add more comments

This commit is contained in:
Fangjun Kuang 2023-09-26 16:31:38 +08:00
parent 0d7d3b13d1
commit 8f63bf0bfd
4 changed files with 44 additions and 16 deletions

View File

@ -8,13 +8,13 @@ on CPU using OpenFST and decoders from kaldi.
Usage: Usage:
./conformer_ctc/jit_pretrained_decode_with_H.py \ ./conformer_ctc/jit_pretrained_decode_with_H.py \
--nn-model ./cpu_jit.pt \ --nn-model ./conformer_ctc/exp/cpu_jit.pt \
--H ./data/lang_bpe_500/H.fst \ --H ./data/lang_bpe_500/H.fst \
--tokens ./data/lang_bpe_500/tokens.txt \ --tokens ./data/lang_bpe_500/tokens.txt \
./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \
./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac
Note that to generate ./tdnn/exp/cpu_jit.pt, Note that to generate ./conformer_ctc/exp/cpu_jit.pt,
you can use ./export.py --jit 1 you can use ./export.py --jit 1
""" """
@ -42,7 +42,7 @@ def get_parser():
type=str, type=str,
required=True, required=True,
help="""Path to the torchscript model. help="""Path to the torchscript model.
You can use ./tdnn/export.py --jit 1 You can use ./conformer_ctc/export.py --jit 1
to obtain it to obtain it
""", """,
) )
@ -111,8 +111,22 @@ def decode(
H: kaldifst, H: kaldifst,
id2token: Dict[int, str], id2token: Dict[int, str],
) -> List[str]: ) -> List[str]:
"""
Args:
filename:
Path to the filename for decoding. Used for debugging.
nnet_output:
A 2-D float32 tensor of shape (num_frames, vocab_size). It
contains output from log_softmax.
H:
The H graph.
id2token:
A map mapping token ID to token string.
Returns:
Return a list of decoded tokens.
"""
logging.info(f"{filename}, {nnet_output.shape}") logging.info(f"{filename}, {nnet_output.shape}")
decodable = DecodableCtc(nnet_output) decodable = DecodableCtc(nnet_output.cpu())
decoder_opts = FasterDecoderOptions(max_active=3000) decoder_opts = FasterDecoderOptions(max_active=3000)
decoder = FasterDecoder(H, decoder_opts) decoder = FasterDecoder(H, decoder_opts)
@ -120,7 +134,7 @@ def decode(
if not decoder.reached_final(): if not decoder.reached_final():
print(f"failed to decode {filename}") print(f"failed to decode {filename}")
return "" return [""]
ok, best_path = decoder.get_best_path() ok, best_path = decoder.get_best_path()
@ -132,7 +146,7 @@ def decode(
) = kaldifst.get_linear_symbol_sequence(best_path) ) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok: if not ok:
print(f"failed to get linear symbol sequence for {filename}") print(f"failed to get linear symbol sequence for {filename}")
return "" return [""]
# tokens are incremented during graph construction # tokens are incremented during graph construction
# so they need to be decremented # so they need to be decremented

View File

@ -8,13 +8,13 @@ on CPU using OpenFST and decoders from kaldi.
Usage: Usage:
./conformer_ctc/jit_pretrained_decode_with_H.py \ ./conformer_ctc/jit_pretrained_decode_with_H.py \
--nn-model ./cpu_jit.pt \ --nn-model ./conformer_ctc/exp/cpu_jit.pt \
--HL ./data/lang_bpe_500/HL.fst \ --HL ./data/lang_bpe_500/HL.fst \
--words ./data/lang_bpe_500/words.txt \ --words ./data/lang_bpe_500/words.txt \
./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \
./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac
Note that to generate ./tdnn/exp/cpu_jit.pt, Note that to generate ./conformer_ctc/exp/cpu_jit.pt,
you can use ./export.py --jit 1 you can use ./export.py --jit 1
""" """
@ -42,7 +42,7 @@ def get_parser():
type=str, type=str,
required=True, required=True,
help="""Path to the torchscript model. help="""Path to the torchscript model.
You can use ./tdnn/export.py --jit 1 You can use ./conformer_ctc/export.py --jit 1
to obtain it to obtain it
""", """,
) )
@ -111,8 +111,22 @@ def decode(
HL: kaldifst, HL: kaldifst,
id2word: Dict[int, str], id2word: Dict[int, str],
) -> List[str]: ) -> List[str]:
"""
Args:
filename:
Path to the filename for decoding. Used for debugging.
nnet_output:
A 2-D float32 tensor of shape (num_frames, vocab_size). It
contains output from log_softmax.
HL:
The HL graph.
word2token:
A map mapping token ID to word string.
Returns:
Return a list of decoded words.
"""
logging.info(f"{filename}, {nnet_output.shape}") logging.info(f"{filename}, {nnet_output.shape}")
decodable = DecodableCtc(nnet_output) decodable = DecodableCtc(nnet_output.cpu())
decoder_opts = FasterDecoderOptions(max_active=3000) decoder_opts = FasterDecoderOptions(max_active=3000)
decoder = FasterDecoder(HL, decoder_opts) decoder = FasterDecoder(HL, decoder_opts)
@ -120,7 +134,7 @@ def decode(
if not decoder.reached_final(): if not decoder.reached_final():
print(f"failed to decode {filename}") print(f"failed to decode {filename}")
return "" return [""]
ok, best_path = decoder.get_best_path() ok, best_path = decoder.get_best_path()
@ -132,7 +146,7 @@ def decode(
) = kaldifst.get_linear_symbol_sequence(best_path) ) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok: if not ok:
print(f"failed to get linear symbol sequence for {filename}") print(f"failed to get linear symbol sequence for {filename}")
return "" return [""]
# are shifted by 1 during graph construction # are shifted by 1 during graph construction
hyps = [id2word[i] for i in osymbols_out] hyps = [id2word[i] for i in osymbols_out]

View File

@ -118,7 +118,7 @@ def decode(
if not decoder.reached_final(): if not decoder.reached_final():
print(f"failed to decode {filename}") print(f"failed to decode {filename}")
return "" return [""]
ok, best_path = decoder.get_best_path() ok, best_path = decoder.get_best_path()
@ -130,7 +130,7 @@ def decode(
) = kaldifst.get_linear_symbol_sequence(best_path) ) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok: if not ok:
print(f"failed to get linear symbol sequence for {filename}") print(f"failed to get linear symbol sequence for {filename}")
return "" return [""]
# are shifted by 1 during graph construction # are shifted by 1 during graph construction
hyps = [id2token[i - 1] for i in osymbols_out if id2token[i - 1] != "SIL"] hyps = [id2token[i - 1] for i in osymbols_out if id2token[i - 1] != "SIL"]

View File

@ -118,7 +118,7 @@ def decode(
if not decoder.reached_final(): if not decoder.reached_final():
print(f"failed to decode {filename}") print(f"failed to decode {filename}")
return "" return [""]
ok, best_path = decoder.get_best_path() ok, best_path = decoder.get_best_path()
@ -130,7 +130,7 @@ def decode(
) = kaldifst.get_linear_symbol_sequence(best_path) ) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok: if not ok:
print(f"failed to get linear symbol sequence for {filename}") print(f"failed to get linear symbol sequence for {filename}")
return "" return [""]
hyps = [id2word[i] for i in osymbols_out if id2word[i] != "<SIL>"] hyps = [id2word[i] for i in osymbols_out if id2word[i] != "<SIL>"]