diff --git a/.github/scripts/run-pre-trained-conformer-ctc.sh b/.github/scripts/run-pre-trained-conformer-ctc.sh index 19cbd96fc..a82d85fb2 100755 --- a/.github/scripts/run-pre-trained-conformer-ctc.sh +++ b/.github/scripts/run-pre-trained-conformer-ctc.sh @@ -10,16 +10,30 @@ log() { cd egs/librispeech/ASR -repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500 -git lfs install - +# repo_url=https://github.com/csukuangfj/icefall-asr-conformer-ctc-bpe-500 +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09 log "Downloading pre-trained model from $repo_url" -git clone $repo_url +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url repo=$(basename $repo_url) +pushd $repo + +git lfs pull --include "exp/pretrained.pt" +git lfs pull --include "data/lang_bpe_500/HLG.pt" +git lfs pull --include "data/lang_bpe_500/L.pt" +git lfs pull --include "data/lang_bpe_500/L_disambig.pt" +git lfs pull --include "data/lang_bpe_500/Linv.pt" +git lfs pull --include "data/lang_bpe_500/bpe.model" +git lfs pull --include "data/lang_bpe_500/lexicon.txt" +git lfs pull --include "data/lang_bpe_500/lexicon_disambig.txt" +git lfs pull --include "data/lang_bpe_500/tokens.txt" +git lfs pull --include "data/lang_bpe_500/words.txt" +git lfs pull --include "data/lm/G_3_gram.fst.txt" + +popd log "Display test files" tree $repo/ -ls -lh $repo/test_wavs/*.flac +ls -lh $repo/test_wavs/*.wav log "CTC decoding" @@ -28,9 +42,9 @@ log "CTC decoding" --num-classes 500 \ --checkpoint $repo/exp/pretrained.pt \ --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.flac \ - $repo/test_wavs/1221-135766-0001.flac \ - $repo/test_wavs/1221-135766-0002.flac + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav log "HLG decoding" @@ -41,9 +55,9 @@ log "HLG decoding" --tokens $repo/data/lang_bpe_500/tokens.txt \ --words-file $repo/data/lang_bpe_500/words.txt \ --HLG $repo/data/lang_bpe_500/HLG.pt \ - $repo/test_wavs/1089-134686-0001.flac \ - $repo/test_wavs/1221-135766-0001.flac \ - $repo/test_wavs/1221-135766-0002.flac + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav log "CTC decoding on CPU with kaldi decoders using OpenFst" @@ -65,7 +79,8 @@ ls -lh $repo/exp log "Generating H.fst, HL.fst" -./local/prepare_lang_fst.py --lang-dir $repo/data/lang_bpe_500 +./local/prepare_lang_fst.py --lang-dir $repo/data/lang_bpe_500 --ngram-G $repo/data/lm/G_3_gram.fst.txt + ls -lh $repo/data/lang_bpe_500 log "Decoding with H on CPU with OpenFst" @@ -74,9 +89,9 @@ log "Decoding with H on CPU with OpenFst" --nn-model $repo/exp/cpu_jit.pt \ --H $repo/data/lang_bpe_500/H.fst \ --tokens $repo/data/lang_bpe_500/tokens.txt \ - $repo/test_wavs/1089-134686-0001.flac \ - $repo/test_wavs/1221-135766-0001.flac \ - $repo/test_wavs/1221-135766-0002.flac + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav log "Decoding with HL on CPU with OpenFst" @@ -84,6 +99,16 @@ log "Decoding with HL on CPU with OpenFst" --nn-model $repo/exp/cpu_jit.pt \ --HL $repo/data/lang_bpe_500/HL.fst \ --words $repo/data/lang_bpe_500/words.txt \ - $repo/test_wavs/1089-134686-0001.flac \ - $repo/test_wavs/1221-135766-0001.flac \ - $repo/test_wavs/1221-135766-0002.flac + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Decoding with HLG on CPU with OpenFst" + +./conformer_ctc/jit_pretrained_decode_with_HLG.py \ + --nn-model $repo/exp/cpu_jit.pt \ + --HLG $repo/data/lang_bpe_500/HLG.fst \ + --words $repo/data/lang_bpe_500/words.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-conformer-ctc.yml index e268d840d..54845159d 100644 --- a/.github/workflows/run-pretrained-conformer-ctc.yml +++ b/.github/workflows/run-pretrained-conformer-ctc.yml @@ -23,13 +23,20 @@ on: pull_request: types: [labeled] + workflow_dispatch: + inputs: + test-run: + description: 'Test (y/n)?' + required: true + default: 'y' + concurrency: group: run_pre_trained_conformer_ctc-${{ github.ref }} cancel-in-progress: true jobs: run_pre_trained_conformer_ctc: - if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'ctc' + if: github.event.label.name == 'ready' || github.event_name == 'push' || github.event.inputs.test-run == 'y' runs-on: ${{ matrix.os }} strategy: matrix: diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py index f0326ccdf..3420c4da3 100755 --- a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HL.py @@ -2,12 +2,12 @@ # Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) """ -This file shows how to use a torchscript model for decoding with H +This file shows how to use a torchscript model for decoding with HL on CPU using OpenFST and decoders from kaldi. Usage: - ./conformer_ctc/jit_pretrained_decode_with_H.py \ + ./conformer_ctc/jit_pretrained_decode_with_HL.py \ --nn-model ./conformer_ctc/exp/cpu_jit.pt \ --HL ./data/lang_bpe_500/HL.fst \ --words ./data/lang_bpe_500/words.txt \ diff --git a/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py new file mode 100755 index 000000000..42129f073 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/jit_pretrained_decode_with_HLG.py @@ -0,0 +1,232 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang) + +""" +This file shows how to use a torchscript model for decoding with HLG +on CPU using OpenFST and decoders from kaldi. + +Usage: + + ./conformer_ctc/jit_pretrained_decode_with_HLG.py \ + --nn-model ./conformer_ctc/exp/cpu_jit.pt \ + --HLG ./data/lang_bpe_500/HLG.fst \ + --words ./data/lang_bpe_500/words.txt \ + ./download/LibriSpeech/test-clean/1089/134686/1089-134686-0002.flac \ + ./download/LibriSpeech/test-clean/1221/135766/1221-135766-0001.flac + +Note that to generate ./conformer_ctc/exp/cpu_jit.pt, +you can use ./export.py --jit 1 +""" + +import argparse +import logging +import math +from typing import Dict, List + +import kaldi_hmm_gmm +import kaldifeat +import kaldifst +import torch +import torchaudio +from kaldi_hmm_gmm import DecodableCtc, FasterDecoder, FasterDecoderOptions +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model", + type=str, + required=True, + help="""Path to the torchscript model. + You can use ./conformer_ctc/export.py --jit 1 + to obtain it + """, + ) + + parser.add_argument( + "--words", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.fst") + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. ", + ) + + return parser + + +def read_words(words_txt: str) -> Dict[int, str]: + id2word = dict() + with open(words_txt, encoding="utf-8") as f: + for line in f: + word, idx = line.strip().split() + id2word[int(idx)] = word + + return id2word + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + if sample_rate != expected_sample_rate: + wave = torchaudio.functional.resample( + wave, + orig_freq=sample_rate, + new_freq=expected_sample_rate, + ) + + # We use only the first channel + ans.append(wave[0].contiguous()) + return ans + + +def decode( + filename: str, + nnet_output: torch.Tensor, + HLG: kaldifst, + id2word: Dict[int, str], +) -> List[str]: + """ + Args: + filename: + Path to the filename for decoding. Used for debugging. + nnet_output: + A 2-D float32 tensor of shape (num_frames, vocab_size). It + contains output from log_softmax. + HLG: + The HLG graph. + word2token: + A map mapping token ID to word string. + Returns: + Return a list of decoded words. + """ + logging.info(f"{filename}, {nnet_output.shape}") + decodable = DecodableCtc(nnet_output.cpu()) + + decoder_opts = FasterDecoderOptions(max_active=3000) + decoder = FasterDecoder(HLG, decoder_opts) + decoder.decode(decodable) + + if not decoder.reached_final(): + print(f"failed to decode {filename}") + return [""] + + ok, best_path = decoder.get_best_path() + + ( + ok, + isymbols_out, + osymbols_out, + total_weight, + ) = kaldifst.get_linear_symbol_sequence(best_path) + if not ok: + print(f"failed to get linear symbol sequence for {filename}") + return [""] + + # are shifted by 1 during graph construction + hyps = [id2word[i] for i in osymbols_out] + + return hyps + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + logging.info("Loading torchscript model") + model = torch.jit.load(args.nn_model) + model.eval() + model.to(device) + + logging.info(f"Loading HLG from {args.HLG}") + HLG = kaldifst.StdVectorFst.read(args.HLG) + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, expected_sample_rate=sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.shape[0] for f in features] + feature_lengths = torch.tensor(feature_lengths) + + supervisions = dict() + supervisions["sequence_idx"] = torch.arange(len(features)) + supervisions["start_frame"] = torch.zeros(len(features)) + supervisions["num_frames"] = feature_lengths + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + nnet_output, _, _ = model(features, supervisions) + feature_lengths = ((feature_lengths - 1) // 2 - 1) // 2 + + id2word = read_words(args.words) + + hyps = [] + for i in range(nnet_output.shape[0]): + hyp = decode( + filename=args.sound_files[i], + nnet_output=nnet_output[i, : feature_lengths[i]], + HLG=HLG, + id2word=id2word, + ) + hyps.append(hyp) + + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/local/prepare_lang_fst.py b/egs/librispeech/ASR/local/prepare_lang_fst.py index e8401123f..fb1e7f9c0 100755 --- a/egs/librispeech/ASR/local/prepare_lang_fst.py +++ b/egs/librispeech/ASR/local/prepare_lang_fst.py @@ -8,6 +8,7 @@ tokens.txt, and words.txt and generates the following files: - H.fst - HL.fst + - HLG.fst Note that saved files are in OpenFst binary format. @@ -56,9 +57,114 @@ def get_args(): help="True if the lexicon has silence.", ) + parser.add_argument( + "--ngram-G", + type=str, + help="""If not empty, it is the filename of G used to build HLG. + For instance, --ngram-G=./data/lm/G_3_fst.txt + """, + ) + return parser.parse_args() +def build_HL( + H: kaldifst.StdVectorFst, + L: kaldifst.StdVectorFst, + has_silence: bool, + lexicon: Lexicon, +) -> kaldifst.StdVectorFst: + if has_silence: + # We also need to change the input labels of L + add_one(L, treat_ilabel_zero_specially=True, update_olabel=False) + else: + add_one(L, treat_ilabel_zero_specially=False, update_olabel=False) + + # Invoke add_disambig_self_loops() so that it eats the disambig symbols + # from L after composition + add_disambig_self_loops( + H, + start=lexicon.token2id["#0"] + 1, + end=lexicon.max_disambig_id + 1, + ) + + kaldifst.arcsort(H, sort_type="olabel") + kaldifst.arcsort(L, sort_type="ilabel") + + HL = kaldifst.compose(H, L) + kaldifst.determinize_star(HL) + + disambig0 = lexicon.token2id["#0"] + 1 + max_disambig = lexicon.max_disambig_id + 1 + for state in kaldifst.StateIterator(HL): + for arc in kaldifst.ArcIterator(HL, state): + # If treat_ilabel_zero_specially is False, we always change it + # Otherwise, we only change non-zero input labels + if disambig0 <= arc.ilabel <= max_disambig: + arc.ilabel = 0 + + # Note: We are not composing L with G, so there is no need to add + # self-loops to L to handle #0 + + return HL + + +def build_HLG( + H: kaldifst.StdVectorFst, + L: kaldifst.StdVectorFst, + G: kaldifst.StdVectorFst, + has_silence: bool, + lexicon: Lexicon, +) -> kaldifst.StdVectorFst: + if has_silence: + # We also need to change the input labels of L + add_one(L, treat_ilabel_zero_specially=True, update_olabel=False) + else: + add_one(L, treat_ilabel_zero_specially=False, update_olabel=False) + + # add-self-loops + token_disambig0 = lexicon.token2id["#0"] + 1 + word_disambig0 = lexicon.word2id["#0"] + + kaldifst.add_self_loops(L, isyms=[token_disambig0], osyms=[word_disambig0]) + + kaldifst.arcsort(L, sort_type="olabel") + kaldifst.arcsort(G, sort_type="ilabel") + LG = kaldifst.compose(L, G) + kaldifst.determinize_star(LG) + kaldifst.minimize_encoded(LG) + + kaldifst.arcsort(LG, sort_type="ilabel") + + # Invoke add_disambig_self_loops() so that it eats the disambig symbols + # from L after composition + add_disambig_self_loops( + H, + start=lexicon.token2id["#0"] + 1, + end=lexicon.max_disambig_id + 1, + ) + + kaldifst.arcsort(H, sort_type="olabel") + + HLG = kaldifst.compose(H, LG) + kaldifst.determinize_star(HLG) + + disambig0 = lexicon.token2id["#0"] + 1 + max_disambig = lexicon.max_disambig_id + 1 + for state in kaldifst.StateIterator(HLG): + for arc in kaldifst.ArcIterator(HLG, state): + # If treat_ilabel_zero_specially is False, we always change it + # Otherwise, we only change non-zero input labels + if disambig0 <= arc.ilabel <= max_disambig: + arc.ilabel = 0 + return HLG + + +def copy_fst(fst): + # Please don't use fst.copy() + return kaldifst.StdVectorFst(fst) + + def main(): args = get_args() lang_dir = args.lang_dir @@ -82,43 +188,29 @@ def main(): else: L = make_lexicon_fst_no_silence(lexicon, attach_symbol_table=False) - if args.has_silence: - # We also need to change the input labels of L - add_one(L, treat_ilabel_zero_specially=True, update_olabel=False) - else: - add_one(L, treat_ilabel_zero_specially=False, update_olabel=False) - - # Invoke add_disambig_self_loops() so that it eats the disambig symbols - # from L after composition - add_disambig_self_loops( - H, - start=lexicon.token2id["#0"] + 1, - end=lexicon.max_disambig_id + 1, - ) - with open("H_1.fst.txt", "w") as f: - print(H, file=f) - - kaldifst.arcsort(H, sort_type="olabel") - kaldifst.arcsort(L, sort_type="ilabel") - logging.info("Building HL") - HL = kaldifst.compose(H, L) - kaldifst.determinize_star(HL) - - disambig0 = lexicon.token2id["#0"] + 1 - max_disambig = lexicon.max_disambig_id + 1 - for state in kaldifst.StateIterator(HL): - for arc in kaldifst.ArcIterator(HL, state): - # If treat_ilabel_zero_specially is False, we always change it - # Otherwise, we only change non-zero input labels - if disambig0 <= arc.ilabel <= max_disambig: - arc.ilabel = 0 - - # Note: We are not composing L with G, so there is no need to add - # self-loops to L to handle #0 - + HL = build_HL( + H=copy_fst(H), + L=copy_fst(L), + has_silence=args.has_silence, + lexicon=lexicon, + ) HL.write(f"{lang_dir}/HL.fst") + if not args.ngram_G: + logging.info("Skip building HLG") + return + + logging.info("Building HLG") + with open(args.ngram_G) as f: + G = kaldifst.compile( + s=f.read(), + acceptor=False, + ) + + HLG = build_HLG(H=H, L=L, G=G, has_silence=args.has_silence, lexicon=lexicon) + HLG.write(f"{lang_dir}/HLG.fst") + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index fca2c6cc4..93d010ea8 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -244,7 +244,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then fi if [ ! -f $lang_dir/HL.fst ]; then - ./local/prepare_lang_fst.py --lang-dir $lang_dir + ./local/prepare_lang_fst.py --lang-dir $lang_dir --ngram-G ./data/lm/G_3_gram.fst.txt fi done fi