diff --git a/egs/tedlium3/ASR/zipformer/decode.py b/egs/tedlium3/ASR/zipformer/decode.py index b7c8132fa..6f109780a 100755 --- a/egs/tedlium3/ASR/zipformer/decode.py +++ b/egs/tedlium3/ASR/zipformer/decode.py @@ -374,6 +374,7 @@ def decode_one_batch( encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) hyps = [] + unk = sp.decode(sp.unk_id()).strip() if params.decoding_method == "fast_beam_search": hyp_tokens = fast_beam_search_one_best( @@ -386,7 +387,8 @@ def decode_one_batch( max_states=params.max_states, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyp = [w for w in hyp.split() if w != unk] + hyps.append(hyp) elif params.decoding_method == "fast_beam_search_nbest_LG": hyp_tokens = fast_beam_search_nbest_LG( model=model, @@ -400,7 +402,8 @@ def decode_one_batch( nbest_scale=params.nbest_scale, ) for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) + hyp = [word_table[i] for i in hyp if word_table[i] != unk] + hyps.append(hyp) elif params.decoding_method == "fast_beam_search_nbest": hyp_tokens = fast_beam_search_nbest( model=model, @@ -414,7 +417,8 @@ def decode_one_batch( nbest_scale=params.nbest_scale, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyp = [w for w in hyp.split() if w != unk] + hyps.append(hyp) elif params.decoding_method == "fast_beam_search_nbest_oracle": hyp_tokens = fast_beam_search_nbest_oracle( model=model, @@ -429,7 +433,8 @@ def decode_one_batch( nbest_scale=params.nbest_scale, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyp = [w for w in hyp.split() if w != unk] + hyps.append(hyp) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: hyp_tokens = greedy_search_batch( model=model, @@ -437,7 +442,8 @@ def decode_one_batch( encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyp = [w for w in hyp.split() if w != unk] + hyps.append(hyp) elif params.decoding_method == "modified_beam_search": hyp_tokens = modified_beam_search( model=model, @@ -446,7 +452,8 @@ def decode_one_batch( beam=params.beam_size, ) for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + hyp = [w for w in hyp.split() if w != unk] + hyps.append(hyp) else: batch_size = encoder_out.size(0) @@ -470,7 +477,8 @@ def decode_one_batch( raise ValueError( f"Unsupported decoding method: {params.decoding_method}" ) - hyps.append(sp.decode(hyp).split()) + hyp = [w for w in sp.decode(hyp).split() if w != unk] + hyps.append(hyp) if params.decoding_method == "greedy_search": return {"greedy_search": hyps} diff --git a/egs/tedlium3/ASR/zipformer/train.py b/egs/tedlium3/ASR/zipformer/train.py index 402388e38..aef89c734 100755 --- a/egs/tedlium3/ASR/zipformer/train.py +++ b/egs/tedlium3/ASR/zipformer/train.py @@ -67,6 +67,7 @@ from joiner import Joiner from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed +from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids from model import Transducer from optim import Eden, ScaledAdam from scaling import ScheduledFloat @@ -415,7 +416,7 @@ def get_parser(): parser.add_argument( "--keep-last-k", type=int, - default=5, + default=1, help="""Only keep this number of checkpoints on disk. For instance, if it is 3, there are only 3 checkpoints in the exp-dir with filenames `checkpoint-xxx.pt`. @@ -751,7 +752,7 @@ def compute_loss( warm_step = params.warm_step texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) + y = convert_texts_into_ids(texts, sp) y = k2.RaggedTensor(y).to(device) with torch.set_grad_enabled(is_training):