diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md index 75fc6326e..f4a59e552 100644 --- a/egs/aishell/ASR/README.md +++ b/egs/aishell/ASR/README.md @@ -1,7 +1,7 @@ # Introduction -Please refer to +Please refer to for how to run models in this recipe. diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 7394f164e..5917668a1 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -31,7 +31,7 @@ stop_stage=10 # - noise # - speech -dl_dir=/home/work/workspace/aishell +dl_dir=$PWD/download . shared/parse_options.sh || exit 1 diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py index cb0d6e04d..52d2eda3b 100644 --- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py @@ -44,7 +44,8 @@ class LabelSmoothingLoss(torch.nn.Module): mean of the output is taken. (3) "sum": the output will be summed. """ super().__init__() - assert 0.0 <= label_smoothing < 1.0 + assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}" + assert reduction in ("none", "sum", "mean"), reduction self.ignore_index = ignore_index self.label_smoothing = label_smoothing self.reduction = reduction diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index df6c609bb..08dac6a7b 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -24,7 +24,7 @@ This script takes as input lang_dir and generates HLG from Caution: We use a lexicon that contains disambiguation symbols - - G, the LM, built from data/lm/G_3_gram.fst.txt + - G, the LM, built from data/lm/G_n_gram.fst.txt The generated HLG is saved in $lang_dir/HLG.pt """ diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 4a4093ae4..62036467e 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -28,7 +28,7 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, combine +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor @@ -41,6 +41,10 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) +def is_cut_long(c: MonoCut) -> bool: + return c.duration > 5 + + def compute_fbank_musan(): src_dir = Path("data/manifests") output_dir = Path("data/fbank") @@ -86,7 +90,7 @@ def compute_fbank_musan(): recordings=combine(part["recordings"] for part in manifests.values()) ) .cut_into_windows(10.0) - .filter(lambda c: c.duration > 5) + .filter(is_cut_long) .compute_and_store_features( extractor=extractor, storage_path=f"{output_dir}/musan_feats", diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index b324cc9b7..7388af389 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -26,7 +26,9 @@ from model import Transducer from icefall import NgramLm, NgramLmStateCost from icefall.decode import Nbest, one_best_decoding +from icefall.lm_wrapper import LmScorer from icefall.rnn_lm.model import RnnLmModel +from icefall.transformer_lm.model import TransformerLM from icefall.utils import ( DecodingResults, add_eos, @@ -1846,254 +1848,14 @@ def modified_beam_search_ngram_rescoring( return ans -def modified_beam_search_rnnlm_shallow_fusion( - model: Transducer, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - sp: spm.SentencePieceProcessor, - rnnlm: RnnLmModel, - rnnlm_scale: float, - beam: int = 4, - return_timestamps: bool = False, -) -> List[List[int]]: - """Modified_beam_search + RNNLM shallow fusion - - Args: - model (Transducer): - The transducer model - encoder_out (torch.Tensor): - Encoder output in (N,T,C) - encoder_out_lens (torch.Tensor): - A 1-D tensor of shape (N,), containing the number of - valid frames in encoder_out before padding. - sp: - Sentence piece generator. - rnnlm (RnnLmModel): - RNNLM - rnnlm_scale (float): - scale of RNNLM in shallow fusion - beam (int, optional): - Beam size. Defaults to 4. - - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert rnnlm is not None - lm_scale = rnnlm_scale - vocab_size = rnnlm.vocab_size - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("") - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - # get initial lm score and lm state by scoring the "sos" token - sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - init_score, init_states = rnnlm.score_token(sos_token) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, - lm_score=init_score.reshape(-1), - timestamp=[], - ) - ) - - rnnlm.clean_cache() - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] # get batch - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - """ - for all hyps with a non-blank new token, score this token. - It is a little confusing here because this for-loop - looks very similar to the one below. Here, we go through all - top-k tokens and only add the non-blanks ones to the token_list. - The RNNLM will score those tokens given the LM states. Note that - the variable `scores` is the LM score after seeing the new - non-blank token. - """ - token_list = [] - hs = [] - cs = [] - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - assert new_token != 0, new_token - token_list.append([new_token]) - # store the LSTM states - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) - - # forward RNNLM to get new states and scores - if len(token_list) != 0: - tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) - ) - - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) - - count = 0 # index, used to locate score and lm states - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - ys = hyp.ys[:] - - lm_score = hyp.lm_score - state = hyp.state - - hyp_log_prob = topk_log_probs[k] # get score of current hyp - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - - ys.append(new_token) - new_timestamp.append(t) - hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score - - lm_score = scores[count] - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) - count += 1 - - new_hyp = Hypothesis( - ys=ys, - log_prob=hyp_log_prob, - state=state, - lm_score=lm_score, - timestamp=new_timestamp, - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - tokens=ans, - timestamps=ans_timestamps, - ) - - -def modified_beam_search_rnnlm_LODR( +def modified_beam_search_LODR( model: Transducer, encoder_out: torch.Tensor, encoder_out_lens: torch.Tensor, sp: spm.SentencePieceProcessor, LODR_lm: NgramLm, LODR_lm_scale: float, - rnnlm: RnnLmModel, - rnnlm_scale: float, + LM: LmScorer, beam: int = 4, ) -> List[List[int]]: """This function implements LODR (https://arxiv.org/abs/2203.16776) with @@ -2113,13 +1875,11 @@ def modified_beam_search_rnnlm_LODR( sp: Sentence piece generator. LODR_lm: - A low order n-gram LM + A low order n-gram LM, whose score will be subtracted during shallow fusion LODR_lm_scale: The scale of the LODR_lm - rnnlm (RnnLmModel): - RNNLM, the external language model - rnnlm_scale (float): - scale of RNNLM in shallow fusion + LM: + A neural net LM, e.g an RNNLM or transformer LM beam (int, optional): Beam size. Defaults to 4. @@ -2130,9 +1890,8 @@ def modified_beam_search_rnnlm_LODR( """ assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert rnnlm is not None - lm_scale = rnnlm_scale - vocab_size = rnnlm.vocab_size + assert LM is not None + lm_scale = LM.lm_scale packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( input=encoder_out, @@ -2154,7 +1913,8 @@ def modified_beam_search_rnnlm_LODR( # get initial lm score and lm state by scoring the "sos" token sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - init_score, init_states = rnnlm.score_token(sos_token) + lens = torch.tensor([1]).to(device) + init_score, init_states = LM.score_token(sos_token, lens) B = [HypothesisList() for _ in range(N)] for i in range(N): @@ -2162,7 +1922,7 @@ def modified_beam_search_rnnlm_LODR( Hypothesis( ys=[blank_id] * context_size, log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, # state of the RNNLM + state=init_states, # state of the NN LM lm_score=init_score.reshape(-1), state_cost=NgramLmStateCost( LODR_lm @@ -2170,7 +1930,6 @@ def modified_beam_search_rnnlm_LODR( ) ) - rnnlm.clean_cache() encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) offset = 0 @@ -2236,7 +1995,7 @@ def modified_beam_search_rnnlm_LODR( It is a little confusing here because this for-loop looks very similar to the one below. Here, we go through all top-k tokens and only add the non-blanks ones to the token_list. - The RNNLM will score those tokens given the LM states. Note that + LM will score those tokens given the LM states. Note that the variable `scores` is the LM score after seeing the new non-blank token. """ @@ -2256,21 +2015,41 @@ def modified_beam_search_rnnlm_LODR( new_token = topk_token_indexes[k] if new_token not in (blank_id, unk_id): - assert new_token != 0, new_token - token_list.append([new_token]) - # store the LSTM states - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) + if LM.lm_type == "rnn": + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append( + [sos_id] + hyp.ys[context_size:] + [new_token] + ) - # forward RNNLM to get new states and scores + # forward NN LM to get new states and scores if len(token_list) != 0: - tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) - ) + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if LM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) + state = None + + scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) count = 0 # index, used to locate score and lm states for i in range(batch_size): @@ -2305,18 +2084,19 @@ def modified_beam_search_rnnlm_LODR( state_cost.lm_score, hyp.state_cost.lm_score, ) - # score = score + RNNLM_score - LODR_score - # LODR_LM_scale is a negative number here + # score = score + TDLM_score - LODR_score + # LODR_LM_scale should be a negative number here hyp_log_prob += ( lm_score[new_token] * lm_scale + LODR_lm_scale * current_ngram_score ) # add the lm score lm_score = scores[count] - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) + if LM.lm_type == "rnn": + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) count += 1 else: state_cost = hyp.state_cost @@ -2340,3 +2120,263 @@ def modified_beam_search_rnnlm_LODR( ans.append(sorted_ans[unsorted_indices[i]]) return ans + + +def modified_beam_search_lm_shallow_fusion( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + sp: spm.SentencePieceProcessor, + LM: LmScorer, + beam: int = 4, + return_timestamps: bool = False, +) -> List[List[int]]: + """Modified_beam_search + NN LM shallow fusion + + Args: + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + sp: + Sentence piece generator. + LM (LmScorer): + A neural net LM, e.g RNN or Transformer + beam (int, optional): + Beam size. Defaults to 4. + + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert LM is not None + lm_scale = LM.lm_scale + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("") + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + lens = torch.tensor([1]).to(device) + init_score, init_states = LM.score_token(sos_token, lens) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state=init_states, + lm_score=init_score.reshape(-1), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + + lm_scores = torch.cat( + [hyp.lm_score.reshape(1, -1) for hyps in A for hyp in hyps] + ) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + """ + for all hyps with a non-blank new token, score this token. + It is a little confusing here because this for-loop + looks very similar to the one below. Here, we go through all + top-k tokens and only add the non-blanks ones to the token_list. + `LM` will score those tokens given the LM states. Note that + the variable `scores` is the LM score after seeing the new + non-blank token. + """ + token_list = [] # a list of list + hs = [] + cs = [] + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + if LM.lm_type == "rnn": + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + else: + # for transformer LM + token_list.append( + [sos_id] + hyp.ys[context_size:] + [new_token] + ) + + if len(token_list) != 0: + x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) + if LM.lm_type == "rnn": + tokens_to_score = ( + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + ) + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + state = (hs, cs) + else: + # for transformer LM + tokens_list = [torch.tensor(tokens) for tokens in token_list] + tokens_to_score = ( + torch.nn.utils.rnn.pad_sequence( + tokens_list, batch_first=True, padding_value=0.0 + ) + .to(device) + .to(torch.int64) + ) + + state = None + + scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + ys = hyp.ys[:] + + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + + ys.append(new_token) + new_timestamp.append(t) + + hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score + + lm_score = scores[count] + if LM.lm_type == "rnn": + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) + count += 1 + + new_hyp = Hypothesis( + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score, + timestamp=new_timestamp, + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] + ans = [] + ans_timestamps = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) + + if not return_timestamps: + return ans + else: + return DecodingResults( + tokens=ans, + timestamps=ans_timestamps, + ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 212a00dcf..b3a7d71bc 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -33,7 +33,6 @@ from scaling import ( from torch import Tensor, nn from icefall.utils import make_pad_mask, subsequent_chunk_mask -import random class Conformer(EncoderInterface): @@ -693,10 +692,7 @@ class ConformerEncoder(nn.Module): output = src outputs = [] - residual = None - ''' - for i, mod in enumerate(self.layers): output = mod( output, @@ -705,33 +701,10 @@ class ConformerEncoder(nn.Module): src_key_padding_mask=src_key_padding_mask, warmup=warmup, ) - ''' + if i in self.aux_layers: + outputs.append(output) - for i, mod in enumerate(self.layers): - if i == 0: - residual = output - elif i in [2,5,8,11,14,17]: - output = mod( - output, - pos_emb, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - ) - output += residual - residual = output - else: - output = mod( - output, - pos_emb, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - ) - #if i in self.aux_layers: - # outputs.append(output) - - #output = self.combiner(outputs) + output = self.combiner(outputs) return output diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py new file mode 100755 index 000000000..29a2cd7f7 --- /dev/null +++ b/icefall/shared/convert-k2-to-openfst.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes as input an FST in k2 format and convert it +to an FST in OpenFST format. + +The generated FST is saved into a binary file and its type is +StdVectorFst. + +Usage examples: +(1) Convert an acceptor + + ./convert-k2-to-openfst.py in.pt binary.fst + +(2) Convert a transducer + + ./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import kaldifst.utils +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--olabels", + type=str, + default=None, + help="""If not empty, the input FST is assumed to be a transducer + and we use its attribute specified by "olabels" as the output labels. + """, + ) + parser.add_argument( + "input_filename", + type=str, + help="Path to the input FST in k2 format", + ) + + parser.add_argument( + "output_filename", + type=str, + help="Path to the output FST in OpenFst format", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + logging.info(f"{vars(args)}") + + input_filename = args.input_filename + output_filename = args.output_filename + olabels = args.olabels + + if Path(output_filename).is_file(): + logging.info(f"{output_filename} already exists - skipping") + return + + assert Path(input_filename).is_file(), f"{input_filename} does not exist" + logging.info(f"Loading {input_filename}") + k2_fst = k2.Fsa.from_dict(torch.load(input_filename)) + if olabels: + assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}" + + p = Path(output_filename).parent + if not p.is_dir(): + logging.info(f"Creating {p}") + p.mkdir(parents=True) + + logging.info("Converting (May take some time if the input FST is large)") + fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels) + logging.info(f"Saving to {output_filename}") + fst.write(output_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main()