diff --git a/egs/aishell/ASR/transducer_stateless/beam_search.py b/egs/aishell/ASR/transducer_stateless/beam_search.py index f347f552f..3441bd20c 100644 --- a/egs/aishell/ASR/transducer_stateless/beam_search.py +++ b/egs/aishell/ASR/transducer_stateless/beam_search.py @@ -296,7 +296,7 @@ def beam_search( if cached_key not in joint_cache: logits = model.joiner(current_encoder_out, decoder_out) - # TODO(fangjun): Ccale the blank posterior + # TODO(fangjun): Scale the blank posterior log_prob = logits.log_softmax(dim=-1) # log_prob is (1, 1, 1, vocab_size) diff --git a/egs/aishell/ASR/transducer_stateless/decode.py b/egs/aishell/ASR/transducer_stateless/decode.py index 11228375d..9a1d578c5 100755 --- a/egs/aishell/ASR/transducer_stateless/decode.py +++ b/egs/aishell/ASR/transducer_stateless/decode.py @@ -31,7 +31,6 @@ from decoder import Decoder from joiner import Joiner from model import Transducer -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info from icefall.lexicon import Lexicon @@ -39,8 +38,8 @@ from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, - write_error_stats, str2bool, + write_error_stats, ) @@ -130,9 +129,9 @@ def get_params() -> AttributeDict: "feature_dim": 80, "embedding_dim": 256, "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, + "attention_dim": 256, + "nhead": 4, + "dim_feedforward": 1024, "num_encoder_layers": 12, "vgg_frontend": False, "env_info": get_env_info(), @@ -141,7 +140,7 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, @@ -156,7 +155,7 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, embedding_dim=params.embedding_dim, @@ -166,16 +165,16 @@ def get_decoder_model(params: AttributeDict): return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( input_dim=params.vocab_size, - output_dim=params.vocab_size, inner_dim=params.embedding_dim, + output_dim=params.vocab_size, ) return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) @@ -404,10 +403,6 @@ def main(): logging.info(f"Device: {device}") lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) params.blank_id = 0 params.vocab_size = max(lexicon.tokens) + 1 diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py index 7c9c8201c..4653192ec 100644 --- a/egs/aishell/ASR/transducer_stateless/decoder.py +++ b/egs/aishell/ASR/transducer_stateless/decoder.py @@ -96,5 +96,5 @@ class Decoder(nn.Module): assert embeding_out.size(-1) == self.context_size embeding_out = self.conv(embeding_out) embeding_out = embeding_out.permute(0, 2, 1) - embeding_out = self.output_linear(embeding_out) + embeding_out = self.output_linear(F.relu(embeding_out)) return embeding_out diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py index 641555bdb..0d2b5a6bf 100755 --- a/egs/aishell/ASR/transducer_stateless/export.py +++ b/egs/aishell/ASR/transducer_stateless/export.py @@ -22,7 +22,7 @@ Usage: ./transducer_stateless/export.py \ --exp-dir ./transducer_stateless/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ + --lang-dir data/lang_char \ --epoch 20 \ --avg 10 @@ -39,15 +39,15 @@ To use the generated file with `transducer_stateless/decode.py`, you can do: --epoch 9999 \ --avg 1 \ --max-duration 1 \ - --bpe-model data/lang_bpe_500/bpe.model + --lang-dir data/lang_char """ import argparse import logging from pathlib import Path -import sentencepiece as spm import torch +import torch.nn as nn from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -55,6 +55,7 @@ from model import Transducer from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.env import get_env_info +from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, str2bool @@ -90,10 +91,10 @@ def get_parser(): ) parser.add_argument( - "--bpe-model", + "--lang-dir", type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", + default="data/lang_char", + help="Path to the tokens.txt", ) parser.add_argument( @@ -120,11 +121,11 @@ def get_params() -> AttributeDict: { # parameters for conformer "feature_dim": 80, - "encoder_out_dim": 512, + "embedding_dim": 256, "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, + "attention_dim": 256, + "nhead": 4, + "dim_feedforward": 1024, "num_encoder_layers": 12, "vgg_frontend": False, "env_info": get_env_info(), @@ -133,10 +134,10 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Conformer( num_features=params.feature_dim, - output_dim=params.encoder_out_dim, + output_dim=params.vocab_size, subsampling_factor=params.subsampling_factor, d_model=params.attention_dim, nhead=params.nhead, @@ -147,25 +148,26 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, + embedding_dim=params.embedding_dim, blank_id=params.blank_id, context_size=params.context_size, ) return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - input_dim=params.encoder_out_dim, + input_dim=params.vocab_size, + inner_dim=params.embedding_dim, output_dim=params.vocab_size, ) return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) @@ -193,12 +195,9 @@ def main(): logging.info(f"device: {device}") - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() + lexicon = Lexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 logging.info(params) diff --git a/egs/aishell/ASR/transducer_stateless/joiner.py b/egs/aishell/ASR/transducer_stateless/joiner.py index 9371aec5a..e1b2861d1 100644 --- a/egs/aishell/ASR/transducer_stateless/joiner.py +++ b/egs/aishell/ASR/transducer_stateless/joiner.py @@ -22,9 +22,8 @@ class Joiner(nn.Module): def __init__(self, input_dim: int, inner_dim: int, output_dim: int): super().__init__() - self.output_linear = nn.Sequential( - nn.Linear(input_dim, inner_dim), nn.Linear(inner_dim, output_dim) - ) + self.inner_linear = nn.Linear(input_dim, inner_dim) + self.output_linear = nn.Linear(inner_dim, output_dim) def forward( self, encoder_out: torch.Tensor, decoder_out: torch.Tensor @@ -32,16 +31,19 @@ class Joiner(nn.Module): """ Args: encoder_out: - Output from the encoder. Its shape is (N, T, C). + The pruned output from the encoder. Its shape is (N, T, s_range, C). decoder_out: - Output from the decoder. Its shape is (N, U, C). + The pruned output from the decoder. Its shape is (N, T, s_range, C). Returns: - Return a tensor of shape (N, T, U, C). + Return a tensor of shape (N, T, s_range, C). """ assert encoder_out.ndim == decoder_out.ndim == 4 assert encoder_out.shape == decoder_out.shape logit = encoder_out + decoder_out + + logit = self.inner_linear(logit) + logit = torch.tanh(logit) output = self.output_linear(logit) diff --git a/egs/aishell/ASR/transducer_stateless/model.py b/egs/aishell/ASR/transducer_stateless/model.py index d69330368..a187bfce1 100644 --- a/egs/aishell/ASR/transducer_stateless/model.py +++ b/egs/aishell/ASR/transducer_stateless/model.py @@ -32,7 +32,7 @@ class Transducer(nn.Module): encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, - prune_range: int = 5, + prune_range: int = 3, lm_scale: float = 0.0, am_scale: float = 0.0, ): @@ -51,6 +51,20 @@ class Transducer(nn.Module): It has two inputs with shapes: (N, T, C) and (N, U, C). Its output shape is (N, T, U, C). Note that its output contains unnormalized probs, i.e., not processed by log-softmax. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs """ super().__init__() assert isinstance(encoder, EncoderInterface), type(encoder) diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py index 65ac5f3ff..94d0ac60d 100755 --- a/egs/aishell/ASR/transducer_stateless/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless/pretrained.py @@ -20,7 +20,7 @@ Usage: (1) greedy search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --lang-dir ./data/lang_char \ --method greedy_search \ /path/to/foo.wav \ /path/to/bar.wav \ @@ -28,7 +28,7 @@ Usage: (1) beam search ./transducer_stateless/pretrained.py \ --checkpoint ./transducer_stateless/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ + --lang-dir ./data/lang_char \ --method beam_search \ --beam-size 4 \ /path/to/foo.wav \ @@ -44,11 +44,12 @@ Note: ./transducer_stateless/exp/pretrained.pt is generated by import argparse import logging import math -from typing import List from pathlib import Path +from typing import List import kaldifeat import torch +import torch.nn as nn import torchaudio from beam_search import beam_search, greedy_search from conformer import Conformer @@ -58,9 +59,8 @@ from model import Transducer from torch.nn.utils.rnn import pad_sequence from icefall.env import get_env_info -from icefall.utils import AttributeDict from icefall.lexicon import Lexicon -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.utils import AttributeDict def get_parser(): @@ -137,11 +137,11 @@ def get_params() -> AttributeDict: "sample_rate": 16000, # parameters for conformer "feature_dim": 80, - "encoder_out_dim": 512, + "embedding_dim": 256, "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, + "attention_dim": 256, + "nhead": 4, + "dim_feedforward": 1024, "num_encoder_layers": 12, "vgg_frontend": False, "env_info": get_env_info(), @@ -150,10 +150,10 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Conformer( num_features=params.feature_dim, - output_dim=params.encoder_out_dim, + output_dim=params.vocab_size, subsampling_factor=params.subsampling_factor, d_model=params.attention_dim, nhead=params.nhead, @@ -164,25 +164,26 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, - embedding_dim=params.encoder_out_dim, + embedding_dim=params.embedding_dim, blank_id=params.blank_id, context_size=params.context_size, ) return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - input_dim=params.encoder_out_dim, + input_dim=params.vocab_size, + inner_dim=params.embedding_dim, output_dim=params.vocab_size, ) return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) @@ -235,12 +236,8 @@ def main(): logging.info(f"device: {device}") lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - params.blank_id = graph_compiler.texts_to_ids("")[0][0] + params.blank_id = 0 params.vocab_size = max(lexicon.tokens) + 1 logging.info("Creating model") diff --git a/egs/aishell/ASR/transducer_stateless/test_decoder.py b/egs/aishell/ASR/transducer_stateless/test_decoder.py index fe0bdee70..0d34cd672 100755 --- a/egs/aishell/ASR/transducer_stateless/test_decoder.py +++ b/egs/aishell/ASR/transducer_stateless/test_decoder.py @@ -42,12 +42,12 @@ def test_decoder(): U = 20 x = torch.randint(low=0, high=vocab_size, size=(N, U)) y = decoder(x) - assert y.shape == (N, U, embedding_dim) + assert y.shape == (N, U, vocab_size) # for inference x = torch.randint(low=0, high=vocab_size, size=(N, context_size)) y = decoder(x, need_pad=False) - assert y.shape == (N, 1, embedding_dim) + assert y.shape == (N, 1, vocab_size) def main(): diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py index 7336c4312..422b3ac25 100755 --- a/egs/aishell/ASR/transducer_stateless/train.py +++ b/egs/aishell/ASR/transducer_stateless/train.py @@ -131,7 +131,7 @@ def get_parser(): parser.add_argument( "--prune-range", type=int, - default=5, + default=3, help="The prune range for rnnt loss, it means how many symbols(context)" "we are using to compute the loss", ) @@ -139,7 +139,7 @@ def get_parser(): parser.add_argument( "--lm-scale", type=float, - default=0.0, + default=0.5, help="The scale to smooth the loss with lm " "(output of prediction network) part.", ) @@ -212,9 +212,9 @@ def get_params() -> AttributeDict: # parameters for conformer "feature_dim": 80, "subsampling_factor": 4, - "attention_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, + "attention_dim": 256, + "nhead": 4, + "dim_feedforward": 1024, "num_encoder_layers": 12, "vgg_frontend": False, # parameters for decoder @@ -228,7 +228,7 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, @@ -243,7 +243,7 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, embedding_dim=params.embedding_dim, @@ -253,7 +253,7 @@ def get_decoder_model(params: AttributeDict): return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( input_dim=params.vocab_size, inner_dim=params.embedding_dim, @@ -262,7 +262,7 @@ def get_joiner_model(params: AttributeDict): return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params)