Use a stateless decoder for transducer_lstm.

This commit is contained in:
Fangjun Kuang 2022-04-21 13:58:43 +08:00
parent 3607c516d6
commit 52b3ed2920
7 changed files with 140 additions and 384 deletions

View File

@ -14,7 +14,7 @@ The following table lists the differences among them.
| `transducer` | Conformer | LSTM | | | `transducer` | Conformer | LSTM | |
| `transducer_stateless` | Conformer | Embedding + Conv1d | Using optimized_transducer from computing RNN-T loss | | `transducer_stateless` | Conformer | Embedding + Conv1d | Using optimized_transducer from computing RNN-T loss |
| `transducer_stateless2` | Conformer | Embedding + Conv1d | Using torchaudio for computing RNN-T loss | | `transducer_stateless2` | Conformer | Embedding + Conv1d | Using torchaudio for computing RNN-T loss |
| `transducer_lstm` | LSTM | LSTM | | | `transducer_lstm` | LSTM | Embedding + Conv1d | Using torchaudio for computing RNN-T loss |
| `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data | | `transducer_stateless_multi_datasets` | Conformer | Embedding + Conv1d | Using data from GigaSpeech as extra training data |
| `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss | | `pruned_transducer_stateless` | Conformer | Embedding + Conv1d | Using k2 pruned RNN-T loss |
| `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss | | `pruned_transducer_stateless2` | Conformer(modified) | Embedding + Conv1d | Using k2 pruned RNN-T loss |

View File

@ -1,222 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
from model import Transducer
def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]:
"""
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
device = model.device
sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(
1, 1
)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0
hyp = []
sym_per_frame = 0
sym_per_utt = 0
max_sym_per_utt = 1000
max_sym_per_frame = 3
while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
logits = model.joiner(current_encoder_out, decoder_out)
# logits is (1, 1, 1, vocab_size)
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
# TODO: Use logits.argmax()
y = log_prob.argmax()
if y != blank_id:
hyp.append(y.item())
y = y.reshape(1, 1)
decoder_out, (h, c) = model.decoder(y, (h, c))
sym_per_utt += 1
sym_per_frame += 1
if y == blank_id or sym_per_frame > max_sym_per_frame:
sym_per_frame = 0
t += 1
return hyp
@dataclass
class Hypothesis:
ys: List[int] # the predicted sequences so far
log_prob: float # The log prob of ys
# Optional decoder state. We assume it is LSTM for now,
# so the state is a tuple (h, c)
decoder_state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 5,
) -> List[int]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
espnet/nets/beam_search_transducer.py#L247 is used as a reference.
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
sos_id = model.decoder.sos_id
device = model.device
sos = torch.tensor([blank_id], device=device).reshape(1, 1)
decoder_out, (h, c) = model.decoder(sos)
T = encoder_out.size(1)
t = 0
B = [Hypothesis(ys=[blank_id], log_prob=0.0, decoder_state=None)]
max_u = 20000 # terminate after this number of steps
u = 0
cache: Dict[
str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
] = {}
while t < T and u < max_u:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
A = B
B = []
# for hyp in A:
# for h in A:
# if h.ys == hyp.ys[:-1]:
# # update the score of hyp
# decoder_input = torch.tensor(
# [h.ys[-1]], device=device
# ).reshape(1, 1)
# decoder_out, _ = model.decoder(
# decoder_input, h.decoder_state
# )
# logits = model.joiner(current_encoder_out, decoder_out)
# log_prob = logits.log_softmax(dim=-1)
# log_prob = log_prob.squeeze()
# hyp.log_prob += h.log_prob + log_prob[hyp.ys[-1]].item()
while u < max_u:
y_star = max(A, key=lambda hyp: hyp.log_prob)
A.remove(y_star)
# Note: y_star.ys is unhashable, i.e., cannot be used
# as a key into a dict
cached_key = "_".join(map(str, y_star.ys))
if cached_key not in cache:
decoder_input = torch.tensor(
[y_star.ys[-1]], device=device
).reshape(1, 1)
decoder_out, decoder_state = model.decoder(
decoder_input,
y_star.decoder_state,
)
cache[cached_key] = (decoder_out, decoder_state)
else:
decoder_out, decoder_state = cache[cached_key]
logits = model.joiner(current_encoder_out, decoder_out)
log_prob = logits.log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
# If we choose blank here, add the new hypothesis to B.
# Otherwise, add the new hypothesis to A
# First, choose blank
skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob.item()
# ys[:] returns a copy of ys
new_y_star = Hypothesis(
ys=y_star.ys[:],
log_prob=new_y_star_log_prob,
# Caution: Use y_star.decoder_state here
decoder_state=y_star.decoder_state,
)
B.append(new_y_star)
# Second, choose other labels
for i, v in enumerate(log_prob.tolist()):
if i in (blank_id, sos_id):
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
new_hyp = Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
decoder_state=decoder_state,
)
A.append(new_hyp)
u += 1
# check whether B contains more than "beam" elements more probable
# than the most probable in A
A_most_probable = max(A, key=lambda hyp: hyp.log_prob)
B = sorted(
[hyp for hyp in B if hyp.log_prob > A_most_probable.log_prob],
key=lambda hyp: hyp.log_prob,
reverse=True,
)
if len(B) >= beam:
B = B[:beam]
break
t += 1
best_hyp = max(B, key=lambda hyp: hyp.log_prob / len(hyp.ys[1:]))
ys = best_hyp.ys[1:] # [1:] to remove the blank
return ys

View File

@ -0,0 +1 @@
../transducer_stateless/beam_search.py

View File

@ -46,14 +46,15 @@ import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import beam_search, greedy_search from beam_search import (
from decoder import Decoder beam_search,
from encoder import LstmEncoder greedy_search,
from joiner import Joiner greedy_search_batch,
from model import Transducer modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -104,6 +105,7 @@ def get_parser():
help="""Possible values are: help="""Possible values are:
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search
""", """,
) )
@ -114,76 +116,25 @@ def get_parser():
help="Used only when --decoding-method is beam_search", help="Used only when --decoding-method is beam_search",
) )
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"encoder_hidden_size": 1024,
"num_encoder_layers": 4,
"proj_size": 512,
"vgg_frontend": False,
# decoder params
"decoder_embedding_dim": 1024,
"num_decoder_layers": 4,
"decoder_hidden_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
encoder = LstmEncoder(
num_features=params.feature_dim,
hidden_size=params.encoder_hidden_size,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim,
blank_id=params.blank_id,
sos_id=params.sos_id,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -228,24 +179,47 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyps = [] hyp_list: List[List[int]] = []
batch_size = encoder_out.size(0)
for i in range(batch_size): if (
# fmt: off params.decoding_method == "greedy_search"
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] and params.max_sym_per_frame == 1
# fmt: on ):
if params.decoding_method == "greedy_search": hyp_list = greedy_search_batch(
hyp = greedy_search(model=model, encoder_out=encoder_out_i) model=model,
elif params.decoding_method == "beam_search": encoder_out=encoder_out,
hyp = beam_search( )
model=model, encoder_out=encoder_out_i, beam=params.beam_size elif params.decoding_method == "modified_beam_search":
) hyp_list = modified_beam_search(
else: model=model,
raise ValueError( encoder_out=encoder_out,
f"Unsupported decoding method: {params.decoding_method}" beam=params.beam_size,
) )
hyps.append(sp.decode(hyp).split()) else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
@ -393,9 +367,8 @@ def main():
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)

View File

@ -14,25 +14,30 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
# TODO(fangjun): Support switching between LSTM and GRU
class Decoder(nn.Module): class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__( def __init__(
self, self,
vocab_size: int, vocab_size: int,
embedding_dim: int, embedding_dim: int,
blank_id: int, blank_id: int,
sos_id: int, context_size: int,
num_layers: int,
hidden_dim: int,
output_dim: int,
embedding_dropout: float = 0.0,
rnn_dropout: float = 0.0,
): ):
""" """
Args: Args:
@ -42,18 +47,9 @@ class Decoder(nn.Module):
Dimension of the input embedding. Dimension of the input embedding.
blank_id: blank_id:
The ID of the blank symbol. The ID of the blank symbol.
sos_id: context_size:
The ID of the SOS symbol. Number of previous words to use to predict the next word.
num_layers: 1 means bigram; 2 means trigram. n means (n+1)-gram.
Number of LSTM layers.
hidden_dim:
Hidden dimension of LSTM layers.
output_dim:
Output dimension of the decoder.
embedding_dropout:
Dropout rate for the embedding layer.
rnn_dropout:
Dropout for LSTM layers.
""" """
super().__init__() super().__init__()
self.embedding = nn.Embedding( self.embedding = nn.Embedding(
@ -61,41 +57,42 @@ class Decoder(nn.Module):
embedding_dim=embedding_dim, embedding_dim=embedding_dim,
padding_idx=blank_id, padding_idx=blank_id,
) )
self.embedding_dropout = nn.Dropout(embedding_dropout)
# TODO(fangjun): Use layer normalized LSTM
self.rnn = nn.LSTM(
input_size=embedding_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=rnn_dropout,
)
self.blank_id = blank_id self.blank_id = blank_id
self.sos_id = sos_id
self.output_linear = nn.Linear(hidden_dim, output_dim)
def forward( assert context_size >= 1, context_size
self, self.context_size = context_size
y: torch.Tensor, if context_size > 1:
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, self.conv = nn.Conv1d(
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: in_channels=embedding_dim,
out_channels=embedding_dim,
kernel_size=context_size,
padding=0,
groups=embedding_dim,
bias=False,
)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
""" """
Args: Args:
y: y:
A 2-D tensor of shape (N, U) with BOS prepended. A 2-D tensor of shape (N, U).
states: need_pad:
A tuple of two tensors containing the states information of True to left pad the input. Should be True during training.
LSTM layers in this decoder. False to not pad the input. Should be False during inference.
Returns: Returns:
Return a tuple containing: Return a tensor of shape (N, U, embedding_dim).
- rnn_output, a tensor of shape (N, U, C)
- (h, c), containing the state information for LSTM layers.
Both are of shape (num_layers, N, C)
""" """
embeding_out = self.embedding(y) embedding_out = self.embedding(y)
embeding_out = self.embedding_dropout(embeding_out) if self.context_size > 1:
rnn_out, (h, c) = self.rnn(embeding_out, states) embedding_out = embedding_out.permute(0, 2, 1)
out = self.output_linear(rnn_out) if need_pad is True:
embedding_out = F.pad(
return out, (h, c) embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
return embedding_out

View File

@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from typing import Tuple from typing import Tuple
import torch import torch
@ -87,7 +88,9 @@ class LstmEncoder(EncoderInterface):
x = self.encoder_embed(x) x = self.encoder_embed(x)
# Caution: We assume the subsampling factor is 4! # Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2 with warnings.catch_warnings():
warnings.simplefilter("ignore")
lengths = ((x_lens - 1) // 2 - 1) // 2
assert x.size(1) == lengths.max().item(), ( assert x.size(1) == lengths.max().item(), (
x.size(1), x.size(1),
lengths.max(), lengths.max(),

View File

@ -49,7 +49,7 @@ class Transducer(nn.Module):
decoder: decoder:
It is the prediction network in the paper. Its input shape It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain is (N, U) and its output shape is (N, U, C). It should contain
two attributes: `blank_id` and `sos_id`. one attribute: `blank_id`.
joiner: joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains output shape is (N, T, U, C). Note that its output contains
@ -58,7 +58,6 @@ class Transducer(nn.Module):
super().__init__() super().__init__()
assert isinstance(encoder, EncoderInterface) assert isinstance(encoder, EncoderInterface)
assert hasattr(decoder, "blank_id") assert hasattr(decoder, "blank_id")
assert hasattr(decoder, "sos_id")
self.encoder = encoder self.encoder = encoder
self.decoder = decoder self.decoder = decoder
@ -97,13 +96,12 @@ class Transducer(nn.Module):
y_lens = row_splits[1:] - row_splits[:-1] y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id blank_id = self.decoder.blank_id
sos_id = self.decoder.sos_id sos_y = add_sos(y, sos_id=blank_id)
sos_y = add_sos(y, sos_id=sos_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64) sos_y_padded = sos_y_padded.to(torch.int64)
decoder_out, _ = self.decoder(sos_y_padded) decoder_out = self.decoder(sos_y_padded)
logits = self.joiner(encoder_out, decoder_out) logits = self.joiner(encoder_out, decoder_out)

View File

@ -139,6 +139,14 @@ def get_parser():
help="The seed for random generators intended for reproducibility", help="The seed for random generators intended for reproducibility",
) )
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
return parser return parser
@ -235,15 +243,12 @@ def get_encoder_model(params: AttributeDict):
return encoder return encoder
def get_decoder_model(params: AttributeDict): def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.decoder_embedding_dim, embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
sos_id=params.sos_id, context_size=params.context_size,
num_layers=params.num_decoder_layers,
hidden_dim=params.decoder_hidden_dim,
output_dim=params.encoder_out_dim,
) )
return decoder return decoder
@ -400,9 +405,11 @@ def compute_loss(
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
info["frames"] = ( with warnings.catch_warnings():
(feature_lens // params.subsampling_factor).sum().item() warnings.simplefilter("ignore")
) info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
@ -580,9 +587,8 @@ def run(rank, world_size, args):
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> and <sos/eos> are defined in local/train_bpe_model.py # <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.sos_id = sp.piece_to_id("<sos/eos>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)