mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
add more comments
This commit is contained in:
parent
0d7d3b13d1
commit
8f63bf0bfd
@ -8,13 +8,13 @@ on CPU using OpenFST and decoders from kaldi.
|
|||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
./conformer_ctc/jit_pretrained_decode_with_H.py \
|
./conformer_ctc/jit_pretrained_decode_with_H.py \
|
||||||
--nn-model ./cpu_jit.pt \
|
--nn-model ./conformer_ctc/exp/cpu_jit.pt \
|
||||||
--H ./data/lang_bpe_500/H.fst \
|
--H ./data/lang_bpe_500/H.fst \
|
||||||
--tokens ./data/lang_bpe_500/tokens.txt \
|
--tokens ./data/lang_bpe_500/tokens.txt \
|
||||||
./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \
|
./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \
|
||||||
./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac
|
./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac
|
||||||
|
|
||||||
Note that to generate ./tdnn/exp/cpu_jit.pt,
|
Note that to generate ./conformer_ctc/exp/cpu_jit.pt,
|
||||||
you can use ./export.py --jit 1
|
you can use ./export.py --jit 1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ def get_parser():
|
|||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="""Path to the torchscript model.
|
help="""Path to the torchscript model.
|
||||||
You can use ./tdnn/export.py --jit 1
|
You can use ./conformer_ctc/export.py --jit 1
|
||||||
to obtain it
|
to obtain it
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -111,8 +111,22 @@ def decode(
|
|||||||
H: kaldifst,
|
H: kaldifst,
|
||||||
id2token: Dict[int, str],
|
id2token: Dict[int, str],
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Path to the filename for decoding. Used for debugging.
|
||||||
|
nnet_output:
|
||||||
|
A 2-D float32 tensor of shape (num_frames, vocab_size). It
|
||||||
|
contains output from log_softmax.
|
||||||
|
H:
|
||||||
|
The H graph.
|
||||||
|
id2token:
|
||||||
|
A map mapping token ID to token string.
|
||||||
|
Returns:
|
||||||
|
Return a list of decoded tokens.
|
||||||
|
"""
|
||||||
logging.info(f"{filename}, {nnet_output.shape}")
|
logging.info(f"{filename}, {nnet_output.shape}")
|
||||||
decodable = DecodableCtc(nnet_output)
|
decodable = DecodableCtc(nnet_output.cpu())
|
||||||
|
|
||||||
decoder_opts = FasterDecoderOptions(max_active=3000)
|
decoder_opts = FasterDecoderOptions(max_active=3000)
|
||||||
decoder = FasterDecoder(H, decoder_opts)
|
decoder = FasterDecoder(H, decoder_opts)
|
||||||
@ -120,7 +134,7 @@ def decode(
|
|||||||
|
|
||||||
if not decoder.reached_final():
|
if not decoder.reached_final():
|
||||||
print(f"failed to decode {filename}")
|
print(f"failed to decode {filename}")
|
||||||
return ""
|
return [""]
|
||||||
|
|
||||||
ok, best_path = decoder.get_best_path()
|
ok, best_path = decoder.get_best_path()
|
||||||
|
|
||||||
@ -132,7 +146,7 @@ def decode(
|
|||||||
) = kaldifst.get_linear_symbol_sequence(best_path)
|
) = kaldifst.get_linear_symbol_sequence(best_path)
|
||||||
if not ok:
|
if not ok:
|
||||||
print(f"failed to get linear symbol sequence for {filename}")
|
print(f"failed to get linear symbol sequence for {filename}")
|
||||||
return ""
|
return [""]
|
||||||
|
|
||||||
# tokens are incremented during graph construction
|
# tokens are incremented during graph construction
|
||||||
# so they need to be decremented
|
# so they need to be decremented
|
||||||
|
@ -8,13 +8,13 @@ on CPU using OpenFST and decoders from kaldi.
|
|||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
./conformer_ctc/jit_pretrained_decode_with_H.py \
|
./conformer_ctc/jit_pretrained_decode_with_H.py \
|
||||||
--nn-model ./cpu_jit.pt \
|
--nn-model ./conformer_ctc/exp/cpu_jit.pt \
|
||||||
--HL ./data/lang_bpe_500/HL.fst \
|
--HL ./data/lang_bpe_500/HL.fst \
|
||||||
--words ./data/lang_bpe_500/words.txt \
|
--words ./data/lang_bpe_500/words.txt \
|
||||||
./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \
|
./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \
|
||||||
./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac
|
./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac
|
||||||
|
|
||||||
Note that to generate ./tdnn/exp/cpu_jit.pt,
|
Note that to generate ./conformer_ctc/exp/cpu_jit.pt,
|
||||||
you can use ./export.py --jit 1
|
you can use ./export.py --jit 1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -42,7 +42,7 @@ def get_parser():
|
|||||||
type=str,
|
type=str,
|
||||||
required=True,
|
required=True,
|
||||||
help="""Path to the torchscript model.
|
help="""Path to the torchscript model.
|
||||||
You can use ./tdnn/export.py --jit 1
|
You can use ./conformer_ctc/export.py --jit 1
|
||||||
to obtain it
|
to obtain it
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
@ -111,8 +111,22 @@ def decode(
|
|||||||
HL: kaldifst,
|
HL: kaldifst,
|
||||||
id2word: Dict[int, str],
|
id2word: Dict[int, str],
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
filename:
|
||||||
|
Path to the filename for decoding. Used for debugging.
|
||||||
|
nnet_output:
|
||||||
|
A 2-D float32 tensor of shape (num_frames, vocab_size). It
|
||||||
|
contains output from log_softmax.
|
||||||
|
HL:
|
||||||
|
The HL graph.
|
||||||
|
word2token:
|
||||||
|
A map mapping token ID to word string.
|
||||||
|
Returns:
|
||||||
|
Return a list of decoded words.
|
||||||
|
"""
|
||||||
logging.info(f"{filename}, {nnet_output.shape}")
|
logging.info(f"{filename}, {nnet_output.shape}")
|
||||||
decodable = DecodableCtc(nnet_output)
|
decodable = DecodableCtc(nnet_output.cpu())
|
||||||
|
|
||||||
decoder_opts = FasterDecoderOptions(max_active=3000)
|
decoder_opts = FasterDecoderOptions(max_active=3000)
|
||||||
decoder = FasterDecoder(HL, decoder_opts)
|
decoder = FasterDecoder(HL, decoder_opts)
|
||||||
@ -120,7 +134,7 @@ def decode(
|
|||||||
|
|
||||||
if not decoder.reached_final():
|
if not decoder.reached_final():
|
||||||
print(f"failed to decode {filename}")
|
print(f"failed to decode {filename}")
|
||||||
return ""
|
return [""]
|
||||||
|
|
||||||
ok, best_path = decoder.get_best_path()
|
ok, best_path = decoder.get_best_path()
|
||||||
|
|
||||||
@ -132,7 +146,7 @@ def decode(
|
|||||||
) = kaldifst.get_linear_symbol_sequence(best_path)
|
) = kaldifst.get_linear_symbol_sequence(best_path)
|
||||||
if not ok:
|
if not ok:
|
||||||
print(f"failed to get linear symbol sequence for {filename}")
|
print(f"failed to get linear symbol sequence for {filename}")
|
||||||
return ""
|
return [""]
|
||||||
|
|
||||||
# are shifted by 1 during graph construction
|
# are shifted by 1 during graph construction
|
||||||
hyps = [id2word[i] for i in osymbols_out]
|
hyps = [id2word[i] for i in osymbols_out]
|
||||||
|
@ -118,7 +118,7 @@ def decode(
|
|||||||
|
|
||||||
if not decoder.reached_final():
|
if not decoder.reached_final():
|
||||||
print(f"failed to decode {filename}")
|
print(f"failed to decode {filename}")
|
||||||
return ""
|
return [""]
|
||||||
|
|
||||||
ok, best_path = decoder.get_best_path()
|
ok, best_path = decoder.get_best_path()
|
||||||
|
|
||||||
@ -130,7 +130,7 @@ def decode(
|
|||||||
) = kaldifst.get_linear_symbol_sequence(best_path)
|
) = kaldifst.get_linear_symbol_sequence(best_path)
|
||||||
if not ok:
|
if not ok:
|
||||||
print(f"failed to get linear symbol sequence for {filename}")
|
print(f"failed to get linear symbol sequence for {filename}")
|
||||||
return ""
|
return [""]
|
||||||
|
|
||||||
# are shifted by 1 during graph construction
|
# are shifted by 1 during graph construction
|
||||||
hyps = [id2token[i - 1] for i in osymbols_out if id2token[i - 1] != "SIL"]
|
hyps = [id2token[i - 1] for i in osymbols_out if id2token[i - 1] != "SIL"]
|
||||||
|
@ -118,7 +118,7 @@ def decode(
|
|||||||
|
|
||||||
if not decoder.reached_final():
|
if not decoder.reached_final():
|
||||||
print(f"failed to decode {filename}")
|
print(f"failed to decode {filename}")
|
||||||
return ""
|
return [""]
|
||||||
|
|
||||||
ok, best_path = decoder.get_best_path()
|
ok, best_path = decoder.get_best_path()
|
||||||
|
|
||||||
@ -130,7 +130,7 @@ def decode(
|
|||||||
) = kaldifst.get_linear_symbol_sequence(best_path)
|
) = kaldifst.get_linear_symbol_sequence(best_path)
|
||||||
if not ok:
|
if not ok:
|
||||||
print(f"failed to get linear symbol sequence for {filename}")
|
print(f"failed to get linear symbol sequence for {filename}")
|
||||||
return ""
|
return [""]
|
||||||
|
|
||||||
hyps = [id2word[i] for i in osymbols_out if id2word[i] != "<SIL>"]
|
hyps = [id2word[i] for i in osymbols_out if id2word[i] != "<SIL>"]
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user