mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Support decoding multiple files at the same time.
Also, use kaldifeat for feature extraction.
This commit is contained in:
parent
f731996abe
commit
a73d3ed917
@ -10,15 +10,30 @@ You need to prepare 4 files:
|
||||
Supported formats are those supported by `torchaudio.load()`,
|
||||
e.g., wav and flac.
|
||||
|
||||
Also, you need to install `kaldifeat`. Please refer to
|
||||
<https://github.com/csukuangfj/kaldifeat> for installation.
|
||||
|
||||
Once you have the above files ready, you can run:
|
||||
Once you have the above files ready and have `kaldifeat` installed,
|
||||
you can run:
|
||||
|
||||
```
|
||||
./conformer_ctc/pretrained.py \
|
||||
--checkpoint /path/to/your/checkpoint.pt \
|
||||
--words-file /path/to/words.txt \
|
||||
--hlg /path/to/HLG.pt \
|
||||
--sound-file /path/to/your/sound.wav
|
||||
/path/to/your/sound.wav
|
||||
```
|
||||
|
||||
and you will see the transcribed result.
|
||||
|
||||
If you want to transcribe multiple files at the same time, you can use:
|
||||
|
||||
```
|
||||
./conformer_ctc/pretrained.py \
|
||||
--checkpoint /path/to/your/checkpoint.pt \
|
||||
--words-file /path/to/words.txt \
|
||||
--hlg /path/to/HLG.pt \
|
||||
/path/to/your/sound1.wav \
|
||||
/path/to/your/sound2.wav \
|
||||
/path/to/your/sound3.wav \
|
||||
```
|
||||
|
@ -2,11 +2,15 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from conformer import Conformer
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.decode import get_lattice, one_best_decoding
|
||||
from icefall.utils import AttributeDict, get_texts
|
||||
@ -38,10 +42,10 @@ def get_parser():
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sound-file",
|
||||
"sound_files",
|
||||
type=str,
|
||||
required=True,
|
||||
help="The input sound file to transcribe. "
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those that supported by torchaudio.load(). "
|
||||
"For example, wav, flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
@ -56,7 +60,7 @@ def get_params() -> AttributeDict:
|
||||
"feature_dim": 80,
|
||||
"nhead": 8,
|
||||
"num_classes": 5000,
|
||||
"sample_freq": 16000,
|
||||
"sample_rate": 16000,
|
||||
"attention_dim": 512,
|
||||
"subsampling_factor": 4,
|
||||
"num_decoder_layers": 6,
|
||||
@ -74,6 +78,30 @@ def get_params() -> AttributeDict:
|
||||
return params
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
@ -87,6 +115,7 @@ def main():
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Create model")
|
||||
model = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
nhead=params.nhead,
|
||||
@ -103,28 +132,39 @@ def main():
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"])
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.hlg))
|
||||
HLG = HLG.to(device)
|
||||
|
||||
model.to(device)
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
wave, sample_freq = torchaudio.load(params.sound_file)
|
||||
assert sample_freq == params.sample_freq
|
||||
wave = wave.to(device)
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
features = torchaudio.compliance.kaldi.fbank(
|
||||
waveform=wave,
|
||||
num_mel_bins=params.feature_dim,
|
||||
snip_edges=False,
|
||||
sample_frequency=params.sample_freq,
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info(f"Decoding started")
|
||||
features = fbank(waves)
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
|
||||
features = features.unsqueeze(0)
|
||||
with torch.no_grad():
|
||||
nnet_output, _, _ = model(features)
|
||||
|
||||
nnet_output, _, _ = model(features)
|
||||
batch_size = nnet_output.shape[0]
|
||||
supervision_segments = torch.tensor(
|
||||
[[0, 0, nnet_output.shape[1]]], dtype=torch.int32
|
||||
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
lattice = get_lattice(
|
||||
@ -145,7 +185,14 @@ def main():
|
||||
hyps = get_texts(best_path)
|
||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||
logging.info(hyps)
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info(f"Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user