do some changes for pruned RNN-T

This commit is contained in:
luomingshuang 2022-04-11 16:25:12 +08:00
parent b7e7629168
commit cd6a2d903b
8 changed files with 172 additions and 268 deletions

View File

@ -483,8 +483,9 @@ def main():
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py # <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)

View File

@ -37,6 +37,7 @@ class Decoder(nn.Module):
vocab_size: int, vocab_size: int,
embedding_dim: int, embedding_dim: int,
blank_id: int, blank_id: int,
unk_id: int,
context_size: int, context_size: int,
): ):
""" """
@ -47,6 +48,8 @@ class Decoder(nn.Module):
Dimension of the input embedding. Dimension of the input embedding.
blank_id: blank_id:
The ID of the blank symbol. The ID of the blank symbol.
unk_id:
The ID of the unk symbol.
context_size: context_size:
Number of previous words to use to predict the next word. Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram. 1 means bigram; 2 means trigram. n means (n+1)-gram.
@ -58,6 +61,7 @@ class Decoder(nn.Module):
padding_idx=blank_id, padding_idx=blank_id,
) )
self.blank_id = blank_id self.blank_id = blank_id
self.unk_id = unk_id
assert context_size >= 1, context_size assert context_size >= 1, context_size
self.context_size = context_size self.context_size = context_size

View File

@ -319,6 +319,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim, embedding_dim=params.embedding_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
unk_id=params.unk_id,
context_size=params.context_size, context_size=params.context_size,
) )
return decoder return decoder
@ -756,8 +757,9 @@ def run(rank, world_size, args):
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py # <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)

View File

@ -12,7 +12,7 @@ The WERs are
|------------------------------------|------------|------------|------------------------------------------| |------------------------------------|------------|------------|------------------------------------------|
| greedy search | 7.27 | 6.69 | --epoch 29, --avg 13, --max-duration 100 | | greedy search | 7.27 | 6.69 | --epoch 29, --avg 13, --max-duration 100 |
| beam search (beam size 4) | 6.70 | 6.04 | --epoch 29, --avg 13, --max-duration 100 | | beam search (beam size 4) | 6.70 | 6.04 | --epoch 29, --avg 13, --max-duration 100 |
| modified beam search (beam size 4) | 6.72 | 6.12 | --epoch 29, --avg 13, --max-duration 100 | | modified beam search (beam size 4) | 6.77 | 6.12 | --epoch 29, --avg 13, --max-duration 100 |
| fast beam search (set as default) | 7.14 | 6.50 | --epoch 29, --avg 13, --max-duration 1500| | fast beam search (set as default) | 7.14 | 6.50 | --epoch 29, --avg 13, --max-duration 1500|
The training command for reproducing is given below: The training command for reproducing is given below:

View File

@ -1,5 +1,5 @@
# Copyright 2020 Xiaomi Corp. (authors: Fangjun Kuang # Copyright 2020 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo) # Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -369,13 +369,158 @@ class HypothesisList(object):
return ", ".join(s) return ", ".join(s)
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps].
Args:
hyps:
len(hyps) == batch_size. It contains the current hypothesis for
each utterance in the batch.
Returns:
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
the shape is on CPU.
"""
num_hyps = [len(h) for h in hyps]
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
# to get exclusive sum later.
num_hyps.insert(0, 0)
num_hyps = torch.tensor(num_hyps)
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
ans = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
)
return ans
def modified_beam_search( def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
beam: int = 4, beam: int = 4,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
beam:
Number of active paths during the beam search.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id
context_size = model.decoder.context_size
device = model.device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
hyps_shape = _get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
# decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = torch.div(topk_indexes, vocab_size, rounding_mode="trunc")
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id and new_token != unk_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
return ans
def _deprecated_modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]: ) -> List[int]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
It decodes only one utterance at a time. We keep it only for reference.
The function :func:`modified_beam_search` should be preferred as it
supports batch decoding.
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.

View File

@ -74,13 +74,9 @@ from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from conformer import Conformer from train import get_params, get_transducer_model
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -182,7 +178,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--max-sym-per-frame", "--max-sym-per-frame",
type=int, type=int,
default=3, default=1,
help="""Maximum number of symbols per frame. help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
@ -190,73 +186,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"embedding_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.vocab_size,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
unk_id=params.unk_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.vocab_size,
inner_dim=params.embedding_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -329,6 +258,14 @@ def decode_one_batch(
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
@ -348,12 +285,6 @@ def decode_one_batch(
encoder_out=encoder_out_i, encoder_out=encoder_out_i,
beam=params.beam_size, beam=params.beam_size,
) )
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else: else:
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
@ -593,8 +524,5 @@ def main():
logging.info("Done!") logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,105 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
embedding_dim: int,
blank_id: int,
unk_id: int,
context_size: int,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
embedding_dim:
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
unk_id:
The ID of the unk symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=blank_id,
)
self.blank_id = blank_id
self.unk_id = unk_id
assert context_size >= 1, context_size
self.context_size = context_size
self.vocab_size = vocab_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=embedding_dim,
out_channels=embedding_dim,
kernel_size=context_size,
padding=0,
groups=embedding_dim,
bias=False,
)
self.output_linear = nn.Linear(embedding_dim, vocab_size)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
embedding_out = self.embedding(y)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = self.output_linear(F.relu(embedding_out))
return embedding_out

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless/decoder.py

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# #
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
# Mingshuang Luo) # Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -50,15 +50,10 @@ from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn from train import get_params, get_transducer_model
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info from icefall.utils import str2bool
from icefall.utils import AttributeDict, str2bool
def get_parser(): def get_parser():
@ -69,7 +64,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=20, default=30,
help="It specifies the checkpoint to use for decoding." help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.", "Note: Epoch counts from 0.",
) )
@ -77,7 +72,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=10, default=13,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch'. ",
@ -118,73 +113,6 @@ def get_parser():
return parser return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"embedding_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
unk_id=params.unk_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
inner_dim=params.embedding_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def main(): def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)