Support decoding multiple files at the same time.

Also, use kaldifeat for feature extraction.
This commit is contained in:
Fangjun Kuang 2021-08-18 21:20:42 +08:00
parent f731996abe
commit a73d3ed917
2 changed files with 81 additions and 19 deletions

View File

@ -10,15 +10,30 @@ You need to prepare 4 files:
Supported formats are those supported by `torchaudio.load()`, Supported formats are those supported by `torchaudio.load()`,
e.g., wav and flac. e.g., wav and flac.
Also, you need to install `kaldifeat`. Please refer to
<https://github.com/csukuangfj/kaldifeat> for installation.
Once you have the above files ready, you can run: Once you have the above files ready and have `kaldifeat` installed,
you can run:
``` ```
./conformer_ctc/pretrained.py \ ./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \ --checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \ --words-file /path/to/words.txt \
--hlg /path/to/HLG.pt \ --hlg /path/to/HLG.pt \
--sound-file /path/to/your/sound.wav /path/to/your/sound.wav
``` ```
and you will see the transcribed result. and you will see the transcribed result.
If you want to transcribe multiple files at the same time, you can use:
```
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--hlg /path/to/HLG.pt \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav \
```

View File

@ -2,11 +2,15 @@
import argparse import argparse
import logging import logging
import math
from typing import List
import k2 import k2
import kaldifeat
import torch import torch
import torchaudio import torchaudio
from conformer import Conformer from conformer import Conformer
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_texts from icefall.utils import AttributeDict, get_texts
@ -38,10 +42,10 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--sound-file", "sound_files",
type=str, type=str,
required=True, nargs="+",
help="The input sound file to transcribe. " help="The input sound file(s) to transcribe. "
"Supported formats are those that supported by torchaudio.load(). " "Supported formats are those that supported by torchaudio.load(). "
"For example, wav, flac are supported. " "For example, wav, flac are supported. "
"The sample rate has to be 16kHz.", "The sample rate has to be 16kHz.",
@ -56,7 +60,7 @@ def get_params() -> AttributeDict:
"feature_dim": 80, "feature_dim": 80,
"nhead": 8, "nhead": 8,
"num_classes": 5000, "num_classes": 5000,
"sample_freq": 16000, "sample_rate": 16000,
"attention_dim": 512, "attention_dim": 512,
"subsampling_factor": 4, "subsampling_factor": 4,
"num_decoder_layers": 6, "num_decoder_layers": 6,
@ -74,6 +78,30 @@ def get_params() -> AttributeDict:
return params return params
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def main(): def main():
parser = get_parser() parser = get_parser()
args = parser.parse_args() args = parser.parse_args()
@ -87,6 +115,7 @@ def main():
logging.info(f"device: {device}") logging.info(f"device: {device}")
logging.info("Create model")
model = Conformer( model = Conformer(
num_features=params.feature_dim, num_features=params.feature_dim,
nhead=params.nhead, nhead=params.nhead,
@ -103,28 +132,39 @@ def main():
checkpoint = torch.load(args.checkpoint, map_location="cpu") checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"]) model.load_state_dict(checkpoint["model"])
model.to(device) model.to(device)
model.eval()
HLG = k2.Fsa.from_dict(torch.load(params.hlg)) HLG = k2.Fsa.from_dict(torch.load(params.hlg))
HLG = HLG.to(device) HLG = HLG.to(device)
model.to(device) opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
wave, sample_freq = torchaudio.load(params.sound_file) fbank = kaldifeat.Fbank(opts)
assert sample_freq == params.sample_freq
wave = wave.to(device)
features = torchaudio.compliance.kaldi.fbank( waves = read_sound_files(
waveform=wave, filenames=params.sound_files, expected_sample_rate=params.sample_rate
num_mel_bins=params.feature_dim, )
snip_edges=False, waves = [w.to(device) for w in waves]
sample_frequency=params.sample_freq,
logging.info(f"Decoding started")
features = fbank(waves)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
) )
features = features.unsqueeze(0) with torch.no_grad():
nnet_output, _, _ = model(features) nnet_output, _, _ = model(features)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor( supervision_segments = torch.tensor(
[[0, 0, nnet_output.shape[1]]], dtype=torch.int32 [[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
) )
lattice = get_lattice( lattice = get_lattice(
@ -145,7 +185,14 @@ def main():
hyps = get_texts(best_path) hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file) word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps] hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
logging.info(hyps)
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info(f"Decoding Done")
if __name__ == "__main__": if __name__ == "__main__":