From a73d3ed91796b3c3f47b87859e1bb55b14577138 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 18 Aug 2021 21:20:42 +0800 Subject: [PATCH] Support decoding multiple files at the same time. Also, use kaldifeat for feature extraction. --- egs/librispeech/ASR/conformer_ctc/README.md | 19 ++++- .../ASR/conformer_ctc/pretrained.py | 81 +++++++++++++++---- 2 files changed, 81 insertions(+), 19 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/README.md b/egs/librispeech/ASR/conformer_ctc/README.md index 8f43dac34..f2fd18cd4 100644 --- a/egs/librispeech/ASR/conformer_ctc/README.md +++ b/egs/librispeech/ASR/conformer_ctc/README.md @@ -10,15 +10,30 @@ You need to prepare 4 files: Supported formats are those supported by `torchaudio.load()`, e.g., wav and flac. +Also, you need to install `kaldifeat`. Please refer to + for installation. -Once you have the above files ready, you can run: +Once you have the above files ready and have `kaldifeat` installed, +you can run: ``` ./conformer_ctc/pretrained.py \ --checkpoint /path/to/your/checkpoint.pt \ --words-file /path/to/words.txt \ --hlg /path/to/HLG.pt \ - --sound-file /path/to/your/sound.wav + /path/to/your/sound.wav ``` and you will see the transcribed result. + +If you want to transcribe multiple files at the same time, you can use: + +``` +./conformer_ctc/pretrained.py \ + --checkpoint /path/to/your/checkpoint.pt \ + --words-file /path/to/words.txt \ + --hlg /path/to/HLG.pt \ + /path/to/your/sound1.wav \ + /path/to/your/sound2.wav \ + /path/to/your/sound3.wav \ +``` diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index c02c2cca7..27d9ccc4c 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -2,11 +2,15 @@ import argparse import logging +import math +from typing import List import k2 +import kaldifeat import torch import torchaudio from conformer import Conformer +from torch.nn.utils.rnn import pad_sequence from icefall.decode import get_lattice, one_best_decoding from icefall.utils import AttributeDict, get_texts @@ -38,10 +42,10 @@ def get_parser(): ) parser.add_argument( - "--sound-file", + "sound_files", type=str, - required=True, - help="The input sound file to transcribe. " + nargs="+", + help="The input sound file(s) to transcribe. " "Supported formats are those that supported by torchaudio.load(). " "For example, wav, flac are supported. " "The sample rate has to be 16kHz.", @@ -56,7 +60,7 @@ def get_params() -> AttributeDict: "feature_dim": 80, "nhead": 8, "num_classes": 5000, - "sample_freq": 16000, + "sample_rate": 16000, "attention_dim": 512, "subsampling_factor": 4, "num_decoder_layers": 6, @@ -74,6 +78,30 @@ def get_params() -> AttributeDict: return params +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + def main(): parser = get_parser() args = parser.parse_args() @@ -87,6 +115,7 @@ def main(): logging.info(f"device: {device}") + logging.info("Create model") model = Conformer( num_features=params.feature_dim, nhead=params.nhead, @@ -103,28 +132,39 @@ def main(): checkpoint = torch.load(args.checkpoint, map_location="cpu") model.load_state_dict(checkpoint["model"]) model.to(device) + model.eval() HLG = k2.Fsa.from_dict(torch.load(params.hlg)) HLG = HLG.to(device) - model.to(device) + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim - wave, sample_freq = torchaudio.load(params.sound_file) - assert sample_freq == params.sample_freq - wave = wave.to(device) + fbank = kaldifeat.Fbank(opts) - features = torchaudio.compliance.kaldi.fbank( - waveform=wave, - num_mel_bins=params.feature_dim, - snip_edges=False, - sample_frequency=params.sample_freq, + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info(f"Decoding started") + features = fbank(waves) + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) ) - features = features.unsqueeze(0) + with torch.no_grad(): + nnet_output, _, _ = model(features) - nnet_output, _, _ = model(features) + batch_size = nnet_output.shape[0] supervision_segments = torch.tensor( - [[0, 0, nnet_output.shape[1]]], dtype=torch.int32 + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, ) lattice = get_lattice( @@ -145,7 +185,14 @@ def main(): hyps = get_texts(best_path) word_sym_table = k2.SymbolTable.from_file(params.words_file) hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - logging.info(hyps) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info(f"Decoding Done") if __name__ == "__main__":