diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh index a4959aa01..c4767b723 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -44,3 +44,46 @@ log "HLG decoding" $repo/test_wavs/1089-134686-0001.flac \ $repo/test_wavs/1221-135766-0001.flac \ $repo/test_wavs/1221-135766-0002.flac + +log "CTC decoding on CPU with kaldi decoders using OpenFst" + +log "Exporting model with torchscript" + +pushd $repo/exp +ln -s pretrained.pt epoch-99.pt +popd + +./conformer_ctc/export.py \ + --epoch 99 \ + --avg 1 \ + --exp-dir $repo/exp \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --jit 1 + +ls -lh $repo/exp + + +log "Generating H.fst, HL.fst" + +./local/prepare_lang_fst.py --lang-dir $repo/data/lang_bpe_500 +ls -lh $repo/data/lang_bpe_500 + +log "Decoding with H on CPU with OpenFst" + +./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --H $repo/data/lang_bpe_500/H.fst \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.flac \ + $repo/test_wavs/1221-135766-0001.flac \ + $repo/test_wavs/1221-135766-0002.flac + +log "Decoding with HL on CPU with OpenFst" + +./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HL $repo/data/lang_bpe_500/HL.fst \ + --words $repo/data/lang_bpe_500/words.txt \ + $repo/test_wavs/1089-134686-0001.flac \ + $repo/test_wavs/1221-135766-0001.flac \ + $repo/test_wavs/1221-135766-0002.flac diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-conformer-ctc.yml index 6151a5a14..e268d840d 100644 --- a/.github/workflows/run-pretrained-conformer-ctc.yml +++ b/.github/workflows/run-pretrained-conformer-ctc.yml @@ -29,7 +29,7 @@ concurrency: jobs: run_pre_trained_conformer_ctc: - if: github.event.label.name == 'ready' || github.event_name == 'push' + if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'ctc' runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py new file mode 100755 index 000000000..0309ea873 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_H.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with H +on CPU using OpenFST and decoders from kaldi. + +Usage: + + ./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model ./cpu_jit.pt \ + --H ./data/lang_bpe_500/H.fst \ + --tokens ./data/lang_bpe_500/tokens.txt \ + ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ + ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac + +Note that to generate ./tdnn/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldi_hmm_gmm +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./tdnn/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--tokens", + type=str, + required=True, + help="Path to tokens.txt", + ) + + parser.add_argument("--H", type=str, required=True, help="Path to H.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_tokens(tokens_txt: str) -> Dict[int, str]: + id2token = dict() + with open(tokens_txt, encoding="utf-8") as f: + for line in f: + token, idx = line.strip().split() + id2token[int(idx)] = token + + return id2token + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + H: kaldifst, + id2token: Dict[int, str], +) -> List[str]: + logging.info(f"{filename}, {nnet_output.shape}") + decodable = DecodableCtc(nnet_output) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(H, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + print(f"failed to decode {filename}") + return "" + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + print(f"failed to get linear symbol sequence for {filename}") + return "" + + # tokens are incremented during graph construction + # so they need to be decremented + hyps = [id2token[i - 1] for i in osymbols_out] + # hyps = "".join(hyps).split("▁") + hyps = "".join(hyps).split("\u2581") # unicode codepoint of ▁ + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading H from {args.H}") + H = kaldifst.StdVectorFst.read(args.H) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.shape[0] for f in features] + feature_lengths = torch.tensor(feature_lengths) + + supervisions = dict() + supervisions["sequence_idx"] = torch.arange(len(features)) + supervisions["start_frame"] = torch.zeros(len(features)) + supervisions["num_frames"] = feature_lengths + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output, _, _ = model(features, supervisions) + feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2 + + id2token = read_tokens(args.tokens) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[i], + nnet_output=nnet_output[i, : feature_lengths[i]], + H=H, + id2token=id2token, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py new file mode 100755 index 000000000..e018feac1 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with H +on CPU using OpenFST and decoders from kaldi. + +Usage: + + ./conformer_ctc/jit_pretrained_decode_with_H.py \ + --nn-model ./cpu_jit.pt \ + --HL ./data/lang_bpe_500/HL.fst \ + --words ./data/lang_bpe_500/words.txt \ + ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ + ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac + +Note that to generate ./tdnn/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldi_hmm_gmm +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./tdnn/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--words", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument("--HL", type=str, required=True, help="Path to HL.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_words(words_txt: str) -> Dict[int, str]: + id2word = dict() + with open(words_txt, encoding="utf-8") as f: + for line in f: + word, idx = line.strip().split() + id2word[int(idx)] = word + + return id2word + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + HL: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + logging.info(f"{filename}, {nnet_output.shape}") + decodable = DecodableCtc(nnet_output) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HL, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + print(f"failed to decode {filename}") + return "" + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + print(f"failed to get linear symbol sequence for {filename}") + return "" + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading HL from {args.HL}") + HL = kaldifst.StdVectorFst.read(args.HL) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.shape[0] for f in features] + feature_lengths = torch.tensor(feature_lengths) + + supervisions = dict() + supervisions["sequence_idx"] = torch.arange(len(features)) + supervisions["start_frame"] = torch.zeros(len(features)) + supervisions["num_frames"] = feature_lengths + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output, _, _ = model(features, supervisions) + feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2 + + id2word = read_words(args.words) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[i], + nnet_output=nnet_output[i, : feature_lengths[i]], + HL=HL, + id2word=id2word, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/local/prepare_lang_fst.py b/egs/librispeech/ASR/local/prepare_lang_fst.py new file mode 100755 index 000000000..e8401123f --- /dev/null +++ b/egs/librispeech/ASR/local/prepare_lang_fst.py @@ -0,0 +1,127 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang) + +""" +This script takes as input lang_dir containing lexicon_disambig.txt, +tokens.txt, and words.txt and generates the following files: + + - H.fst + - HL.fst + +Note that saved files are in OpenFst binary format. + +Usage: + +./local/prepare_lang_fst.py \ + --lang-dir ./data/lang_phone \ + --has-silence 1 + +Or + +./local/prepare_lang_fst.py \ + --lang-dir ./data/lang_bpe_500 +""" + +import argparse +import logging +from pathlib import Path + +import kaldifst + +from icefall.ctc import ( + Lexicon, + add_disambig_self_loops, + add_one, + build_standard_ctc_topo, + make_lexicon_fst_no_silence, + make_lexicon_fst_with_silence, +) +from icefall.utils import str2bool + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + parser.add_argument( + "--has-silence", + type=str2bool, + default=False, + help="True if the lexicon has silence.", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = args.lang_dir + + lexicon = Lexicon(lang_dir) + + logging.info("Building standard CTC topology") + max_token_id = max(lexicon.tokens) + H = build_standard_ctc_topo(max_token_id=max_token_id) + + # We need to add one to all tokens since we want to use ID 0 + # for epsilon + add_one(H, treat_ilabel_zero_specially=False, update_olabel=True) + H.write(f"{lang_dir}/H.fst") + + logging.info("Building L") + # Now for HL + + if args.has_silence: + L = make_lexicon_fst_with_silence(lexicon, attach_symbol_table=False) + else: + L = make_lexicon_fst_no_silence(lexicon, attach_symbol_table=False) + + if args.has_silence: + # We also need to change the input labels of L + add_one(L, treat_ilabel_zero_specially=True, update_olabel=False) + else: + add_one(L, treat_ilabel_zero_specially=False, update_olabel=False) + + # Invoke add_disambig_self_loops() so that it eats the disambig symbols + # from L after composition + add_disambig_self_loops( + H, + start=lexicon.token2id["#0"] + 1, + end=lexicon.max_disambig_id + 1, + ) + with open("H_1.fst.txt", "w") as f: + print(H, file=f) + + kaldifst.arcsort(H, sort_type="olabel") + kaldifst.arcsort(L, sort_type="ilabel") + + logging.info("Building HL") + HL = kaldifst.compose(H, L) + kaldifst.determinize_star(HL) + + disambig0 = lexicon.token2id["#0"] + 1 + max_disambig = lexicon.max_disambig_id + 1 + for state in kaldifst.StateIterator(HL): + for arc in kaldifst.ArcIterator(HL, state): + # If treat_ilabel_zero_specially is False, we always change it + # Otherwise, we only change non-zero input labels + if disambig0 <= arc.ilabel <= max_disambig: + arc.ilabel = 0 + + # Note: We are not composing L with G, so there is no need to add + # self-loops to L to handle #0 + + HL.write(f"{lang_dir}/HL.fst") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 8ce1eb478..fca2c6cc4 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -242,6 +242,10 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then $lang_dir/L_disambig.pt \ $lang_dir/L_disambig.fst fi + + if [ ! -f $lang_dir/HL.fst ]; then + ./local/prepare_lang_fst.py --lang-dir $lang_dir + fi done fi diff --git a/egs/yesno/ASR/local/prepare_lang_fst.py b/egs/yesno/ASR/local/prepare_lang_fst.py deleted file mode 100755 index e1c35e842..000000000 --- a/egs/yesno/ASR/local/prepare_lang_fst.py +++ /dev/null @@ -1,67 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2023 Xiaomi Corporation (authors: Fangjun Kuang) - -""" -This script takes as input data/lang_phone containing lexicon_disambig.txt, -tokens.txt, and words.txt and generates the following files: - - - H.fst - - HL.fst - -TODO(fangjun): Generate HLG.fst - -Note that saved files are in OpenFst binary format. -""" - -from pathlib import Path - -import kaldifst - -from icefall.ctc import ( - Lexicon, - add_disambig_self_loops, - add_one, - build_standard_ctc_topo, - make_lexicon_fst_with_silence, -) - - -def main(): - lang_dir = Path("data/lang_phone") - lexicon = Lexicon(lang_dir) - - max_token_id = max(lexicon.tokens) - H = build_standard_ctc_topo(max_token_id=max_token_id) - - # We need to add one to all tokens since we want to use ID 0 - # for epsilon - add_one(H, treat_ilabel_zero_specially=False, update_olabel=True) - H.write(f"{lang_dir}/H.fst") - - # Now for HL - L = make_lexicon_fst_with_silence(lexicon, attach_symbol_table=False) - - # We also need to change the input labels of L - add_one(L, treat_ilabel_zero_specially=True, update_olabel=False) - - # Invoke add_disambig_self_loops() so that it eats the disambig symbols - # from L after composition - add_disambig_self_loops( - H, - start=lexicon.token2id["#0"] + 1, - end=lexicon.max_disambig_id, - ) - - kaldifst.arcsort(H, sort_type="olabel") - kaldifst.arcsort(L, sort_type="ilabel") - HL = kaldifst.compose(H, L) - - # Note: We are not composing L with G, so there is no need to add - # self-loops to L to handle #0 - - HL.write(f"{lang_dir}/HL.fst") - - -if __name__ == "__main__": - main() diff --git a/egs/yesno/ASR/local/prepare_lang_fst.py b/egs/yesno/ASR/local/prepare_lang_fst.py new file mode 120000 index 000000000..c5787c534 --- /dev/null +++ b/egs/yesno/ASR/local/prepare_lang_fst.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_fst.py \ No newline at end of file diff --git a/egs/yesno/ASR/prepare.sh b/egs/yesno/ASR/prepare.sh index e2b6b5daa..cfb7515f5 100755 --- a/egs/yesno/ASR/prepare.sh +++ b/egs/yesno/ASR/prepare.sh @@ -60,7 +60,7 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then ) > $lang_dir/lexicon.txt ./local/prepare_lang.py - ./local/prepare_lang_fst.py + ./local/prepare_lang_fst.py --lang-dir ./data/lang_phone --has-slience 1 fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py index eb7d4da7d..d1b6fe748 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_H.py @@ -28,11 +28,9 @@ import kaldifeat import kaldifst import torch import torchaudio -from kaldi_hmm_gmm import FasterDecoder, FasterDecoderOptions +from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions from torch.nn.utils.rnn import pad_sequence -from icefall.ctc import CtcDecodable - def get_parser(): parser = argparse.ArgumentParser( @@ -113,8 +111,8 @@ def decode( H: kaldifst, id2token: Dict[int, str], ) -> List[str]: - decodable = CtcDecodable(nnet_output) - decoder_opts = FasterDecoderOptions() + decodable = DecodableCtc(nnet_output) + decoder_opts = FasterDecoderOptions(max_active=3000) decoder = FasterDecoder(H, decoder_opts) decoder.decode(decodable) diff --git a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py index f31a918ef..bf59ff762 100755 --- a/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py +++ b/egs/yesno/ASR/tdnn/jit_pretrained_decode_with_HL.py @@ -28,11 +28,9 @@ import kaldifeat import kaldifst import torch import torchaudio -from kaldi_hmm_gmm import FasterDecoder, FasterDecoderOptions +from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions from torch.nn.utils.rnn import pad_sequence -from icefall.ctc import CtcDecodable - def get_parser(): parser = argparse.ArgumentParser( @@ -113,8 +111,8 @@ def decode( HL: kaldifst, id2word: Dict[int, str], ) -> List[str]: - decodable = CtcDecodable(nnet_output) - decoder_opts = FasterDecoderOptions() + decodable = DecodableCtc(nnet_output) + decoder_opts = FasterDecoderOptions(max_active=3000) decoder = FasterDecoder(HL, decoder_opts) decoder.decode(decodable) diff --git a/icefall/ctc/__init__.py b/icefall/ctc/__init__.py index db26f912e..b546b31af 100644 --- a/icefall/ctc/__init__.py +++ b/icefall/ctc/__init__.py @@ -1,4 +1,3 @@ -from .decodable import CtcDecodable from .prepare_lang import ( Lexicon, make_lexicon_fst_no_silence, diff --git a/icefall/ctc/decodable.py b/icefall/ctc/decodable.py deleted file mode 100644 index 7d7caa5a5..000000000 --- a/icefall/ctc/decodable.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright 2023 Xiaomi Corp. (author: Fangjun Kuang) - -import torch -from kaldi_hmm_gmm import DecodableInterface - - -class CtcDecodable(DecodableInterface): - """This class implements the interface - https://github.com/kaldi-asr/kaldi/blob/master/src/itf/decodable-itf.h - """ - - def __init__(self, nnet_output: torch.Tensor): - DecodableInterface.__init__(self) - assert nnet_output.ndim == 2, nnet_output.shape - self.nnet_output = nnet_output - - def log_likelihood(self, frame: int, index: int) -> float: - # Note: We need to use index - 1 here since - # all the input labels of the H are incremented during graph - # construction - return self.nnet_output[frame][index - 1].item() - - def is_last_frame(self, frame: int) -> bool: - return frame == self.nnet_output.shape[0] - 1 - - def num_frames_ready(self) -> int: - return self.nnet_output.shape[0] - - def num_indices(self) -> int: - return self.nnet_output.shape[1] diff --git a/icefall/ctc/test_ctc_topo.py b/icefall/ctc/test_ctc_topo.py index da3b22b18..4d4667209 100755 --- a/icefall/ctc/test_ctc_topo.py +++ b/icefall/ctc/test_ctc_topo.py @@ -5,7 +5,12 @@ from pathlib import Path import graphviz import kaldifst -from prepare_lang import Lexicon, make_lexicon_fst_with_silence +import sentencepiece as spm +from prepare_lang import ( + Lexicon, + make_lexicon_fst_no_silence, + make_lexicon_fst_with_silence, +) from topo import add_disambig_self_loops, add_one, build_standard_ctc_topo @@ -85,8 +90,50 @@ def test_yesno(): source.render(outfile="HL_yesno.pdf") +def test_librispeech(): + lang_dir = ( + "/star-fj/fangjun/open-source/icefall-2/egs/librispeech/ASR/data/lang_bpe_500" + ) + + if not Path(lang_dir).is_dir(): + print(f"{lang_dir} does not exist! Skip testing") + return + + lexicon = Lexicon(lang_dir) + HL = kaldifst.StdVectorFst.read(lang_dir + "/HL.fst") + + sp = spm.SentencePieceProcessor() + sp.load(lang_dir + "/bpe.model") + + i = lexicon.word2id["HELLOA"] + k = lexicon.word2id["WORLD"] + print(i, k) + s = f""" + 0 1 {i} {i} + 1 2 {k} {k} + 2 + """ + fst = kaldifst.compile( + s=s, + acceptor=False, + ) + + L = make_lexicon_fst_no_silence(lexicon, attach_symbol_table=False) + kaldifst.arcsort(L, sort_type="olabel") + with open("L.fst.txt", "w") as f: + print(L, file=f) + + fst = kaldifst.compose(L, fst) + print(fst) + fst_dot = kaldifst.draw(fst, acceptor=False, portrait=True) + source = graphviz.Source(fst_dot) + source.render(outfile="a.pdf") + print(sp.encode(["HELLOA", "WORLD"])) + + def main(): test_yesno() + test_librispeech() if __name__ == "__main__": diff --git a/icefall/ctc/topo.py b/icefall/ctc/topo.py index 9b1839f69..6a96dd038 100644 --- a/icefall/ctc/topo.py +++ b/icefall/ctc/topo.py @@ -107,9 +107,8 @@ def add_one( def add_disambig_self_loops(fst: kaldifst.StdVectorFst, start: int, end: int): """Add self-loops to each state. - For each disambig symbol, we add a self-loop with input label 0 and output - label diambig_id of that disambig symbol. Note that input label 0 here - represents an epsilon. + For each disambig symbol, we add a self-loop with input label disambig_id + and output label diambig_id of that disambig symbol. Args: fst: @@ -119,14 +118,14 @@ def add_disambig_self_loops(fst: kaldifst.StdVectorFst, start: int, end: int): end: The ID of the last disambig symbol. For instance if there are 3 disambig symbols ``#0``, ``#1``, and ``#2``, then ``end`` is the ID - of ``#3``. + of ``#2``. """ for state in kaldifst.StateIterator(fst): for i in range(start, end + 1): fst.add_arc( state=state, arc=kaldifst.StdArc( - ilabel=0, + ilabel=i, olabel=i, weight=0, nextstate=state,