From 8f63bf0bfdfe9f4141cf3ae0bd415b78d31d1908 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 26 Sep 2023 16:31:38 +0800 Subject: [PATCH] add more comments --- .../jit_pretrained_decode_with_H.py | 26 ++++++++++++++----- .../jit_pretrained_decode_with_HL.py | 26 ++++++++++++++----- .../ASR/tdnn/jit_pretrained_decode_with_H.py | 4 +-- .../ASR/tdnn/jit_pretrained_decode_with_HL.py | 4 +-- 4 files changed, 44 insertions(+), 16 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py index 0309ea873..b52c7cfed 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py @@ -8,13 +8,13 @@ on CPU using OpenFST and decoders from kaldi. Usage: ./conformer_ctc/jit_pretrained_decode_with_H.py \ - --nn-model ./cpu_jit.pt \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ --H ./data/lang_bpe_500/H.fst \ --tokens ./data/lang_bpe_500/tokens.txt \ ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac -Note that to generate ./tdnn/exp/cpu_jit.pt, +Note that to generate ./conformer_ctc/exp/cpu_jit.pt, you can use ./export.py --jit 1 """ @@ -42,7 +42,7 @@ def get_parser(): type=str, required=True, help="""Path to the torchscript model. - You can use ./tdnn/export.py --jit 1 + You can use ./conformer_ctc/export.py --jit 1 to obtain it """, ) @@ -111,8 +111,22 @@ def decode( H: kaldifst, id2token: Dict[int, str], ) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + nnet_output: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + H: + The H graph. + id2token: + A map mapping token ID to token string. + Returns: + Return a list of decoded tokens. + """ logging.info(f"{filename}, {nnet_output.shape}") - decodable = DecodableCtc(nnet_output) + decodable = DecodableCtc(nnet_output.cpu()) decoder_opts = FasterDecoderOptions(max_active=3000) decoder = FasterDecoder(H, decoder_opts) @@ -120,7 +134,7 @@ def decode( if not decoder.reached_final(): print(f"failed to decode {filename}") - return "" + return [""] ok, best_path = decoder.get_best_path() @@ -132,7 +146,7 @@ def decode( ) = kaldifst.get_linear_symbol_sequence(best_path) if not ok: print(f"failed to get linear symbol sequence for {filename}") - return "" + return [""] # tokens are incremented during graph construction # so they need to be decremented diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py index e018feac1..f0326ccdf 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py @@ -8,13 +8,13 @@ on CPU using OpenFST and decoders from kaldi. Usage: ./conformer_ctc/jit_pretrained_decode_with_H.py \ - --nn-model ./cpu_jit.pt \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ --HL ./data/lang_bpe_500/HL.fst \ --words ./data/lang_bpe_500/words.txt \ ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac -Note that to generate ./tdnn/exp/cpu_jit.pt, +Note that to generate ./conformer_ctc/exp/cpu_jit.pt, you can use ./export.py --jit 1 """ @@ -42,7 +42,7 @@ def get_parser(): type=str, required=True, help="""Path to the torchscript model. - You can use ./tdnn/export.py --jit 1 + You can use ./conformer_ctc/export.py --jit 1 to obtain it """, ) @@ -111,8 +111,22 @@ def decode( HL: kaldifst, id2word: Dict[int, str], ) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + nnet_output: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HL: + The HL graph. + word2token: + A map mapping token ID to word string. + Returns: + Return a list of decoded words. + """ logging.info(f"{filename}, {nnet_output.shape}") - decodable = DecodableCtc(nnet_output) + decodable = DecodableCtc(nnet_output.cpu()) decoder_opts = FasterDecoderOptions(max_active=3000) decoder = FasterDecoder(HL, decoder_opts) @@ -120,7 +134,7 @@ def decode( if not decoder.reached_final(): print(f"failed to decode {filename}") - return "" + return [""] ok, best_path = decoder.get_best_path() @@ -132,7 +146,7 @@ def decode( ) = kaldifst.get_linear_symbol_sequence(best_path) if not ok: print(f"failed to get linear symbol sequence for {filename}") - return "" + return [""] # are shifted by 1 during graph construction hyps = [id2word[i] for i in osymbols_out] diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py index d1b6fe748..209ab477a 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py @@ -118,7 +118,7 @@ def decode( if not decoder.reached_final(): print(f"failed to decode {filename}") - return "" + return [""] ok, best_path = decoder.get_best_path() @@ -130,7 +130,7 @@ def decode( ) = kaldifst.get_linear_symbol_sequence(best_path) if not ok: print(f"failed to get linear symbol sequence for {filename}") - return "" + return [""] # are shifted by 1 during graph construction hyps = [id2token[i - 1] for i in osymbols_out if id2token[i - 1] != "SIL"] diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py index bf59ff762..74864e17d 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py @@ -118,7 +118,7 @@ def decode( if not decoder.reached_final(): print(f"failed to decode {filename}") - return "" + return [""] ok, best_path = decoder.get_best_path() @@ -130,7 +130,7 @@ def decode( ) = kaldifst.get_linear_symbol_sequence(best_path) if not ok: print(f"failed to get linear symbol sequence for {filename}") - return "" + return [""] hyps = [id2word[i] for i in osymbols_out if id2word[i] != ""]