add more comments

This commit is contained in:
Fangjun Kuang 2023-09-26 16:31:38 +08:00
parent 0d7d3b13d1
commit 8f63bf0bfd
4 changed files with 44 additions and 16 deletions

View File

@ -8,13 +8,13 @@ on CPU using OpenFST and decoders from kaldi.
Usage:
./conformer_ctc/jit_pretrained_decode_with_H.py \
--nn-model ./cpu_jit.pt \
--nn-model ./conformer_ctc/exp/cpu_jit.pt \
--H ./data/lang_bpe_500/H.fst \
--tokens ./data/lang_bpe_500/tokens.txt \
./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \
./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac
Note that to generate ./tdnn/exp/cpu_jit.pt,
Note that to generate ./conformer_ctc/exp/cpu_jit.pt,
you can use ./export.py --jit 1
"""
@ -42,7 +42,7 @@ def get_parser():
type=str,
required=True,
help="""Path to the torchscript model.
You can use ./tdnn/export.py --jit 1
You can use ./conformer_ctc/export.py --jit 1
to obtain it
""",
)
@ -111,8 +111,22 @@ def decode(
H: kaldifst,
id2token: Dict[int, str],
) -> List[str]:
"""
Args:
filename:
Path to the filename for decoding. Used for debugging.
nnet_output:
A 2-D float32 tensor of shape (num_frames, vocab_size). It
contains output from log_softmax.
H:
The H graph.
id2token:
A map mapping token ID to token string.
Returns:
Return a list of decoded tokens.
"""
logging.info(f"{filename}, {nnet_output.shape}")
decodable = DecodableCtc(nnet_output)
decodable = DecodableCtc(nnet_output.cpu())
decoder_opts = FasterDecoderOptions(max_active=3000)
decoder = FasterDecoder(H, decoder_opts)
@ -120,7 +134,7 @@ def decode(
if not decoder.reached_final():
print(f"failed to decode {filename}")
return ""
return [""]
ok, best_path = decoder.get_best_path()
@ -132,7 +146,7 @@ def decode(
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
print(f"failed to get linear symbol sequence for {filename}")
return ""
return [""]
# tokens are incremented during graph construction
# so they need to be decremented

View File

@ -8,13 +8,13 @@ on CPU using OpenFST and decoders from kaldi.
Usage:
./conformer_ctc/jit_pretrained_decode_with_H.py \
--nn-model ./cpu_jit.pt \
--nn-model ./conformer_ctc/exp/cpu_jit.pt \
--HL ./data/lang_bpe_500/HL.fst \
--words ./data/lang_bpe_500/words.txt \
./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \
./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac
Note that to generate ./tdnn/exp/cpu_jit.pt,
Note that to generate ./conformer_ctc/exp/cpu_jit.pt,
you can use ./export.py --jit 1
"""
@ -42,7 +42,7 @@ def get_parser():
type=str,
required=True,
help="""Path to the torchscript model.
You can use ./tdnn/export.py --jit 1
You can use ./conformer_ctc/export.py --jit 1
to obtain it
""",
)
@ -111,8 +111,22 @@ def decode(
HL: kaldifst,
id2word: Dict[int, str],
) -> List[str]:
"""
Args:
filename:
Path to the filename for decoding. Used for debugging.
nnet_output:
A 2-D float32 tensor of shape (num_frames, vocab_size). It
contains output from log_softmax.
HL:
The HL graph.
word2token:
A map mapping token ID to word string.
Returns:
Return a list of decoded words.
"""
logging.info(f"{filename}, {nnet_output.shape}")
decodable = DecodableCtc(nnet_output)
decodable = DecodableCtc(nnet_output.cpu())
decoder_opts = FasterDecoderOptions(max_active=3000)
decoder = FasterDecoder(HL, decoder_opts)
@ -120,7 +134,7 @@ def decode(
if not decoder.reached_final():
print(f"failed to decode {filename}")
return ""
return [""]
ok, best_path = decoder.get_best_path()
@ -132,7 +146,7 @@ def decode(
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
print(f"failed to get linear symbol sequence for {filename}")
return ""
return [""]
# are shifted by 1 during graph construction
hyps = [id2word[i] for i in osymbols_out]

View File

@ -118,7 +118,7 @@ def decode(
if not decoder.reached_final():
print(f"failed to decode {filename}")
return ""
return [""]
ok, best_path = decoder.get_best_path()
@ -130,7 +130,7 @@ def decode(
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
print(f"failed to get linear symbol sequence for {filename}")
return ""
return [""]
# are shifted by 1 during graph construction
hyps = [id2token[i - 1] for i in osymbols_out if id2token[i - 1] != "SIL"]

View File

@ -118,7 +118,7 @@ def decode(
if not decoder.reached_final():
print(f"failed to decode {filename}")
return ""
return [""]
ok, best_path = decoder.get_best_path()
@ -130,7 +130,7 @@ def decode(
) = kaldifst.get_linear_symbol_sequence(best_path)
if not ok:
print(f"failed to get linear symbol sequence for {filename}")
return ""
return [""]
hyps = [id2word[i] for i in osymbols_out if id2word[i] != "<SIL>"]