do some changes for pruned RNN-T

This commit is contained in:
luomingshuang 2022-04-11 16:25:12 +08:00
parent b7e7629168
commit cd6a2d903b
8 changed files with 172 additions and 268 deletions

View File

@ -483,8 +483,9 @@ def main():
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)

View File

@ -37,6 +37,7 @@ class Decoder(nn.Module):
vocab_size: int,
embedding_dim: int,
blank_id: int,
unk_id: int,
context_size: int,
):
"""
@ -47,6 +48,8 @@ class Decoder(nn.Module):
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
unk_id:
The ID of the unk symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
@ -58,6 +61,7 @@ class Decoder(nn.Module):
padding_idx=blank_id,
)
self.blank_id = blank_id
self.unk_id = unk_id
assert context_size >= 1, context_size
self.context_size = context_size

View File

@ -319,6 +319,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
unk_id=params.unk_id,
context_size=params.context_size,
)
return decoder
@ -756,8 +757,9 @@ def run(rank, world_size, args):
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)

View File

@ -12,7 +12,7 @@ The WERs are
|------------------------------------|------------|------------|------------------------------------------|
| greedy search | 7.27 | 6.69 | --epoch 29, --avg 13, --max-duration 100 |
| beam search (beam size 4) | 6.70 | 6.04 | --epoch 29, --avg 13, --max-duration 100 |
| modified beam search (beam size 4) | 6.72 | 6.12 | --epoch 29, --avg 13, --max-duration 100 |
| modified beam search (beam size 4) | 6.77 | 6.12 | --epoch 29, --avg 13, --max-duration 100 |
| fast beam search (set as default) | 7.14 | 6.50 | --epoch 29, --avg 13, --max-duration 1500|
The training command for reproducing is given below:

View File

@ -1,5 +1,5 @@
# Copyright 2020 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -369,13 +369,158 @@ class HypothesisList(object):
return ", ".join(s)
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps].
Args:
hyps:
len(hyps) == batch_size. It contains the current hypothesis for
each utterance in the batch.
Returns:
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
the shape is on CPU.
"""
num_hyps = [len(h) for h in hyps]
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
# to get exclusive sum later.
num_hyps.insert(0, 0)
num_hyps = torch.tensor(num_hyps)
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
ans = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
)
return ans
def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
beam:
Number of active paths during the beam search.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
blank_id = model.decoder.blank_id
unk_id = model.decoder.unk_id
context_size = model.decoder.context_size
device = model.device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
hyps_shape = _get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
# decoder_output is of shape (num_hyps, 1, 1, decoder_output_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = torch.div(topk_indexes, vocab_size, rounding_mode="trunc")
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id and new_token != unk_id:
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
return ans
def _deprecated_modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
It decodes only one utterance at a time. We keep it only for reference.
The function :func:`modified_beam_search` should be preferred as it
supports batch decoding.
Args:
model:
An instance of `Transducer`.

View File

@ -74,13 +74,9 @@ from beam_search import (
greedy_search_batch,
modified_beam_search,
)
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import (
AttributeDict,
setup_logger,
@ -182,7 +178,7 @@ def get_parser():
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
@ -190,73 +186,6 @@ def get_parser():
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"embedding_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict):
# TODO: We can add an option to switch between Conformer and Transformer
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.vocab_size,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict):
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.embedding_dim,
blank_id=params.blank_id,
unk_id=params.unk_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict):
joiner = Joiner(
input_dim=params.vocab_size,
inner_dim=params.embedding_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
@ -329,6 +258,14 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
@ -348,12 +285,6 @@ def decode_one_batch(
encoder_out=encoder_out_i,
beam=params.beam_size,
)
elif params.decoding_method == "modified_beam_search":
hyp = modified_beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
@ -593,8 +524,5 @@ def main():
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -1,105 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
embedding_dim: int,
blank_id: int,
unk_id: int,
context_size: int,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
embedding_dim:
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
unk_id:
The ID of the unk symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=blank_id,
)
self.blank_id = blank_id
self.unk_id = unk_id
assert context_size >= 1, context_size
self.context_size = context_size
self.vocab_size = vocab_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=embedding_dim,
out_channels=embedding_dim,
kernel_size=context_size,
padding=0,
groups=embedding_dim,
bias=False,
)
self.output_linear = nn.Linear(embedding_dim, vocab_size)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
embedding_out = self.embedding(y)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
embedding_out = self.output_linear(F.relu(embedding_out))
return embedding_out

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless/decoder.py

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang
# Mingshuang Luo)
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -50,15 +50,10 @@ from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.env import get_env_info
from icefall.utils import AttributeDict, str2bool
from icefall.utils import str2bool
def get_parser():
@ -69,7 +64,7 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=20,
default=30,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
@ -77,7 +72,7 @@ def get_parser():
parser.add_argument(
"--avg",
type=int,
default=10,
default=13,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
@ -118,73 +113,6 @@ def get_parser():
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"embedding_dim": 512,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
unk_id=params.unk_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
inner_dim=params.embedding_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)