diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained.py b/egs/librispeech/ASR/zipformer/jit_pretrained.py index 4092d165e..87cd5102c 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained.py @@ -97,7 +97,7 @@ def read_sound_files( sample_rate == expected_sample_rate ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel - ans.append(wave[0]) + ans.append(wave[0].contiguous()) return ans diff --git a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py index 15423b449..1ec390d5b 100755 --- a/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/jit_pretrained_ctc.py @@ -159,11 +159,11 @@ def get_parser(): (2) nbest-rescoring. Extract n paths from the decoding lattice, rescore them with an LM, the path with the highest score is the decoding result. - We call it HLG decoding + n-gram LM rescoring. + We call it HLG decoding + nbest n-gram LM rescoring. (3) whole-lattice-rescoring - Use an LM to rescore the decoding lattice and then use 1best to decode the rescored lattice. - We call it HLG decoding + n-gram LM rescoring. + We call it HLG decoding + whole-lattice n-gram LM rescoring. """, ) @@ -210,15 +210,6 @@ def get_parser(): """, ) - parser.add_argument( - "--num-classes", - type=int, - default=500, - help=""" - Vocab size in the BPE model. - """, - ) - parser.add_argument( "--sample-rate", type=int, @@ -258,7 +249,7 @@ def read_sound_files( sample_rate == expected_sample_rate ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel - ans.append(wave[0]) + ans.append(wave[0].contiguous()) return ans @@ -272,6 +263,11 @@ def main(): params.update(get_decoding_params()) params.update(vars(args)) + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + params.vocab_size = sp.get_piece_size() + logging.info(f"{params}") device = torch.device("cpu") @@ -321,9 +317,7 @@ def main(): if params.method == "ctc-decoding": logging.info("Use CTC decoding") - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.bpe_model) - max_token_id = params.num_classes - 1 + max_token_id = params.vocab_size - 1 H = k2.ctc_topo( max_token=max_token_id, @@ -346,7 +340,7 @@ def main(): lattice=lattice, use_double_scores=params.use_double_scores ) token_ids = get_texts(best_path) - hyps = bpe_model.decode(token_ids) + hyps = sp.decode(token_ids) hyps = [s.split() for s in hyps] elif params.method in [ "1best", diff --git a/egs/librispeech/ASR/zipformer/model.py b/egs/librispeech/ASR/zipformer/model.py index 8d8ec8231..90209b945 100644 --- a/egs/librispeech/ASR/zipformer/model.py +++ b/egs/librispeech/ASR/zipformer/model.py @@ -75,7 +75,7 @@ class AsrModel(nn.Module): assert ( use_transducer or use_ctc - ), f"At least one of them should be True, but gotten use_transducer={use_transducer}, use_ctc={use_ctc}" + ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" assert isinstance(encoder, EncoderInterface), type(encoder) @@ -98,6 +98,9 @@ class AsrModel(nn.Module): self.simple_lm_proj = ScaledLinear( decoder_dim, vocab_size, initial_scale=0.25 ) + else: + assert decoder is None + assert joiner is None self.use_ctc = use_ctc if use_ctc: @@ -135,7 +138,7 @@ class AsrModel(nn.Module): encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - assert torch.all(encoder_out_lens > 0) + assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) return encoder_out, encoder_out_lens @@ -342,10 +345,7 @@ class AsrModel(nn.Module): if self.use_ctc: # Compute CTC loss - targets = [t for tokens in y.tolist() for t in tokens] - # of shape (sum(y_lens),) - targets = torch.tensor(targets, device=x.device, dtype=torch.int64) - + targets = y.values ctc_loss = self.forward_ctc( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, diff --git a/egs/librispeech/ASR/zipformer/pretrained.py b/egs/librispeech/ASR/zipformer/pretrained.py index c0c898f4b..2944f79e3 100755 --- a/egs/librispeech/ASR/zipformer/pretrained.py +++ b/egs/librispeech/ASR/zipformer/pretrained.py @@ -245,7 +245,7 @@ def read_sound_files( sample_rate == expected_sample_rate ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" # We use only the first channel - ans.append(wave[0]) + ans.append(wave[0].contiguous()) return ans diff --git a/egs/librispeech/ASR/zipformer/pretrained_ctc.py b/egs/librispeech/ASR/zipformer/pretrained_ctc.py index 10abb0a32..f10d95449 100755 --- a/egs/librispeech/ASR/zipformer/pretrained_ctc.py +++ b/egs/librispeech/ASR/zipformer/pretrained_ctc.py @@ -167,11 +167,11 @@ def get_parser(): (2) nbest-rescoring. Extract n paths from the decoding lattice, rescore them with an LM, the path with the highest score is the decoding result. - We call it HLG decoding + n-gram LM rescoring. + We call it HLG decoding + nbest n-gram LM rescoring. (3) whole-lattice-rescoring - Use an LM to rescore the decoding lattice and then use 1best to decode the rescored lattice. - We call it HLG decoding + n-gram LM rescoring. + We call it HLG decoding + whole-lattice n-gram LM rescoring. """, ) @@ -218,15 +218,6 @@ def get_parser(): """, ) - parser.add_argument( - "--num-classes", - type=int, - default=500, - help=""" - Vocab size in the BPE model. - """, - ) - parser.add_argument( "--sample-rate", type=int, @@ -268,7 +259,7 @@ def read_sound_files( f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" ) # We use only the first channel - ans.append(wave[0]) + ans.append(wave[0].contiguous()) return ans @@ -281,7 +272,11 @@ def main(): # add decoding params params.update(get_decoding_params()) params.update(vars(args)) - params.vocab_size = params.num_classes + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + params.vocab_size = sp.get_piece_size() params.blank_id = 0 logging.info(f"{params}") @@ -340,9 +335,7 @@ def main(): if params.method == "ctc-decoding": logging.info("Use CTC decoding") - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.bpe_model) - max_token_id = params.num_classes - 1 + max_token_id = params.vocab_size - 1 H = k2.ctc_topo( max_token=max_token_id, @@ -365,7 +358,7 @@ def main(): lattice=lattice, use_double_scores=params.use_double_scores ) token_ids = get_texts(best_path) - hyps = bpe_model.decode(token_ids) + hyps = sp.decode(token_ids) hyps = [s.split() for s in hyps] elif params.method in [ "1best", diff --git a/egs/librispeech/ASR/zipformer/train.py b/egs/librispeech/ASR/zipformer/train.py index 8553581d5..dd59a6245 100755 --- a/egs/librispeech/ASR/zipformer/train.py +++ b/egs/librispeech/ASR/zipformer/train.py @@ -607,7 +607,7 @@ def get_model(params: AttributeDict) -> nn.Module: assert ( params.use_transducer or params.use_ctc ), (f"At least one of them should be True, " - f"but gotten params.use_transducer={params.use_transducer}, " + f"but got params.use_transducer={params.use_transducer}, " f"params.use_ctc={params.use_ctc}") encoder_embed = get_encoder_embed(params)