diff --git a/egs/aishell/ASR/transducer_stateless/decoder.py b/egs/aishell/ASR/transducer_stateless/decoder.py index dca084477..c2c6552a9 100644 --- a/egs/aishell/ASR/transducer_stateless/decoder.py +++ b/egs/aishell/ASR/transducer_stateless/decoder.py @@ -82,17 +82,17 @@ class Decoder(nn.Module): Returns: Return a tensor of shape (N, U, embedding_dim). """ - embeding_out = self.embedding(y) + embedding_out = self.embedding(y) if self.context_size > 1: - embeding_out = embeding_out.permute(0, 2, 1) + embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embeding_out = F.pad( - embeding_out, pad=(self.context_size - 1, 0) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) ) else: # During inference time, there is no need to do extra padding # as we only need one output - assert embeding_out.size(-1) == self.context_size - embeding_out = self.conv(embeding_out) - embeding_out = embeding_out.permute(0, 2, 1) - return embeding_out + assert embedding_out.size(-1) == self.context_size + embedding_out = self.conv(embedding_out) + embedding_out = embedding_out.permute(0, 2, 1) + return embedding_out diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py index 641555bdb..5687260df 100755 --- a/egs/aishell/ASR/transducer_stateless/export.py +++ b/egs/aishell/ASR/transducer_stateless/export.py @@ -48,6 +48,7 @@ from pathlib import Path import sentencepiece as spm import torch +import torch.nn as nn from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -133,7 +134,7 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Conformer( num_features=params.feature_dim, output_dim=params.encoder_out_dim, @@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, @@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict): return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( input_dim=params.encoder_out_dim, output_dim=params.vocab_size, @@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict): return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) diff --git a/egs/aishell/ASR/transducer_stateless/pretrained.py b/egs/aishell/ASR/transducer_stateless/pretrained.py index 65ac5f3ff..db89c4d67 100755 --- a/egs/aishell/ASR/transducer_stateless/pretrained.py +++ b/egs/aishell/ASR/transducer_stateless/pretrained.py @@ -44,11 +44,12 @@ Note: ./transducer_stateless/exp/pretrained.pt is generated by import argparse import logging import math -from typing import List from pathlib import Path +from typing import List import kaldifeat import torch +import torch.nn as nn import torchaudio from beam_search import beam_search, greedy_search from conformer import Conformer @@ -57,10 +58,10 @@ from joiner import Joiner from model import Transducer from torch.nn.utils.rnn import pad_sequence -from icefall.env import get_env_info -from icefall.utils import AttributeDict -from icefall.lexicon import Lexicon from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict def get_parser(): @@ -150,7 +151,7 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Conformer( num_features=params.feature_dim, output_dim=params.encoder_out_dim, @@ -164,7 +165,7 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, @@ -174,7 +175,7 @@ def get_decoder_model(params: AttributeDict): return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( input_dim=params.encoder_out_dim, output_dim=params.vocab_size, @@ -182,7 +183,7 @@ def get_joiner_model(params: AttributeDict): return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) diff --git a/egs/aishell/ASR/transducer_stateless/train.py b/egs/aishell/ASR/transducer_stateless/train.py index 7da8e28a1..0c180b260 100755 --- a/egs/aishell/ASR/transducer_stateless/train.py +++ b/egs/aishell/ASR/transducer_stateless/train.py @@ -204,7 +204,7 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, @@ -219,7 +219,7 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, @@ -229,7 +229,7 @@ def get_decoder_model(params: AttributeDict): return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( input_dim=params.encoder_out_dim, output_dim=params.vocab_size, @@ -237,7 +237,7 @@ def get_joiner_model(params: AttributeDict): return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index c1fa814c0..cb0bd5c2d 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -41,6 +41,7 @@ from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info +from icefall.graph_compiler import CtcTrainingGraphCompiler from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -123,6 +124,15 @@ def get_parser(): """, ) + parser.add_argument( + "--num-decoder-layers", + type=int, + default=6, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + parser.add_argument( "--lr-factor", type=float, @@ -210,7 +220,6 @@ def get_params() -> AttributeDict: "use_feat_batchnorm": True, "attention_dim": 512, "nhead": 8, - "num_decoder_layers": 6, # parameters for loss "beam_size": 10, "reduction": "sum", @@ -357,9 +366,17 @@ def compute_loss( supervisions, subsampling_factor=params.subsampling_factor ) - token_ids = graph_compiler.texts_to_ids(texts) - - decoding_graph = graph_compiler.compile(token_ids) + if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler): + # Works with a BPE model + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + elif isinstance(graph_compiler, CtcTrainingGraphCompiler): + # Works with a phone lexicon + decoding_graph = graph_compiler.compile(texts) + else: + raise ValueError( + f"Unsupported type of graph compiler: {type(graph_compiler)}" + ) dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -584,12 +601,38 @@ def run(rank, world_size, args): if torch.cuda.is_available(): device = torch.device("cuda", rank) - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) + if "lang_bpe" in params.lang_dir: + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + elif "lang_phone" in params.lang_dir: + assert params.att_rate == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. Set --att-rate=0 " + "for pure CTC training when using a phone-based lang dir." + ) + assert params.num_decoder_layers == 0, ( + "Attention decoder training does not support phone lang dirs " + "at this time due to a missing symbol. " + "Set --num-decoder-layers=0 for pure CTC training when using " + "a phone-based lang dir." + ) + graph_compiler = CtcTrainingGraphCompiler( + lexicon, + device=device, + ) + # Manually add the sos/eos ID with their default values + # from the BPE recipe which we're adapting here. + graph_compiler.sos_id = 1 + graph_compiler.eos_id = 1 + else: + raise ValueError( + f"Unsupported type of lang dir (we expected it to have " + f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" + ) logging.info("About to create model") model = Conformer( @@ -607,7 +650,9 @@ def run(rank, world_size, args): model.to(device) if world_size > 1: - model = DDP(model, device_ids=[rank]) + # Note: find_unused_parameters=True is needed in case we + # want to set params.att_rate = 0 (i.e. att decoder is not trained) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer = Noam( model.parameters(), diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py index f45d06ce9..11032f31a 100644 --- a/egs/librispeech/ASR/transducer/beam_search.py +++ b/egs/librispeech/ASR/transducer/beam_search.py @@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: blank_id = model.decoder.blank_id device = model.device - sos = torch.tensor([blank_id], device=device).reshape(1, 1) + sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape( + 1, 1 + ) decoder_out, (h, c) = model.decoder(sos) T = encoder_out.size(1) t = 0 diff --git a/egs/librispeech/ASR/transducer/model.py b/egs/librispeech/ASR/transducer/model.py index fa0b2dd68..8305248c9 100644 --- a/egs/librispeech/ASR/transducer/model.py +++ b/egs/librispeech/ASR/transducer/model.py @@ -99,6 +99,7 @@ class Transducer(nn.Module): sos_y = add_sos(y, sos_id=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = sos_y_padded.to(torch.int64) decoder_out, _ = self.decoder(sos_y_padded) diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py index dfc22fcf8..3531a9633 100644 --- a/egs/librispeech/ASR/transducer_lstm/beam_search.py +++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py @@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: blank_id = model.decoder.blank_id device = model.device - sos = torch.tensor([blank_id], device=device).reshape(1, 1) + sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape( + 1, 1 + ) decoder_out, (h, c) = model.decoder(sos) T = encoder_out.size(1) t = 0 diff --git a/egs/librispeech/ASR/transducer_lstm/model.py b/egs/librispeech/ASR/transducer_lstm/model.py index cb9afd8a2..31843b60e 100644 --- a/egs/librispeech/ASR/transducer_lstm/model.py +++ b/egs/librispeech/ASR/transducer_lstm/model.py @@ -101,6 +101,7 @@ class Transducer(nn.Module): sos_y = add_sos(y, sos_id=sos_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = sos_y_padded.to(torch.int64) decoder_out, _ = self.decoder(sos_y_padded) diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 4f883092c..c5efb733d 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -47,7 +47,7 @@ def greedy_search( device = model.device decoder_input = torch.tensor( - [blank_id] * context_size, device=device + [blank_id] * context_size, device=device, dtype=torch.int64 ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py index 641555bdb..5687260df 100755 --- a/egs/librispeech/ASR/transducer_stateless/export.py +++ b/egs/librispeech/ASR/transducer_stateless/export.py @@ -48,6 +48,7 @@ from pathlib import Path import sentencepiece as spm import torch +import torch.nn as nn from conformer import Conformer from decoder import Decoder from joiner import Joiner @@ -133,7 +134,7 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Conformer( num_features=params.feature_dim, output_dim=params.encoder_out_dim, @@ -147,7 +148,7 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, @@ -157,7 +158,7 @@ def get_decoder_model(params: AttributeDict): return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( input_dim=params.encoder_out_dim, output_dim=params.vocab_size, @@ -165,7 +166,7 @@ def get_joiner_model(params: AttributeDict): return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) diff --git a/egs/librispeech/ASR/transducer_stateless/model.py b/egs/librispeech/ASR/transducer_stateless/model.py index 76279a2eb..8281e1fb5 100644 --- a/egs/librispeech/ASR/transducer_stateless/model.py +++ b/egs/librispeech/ASR/transducer_stateless/model.py @@ -98,6 +98,7 @@ class Transducer(nn.Module): sos_y = add_sos(y, sos_id=blank_id) sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + sos_y_padded = sos_y_padded.to(torch.int64) decoder_out = self.decoder(sos_y_padded) diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index f2fb8d908..ad8d89918 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -59,6 +59,7 @@ from typing import List import kaldifeat import sentencepiece as spm import torch +import torch.nn as nn import torchaudio from beam_search import beam_search, greedy_search, modified_beam_search from conformer import Conformer @@ -159,7 +160,7 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Conformer( num_features=params.feature_dim, output_dim=params.encoder_out_dim, @@ -173,7 +174,7 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, @@ -183,7 +184,7 @@ def get_decoder_model(params: AttributeDict): return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( input_dim=params.encoder_out_dim, output_dim=params.vocab_size, @@ -191,7 +192,7 @@ def get_joiner_model(params: AttributeDict): return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index 1accda09a..544f6e9b1 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -224,7 +224,7 @@ def get_params() -> AttributeDict: return params -def get_encoder_model(params: AttributeDict): +def get_encoder_model(params: AttributeDict) -> nn.Module: # TODO: We can add an option to switch between Conformer and Transformer encoder = Conformer( num_features=params.feature_dim, @@ -239,7 +239,7 @@ def get_encoder_model(params: AttributeDict): return encoder -def get_decoder_model(params: AttributeDict): +def get_decoder_model(params: AttributeDict) -> nn.Module: decoder = Decoder( vocab_size=params.vocab_size, embedding_dim=params.encoder_out_dim, @@ -249,7 +249,7 @@ def get_decoder_model(params: AttributeDict): return decoder -def get_joiner_model(params: AttributeDict): +def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( input_dim=params.encoder_out_dim, output_dim=params.vocab_size, @@ -257,7 +257,7 @@ def get_joiner_model(params: AttributeDict): return joiner -def get_transducer_model(params: AttributeDict): +def get_transducer_model(params: AttributeDict) -> nn.Module: encoder = get_encoder_model(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py index b4c87d964..570ed7d7a 100644 --- a/icefall/graph_compiler.py +++ b/icefall/graph_compiler.py @@ -89,6 +89,29 @@ class CtcTrainingGraphCompiler(object): return decoding_graph + def texts_to_ids(self, texts: List[str]) -> List[List[int]]: + """Convert a list of texts to a list-of-list of word IDs. + + Args: + texts: + It is a list of strings. Each string consists of space(s) + separated words. An example containing two strings is given below: + + ['HELLO ICEFALL', 'HELLO k2'] + Returns: + Return a list-of-list of word IDs. + """ + word_ids_list = [] + for text in texts: + word_ids = [] + for word in text.split(): + if word in self.word_table: + word_ids.append(self.word_table[word]) + else: + word_ids.append(self.oov_id) + word_ids_list.append(word_ids) + return word_ids_list + def convert_transcript_to_fsa(self, texts: List[str]) -> k2.Fsa: """Convert a list of transcript texts to an FsaVec.