Use a stateless decoder.

This commit is contained in:
Fangjun Kuang 2021-12-28 10:49:50 +08:00
parent 2cf1b56cb3
commit ec083e93d8
5 changed files with 77 additions and 292 deletions

View File

@ -78,7 +78,7 @@ class Decoder(nn.Module):
""" """
Args: Args:
y: y:
A 2-D tensor of shape (N, U) with BOS prepended. A 2-D tensor of shape (N, U) with blank prepended.
states: states:
A tuple of two tensors containing the states information of A tuple of two tensors containing the states information of
LSTM layers in this decoder. LSTM layers in this decoder.

View File

@ -1,219 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from model import Transducer
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
"""
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0
hyp = []
sym_per_frame = 0
sym_per_utt = 0
max_sym_per_utt = 1000
max_sym_per_frame = 3
while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
logits = model.joiner(current_encoder_out, decoder_out)
# logits is (1, 1, 1, vocab_size)
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
# TODO: Use logits.argmax()
y = log_prob.argmax()
if y != blank_id:
hyp.append(y.item())
y = y.reshape(1, 1)
decoder_out, (h, c) = model.decoder(y, (h, c))
sym_per_utt += 1
sym_per_frame += 1
if y == blank_id or sym_per_frame > max_sym_per_frame:
sym_per_frame = 0
t += 1
return hyp
@dataclass
class Hypothesis:
ys: List[int] # the predicted sequences so far
log_prob: float # The log prob of ys
# Optional decoder state. We assume it is LSTM for now,
# so the state is a tuple (h, c)
decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 5,
) -> List[int]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
espnet/nets/beam_search_transducer.py#L247 is used as a reference.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0
B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)]
max_u = 20000 # terminate after this number of steps
u = 0
cache: Dict[
str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
] = {}
while t < T and u < max_u:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
A = B
B = []
# for hyp in A:
# for h in A:
# if h.ys == hyp.ys[:-1]:
# # update the score of hyp
# decoder_input = torch.tensor(
# [h.ys[-1]], device=device
# ).reshape(1, 1)
# decoder_out, _ = model.decoder(
# decoder_input, h.decoder_state
# )
# logits = model.joiner(current_encoder_out, decoder_out)
# log_prob = logits.log_softmax(dim=-1)
# log_prob = log_prob.squeeze()
# hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item()
while u < max_u:
y_star = max(A, key=lambda hyp: hyp.log_prob)
A.remove(y_star)
# Note: y_star.ys is unhashable, i.e., cannot be used
# as a key into a dict
cached_key = "_".join(map(str, y_star.ys))
if cached_key not in cache:
decoder_input = torch.tensor(
[y_star.ys[-1]], device=device
).reshape(1, 1)
decoder_out, decoder_state = model.decoder(
decoder_input,
y_star.decoder_state,
)
cache[cached_key] = (decoder_out, decoder_state)
else:
decoder_out, decoder_state = cache[cached_key]
logits = model.joiner(current_encoder_out, decoder_out)
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
# If we choose blank here, add the new hypothesis to B.
# Otherwise, add the new hypothesis to A
# First, choose blank
skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
# ys[:] returns a copy of ys
new_y_star = Hypothesis(
ys=y_star.ys[:],
log_prob=new_y_star_log_prob,
# Caution: Use y_star.decoder_state here
decoder_state=y_star.decoder_state,
)
B.append(new_y_star)
# Second, choose other labels
for i, v in enumerate(log_prob.tolist()):
if i == blank_id:
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
new_hyp = Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
decoder_state=decoder_state,
)
A.append(new_hyp)
u += 1
# check whether B contains more than "beam" elements more probable
# than the most probable in A
A_most_probable = max(A, key=lambda hyp: hyp.log_prob)
B = sorted(
[hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob],
key=lambda hyp: hyp.log_prob,
reverse=True,
)
if len(B) >= beam:
B = B[:beam]
break
t += 1
best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:]))
ys = best_hyp.ys[1:] # [1:] to remove the blank
return ys

View File

@ -0,0 +1 @@
../transducer_stateless/beam_search.py

View File

@ -114,6 +114,14 @@ def get_parser():
help="Used only when --decoding-method is beam_search", help="Used only when --decoding-method is beam_search",
) )
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser return parser
@ -124,14 +132,10 @@ def get_params() -> AttributeDict:
"feature_dim": 80, "feature_dim": 80,
"encoder_out_dim": 512, "encoder_out_dim": 512,
"subsampling_factor": 4, "subsampling_factor": 4,
"encoder_hidden_size": 1024, "encoder_hidden_size": 2048,
"num_encoder_layers": 4, "num_encoder_layers": 6,
"proj_size": 512, "proj_size": 512,
"vgg_frontend": False, "vgg_frontend": False,
# decoder params
"decoder_embedding_dim": 1024,
"num_decoder_layers": 4,
"decoder_hidden_dim": 512,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -153,11 +157,9 @@ def get_encoder_model(params: AttributeDict):
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict):
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim, embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
num_layers=params.num_decoder_layers, context_size=params.context_size,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
) )
return decoder return decoder

View File

@ -14,24 +14,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
# TODO(fangjun): Support switching between LSTM and GRU
class Decoder(nn.Module): class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__( def __init__(
self, self,
vocab_size: int, vocab_size: int,
embedding_dim: int, embedding_dim: int,
blank_id: int, blank_id: int,
num_layers: int, context_size: int,
hidden_dim: int,
output_dim: int,
embedding_dropout: float = 0.0,
rnn_dropout: float = 0.0,
): ):
""" """
Args: Args:
@ -41,16 +47,9 @@ class Decoder(nn.Module):
Dimension of the input embedding. Dimension of the input embedding.
blank_id: blank_id:
The ID of the blank symbol. The ID of the blank symbol.
num_layers: context_size:
Number of LSTM layers. Number of previous words to use to predict the next word.
hidden_dim: 1 means bigram; 2 means trigram. n means (n+1)-gram.
Hidden dimension of LSTM layers.
output_dim:
Output dimension of the decoder.
embedding_dropout:
Dropout rate for the embedding layer.
rnn_dropout:
Dropout for LSTM layers.
""" """
super().__init__() super().__init__()
self.embedding = nn.Embedding( self.embedding = nn.Embedding(
@ -58,40 +57,42 @@ class Decoder(nn.Module):
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
padding_idx=blank_id, padding_idx=blank_id,
) )
self.embedding_dropout = nn.Dropout(embedding_dropout)
# TODO(fangjun): Use layer normalized LSTM
self.rnn = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=rnn_dropout,
)
self.blank_id = blank_id self.blank_id = blank_id
self.output_linear = nn.Linear(hidden_dim, output_dim)
def forward( assert context_size >= 1, context_size
self, self.context_size = context_size
y: torch.Tensor, if context_size > 1:
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, self.conv = nn.Conv1d(
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: in_channels=embedding_dim,
out_channels=embedding_dim,
kernel_size=context_size,
padding=0,
groups=embedding_dim,
bias=False,
)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
""" """
Args: Args:
y: y:
A 2-D tensor of shape (N, U) with BOS prepended. A 2-D tensor of shape (N, U) with blank prepended.
states: need_pad:
A tuple of two tensors containing the states information of True to left pad the input. Should be True during training.
LSTM layers in this decoder. False to not pad the input. Should be False during inference.
Returns: Returns:
Return a tuple containing: Return a tensor of shape (N, U, embedding_dim).
- rnn_output, a tensor of shape (N, U, C)
- (h, c), containing the state information for LSTM layers.
Both are of shape (num_layers, N, C)
""" """
embeding_out = self.embedding(y) embeding_out = self.embedding(y)
embeding_out = self.embedding_dropout(embeding_out) if self.context_size > 1:
rnn_out, (h, c) = self.rnn(embeding_out, states) embeding_out = embeding_out.permute(0, 2, 1)
out = self.output_linear(rnn_out) if need_pad is True:
embeding_out = F.pad(
return out, (h, c) embeding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embeding_out.size(-1) == self.context_size
embeding_out = self.conv(embeding_out)
embeding_out = embeding_out.permute(0, 2, 1)
return embeding_out

View File

@ -131,6 +131,14 @@ def get_parser():
help="The lr_factor for Noam optimizer", help="The lr_factor for Noam optimizer",
) )
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser return parser
@ -172,9 +180,6 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model. - subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- attention_dim: Hidden dim for multi-head attention model. - attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder. - num_decoder_layers: Number of decoder layer of transformer decoder.
@ -195,14 +200,10 @@ def get_params() -> AttributeDict:
"feature_dim": 80, "feature_dim": 80,
"encoder_out_dim": 512, "encoder_out_dim": 512,
"subsampling_factor": 4, "subsampling_factor": 4,
"encoder_hidden_size": 1024, "encoder_hidden_size": 2048,
"num_encoder_layers": 4, "num_encoder_layers": 6,
"proj_size": 512, "proj_size": 512,
"vgg_frontend": False, "vgg_frontend": False,
# decoder params
"decoder_embedding_dim": 1024,
"num_decoder_layers": 4,
"decoder_hidden_dim": 512,
# parameters for Noam # parameters for Noam
"warm_step": 80000, # For the 100h subset, use 8k "warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(), "env_info": get_env_info(),
@ -227,12 +228,11 @@ def get_encoder_model(params: AttributeDict):
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict):
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim, embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
num_layers=params.num_decoder_layers, context_size=params.context_size,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
) )
return decoder return decoder
@ -573,11 +573,11 @@ def run(rank, world_size, args):
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
checkpoints = load_checkpoint_if_available(params=params, model=model)
num_param = sum([p.numel() for p in model.parameters() if p.requires_grad]) num_param = sum([p.numel() for p in model.parameters() if p.requires_grad])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")