Support decoding multiple files at the same time.

Also, use kaldifeat for feature extraction.
This commit is contained in:
Fangjun Kuang 2021-08-18 21:20:42 +08:00
parent f731996abe
commit a73d3ed917
2 changed files with 81 additions and 19 deletions

View File

@ -10,15 +10,30 @@ You need to prepare 4 files:
Supported formats are those supported by `torchaudio.load()`,
e.g., wav and flac.
Also, you need to install `kaldifeat`. Please refer to
<https://github.com/csukuangfj/kaldifeat> for installation.
Once you have the above files ready, you can run:
Once you have the above files ready and have `kaldifeat` installed,
you can run:
```
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--hlg /path/to/HLG.pt \
--sound-file /path/to/your/sound.wav
/path/to/your/sound.wav
```
and you will see the transcribed result.
If you want to transcribe multiple files at the same time, you can use:
```
./conformer_ctc/pretrained.py \
--checkpoint /path/to/your/checkpoint.pt \
--words-file /path/to/words.txt \
--hlg /path/to/HLG.pt \
/path/to/your/sound1.wav \
/path/to/your/sound2.wav \
/path/to/your/sound3.wav \
```

View File

@ -2,11 +2,15 @@
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from conformer import Conformer
from torch.nn.utils.rnn import pad_sequence
from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_texts
@ -38,10 +42,10 @@ def get_parser():
)
parser.add_argument(
"--sound-file",
"sound_files",
type=str,
required=True,
help="The input sound file to transcribe. "
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those that supported by torchaudio.load(). "
"For example, wav, flac are supported. "
"The sample rate has to be 16kHz.",
@ -56,7 +60,7 @@ def get_params() -> AttributeDict:
"feature_dim": 80,
"nhead": 8,
"num_classes": 5000,
"sample_freq": 16000,
"sample_rate": 16000,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
@ -74,6 +78,30 @@ def get_params() -> AttributeDict:
return params
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
def main():
parser = get_parser()
args = parser.parse_args()
@ -87,6 +115,7 @@ def main():
logging.info(f"device: {device}")
logging.info("Create model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
@ -103,28 +132,39 @@ def main():
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"])
model.to(device)
model.eval()
HLG = k2.Fsa.from_dict(torch.load(params.hlg))
HLG = HLG.to(device)
model.to(device)
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
wave, sample_freq = torchaudio.load(params.sound_file)
assert sample_freq == params.sample_freq
wave = wave.to(device)
fbank = kaldifeat.Fbank(opts)
features = torchaudio.compliance.kaldi.fbank(
waveform=wave,
num_mel_bins=params.feature_dim,
snip_edges=False,
sample_frequency=params.sample_freq,
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info(f"Decoding started")
features = fbank(waves)
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
features = features.unsqueeze(0)
with torch.no_grad():
nnet_output, _, _ = model(features)
batch_size = nnet_output.shape[0]
supervision_segments = torch.tensor(
[[0, 0, nnet_output.shape[1]]], dtype=torch.int32
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
lattice = get_lattice(
@ -145,7 +185,14 @@ def main():
hyps = get_texts(best_path)
word_sym_table = k2.SymbolTable.from_file(params.words_file)
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
logging.info(hyps)
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info(f"Decoding Done")
if __name__ == "__main__":