mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Support decoding multiple files at the same time.
Also, use kaldifeat for feature extraction.
This commit is contained in:
parent
f731996abe
commit
a73d3ed917
@ -10,15 +10,30 @@ You need to prepare 4 files:
|
|||||||
Supported formats are those supported by `torchaudio.load()`,
|
Supported formats are those supported by `torchaudio.load()`,
|
||||||
e.g., wav and flac.
|
e.g., wav and flac.
|
||||||
|
|
||||||
|
Also, you need to install `kaldifeat`. Please refer to
|
||||||
|
<https://github.com/csukuangfj/kaldifeat> for installation.
|
||||||
|
|
||||||
Once you have the above files ready, you can run:
|
Once you have the above files ready and have `kaldifeat` installed,
|
||||||
|
you can run:
|
||||||
|
|
||||||
```
|
```
|
||||||
./conformer_ctc/pretrained.py \
|
./conformer_ctc/pretrained.py \
|
||||||
--checkpoint /path/to/your/checkpoint.pt \
|
--checkpoint /path/to/your/checkpoint.pt \
|
||||||
--words-file /path/to/words.txt \
|
--words-file /path/to/words.txt \
|
||||||
--hlg /path/to/HLG.pt \
|
--hlg /path/to/HLG.pt \
|
||||||
--sound-file /path/to/your/sound.wav
|
/path/to/your/sound.wav
|
||||||
```
|
```
|
||||||
|
|
||||||
and you will see the transcribed result.
|
and you will see the transcribed result.
|
||||||
|
|
||||||
|
If you want to transcribe multiple files at the same time, you can use:
|
||||||
|
|
||||||
|
```
|
||||||
|
./conformer_ctc/pretrained.py \
|
||||||
|
--checkpoint /path/to/your/checkpoint.pt \
|
||||||
|
--words-file /path/to/words.txt \
|
||||||
|
--hlg /path/to/HLG.pt \
|
||||||
|
/path/to/your/sound1.wav \
|
||||||
|
/path/to/your/sound2.wav \
|
||||||
|
/path/to/your/sound3.wav \
|
||||||
|
```
|
||||||
|
@ -2,11 +2,15 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
|
import kaldifeat
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
from icefall.decode import get_lattice, one_best_decoding
|
from icefall.decode import get_lattice, one_best_decoding
|
||||||
from icefall.utils import AttributeDict, get_texts
|
from icefall.utils import AttributeDict, get_texts
|
||||||
@ -38,10 +42,10 @@ def get_parser():
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--sound-file",
|
"sound_files",
|
||||||
type=str,
|
type=str,
|
||||||
required=True,
|
nargs="+",
|
||||||
help="The input sound file to transcribe. "
|
help="The input sound file(s) to transcribe. "
|
||||||
"Supported formats are those that supported by torchaudio.load(). "
|
"Supported formats are those that supported by torchaudio.load(). "
|
||||||
"For example, wav, flac are supported. "
|
"For example, wav, flac are supported. "
|
||||||
"The sample rate has to be 16kHz.",
|
"The sample rate has to be 16kHz.",
|
||||||
@ -56,7 +60,7 @@ def get_params() -> AttributeDict:
|
|||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
"num_classes": 5000,
|
"num_classes": 5000,
|
||||||
"sample_freq": 16000,
|
"sample_rate": 16000,
|
||||||
"attention_dim": 512,
|
"attention_dim": 512,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"num_decoder_layers": 6,
|
"num_decoder_layers": 6,
|
||||||
@ -74,6 +78,30 @@ def get_params() -> AttributeDict:
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def read_sound_files(
|
||||||
|
filenames: List[str], expected_sample_rate: float
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||||
|
Args:
|
||||||
|
filenames:
|
||||||
|
A list of sound filenames.
|
||||||
|
expected_sample_rate:
|
||||||
|
The expected sample rate of the sound files.
|
||||||
|
Returns:
|
||||||
|
Return a list of 1-D float32 torch tensors.
|
||||||
|
"""
|
||||||
|
ans = []
|
||||||
|
for f in filenames:
|
||||||
|
wave, sample_rate = torchaudio.load(f)
|
||||||
|
assert sample_rate == expected_sample_rate, (
|
||||||
|
f"expected sample rate: {expected_sample_rate}. "
|
||||||
|
f"Given: {sample_rate}"
|
||||||
|
)
|
||||||
|
# We use only the first channel
|
||||||
|
ans.append(wave[0])
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = get_parser()
|
parser = get_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@ -87,6 +115,7 @@ def main():
|
|||||||
|
|
||||||
logging.info(f"device: {device}")
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
logging.info("Create model")
|
||||||
model = Conformer(
|
model = Conformer(
|
||||||
num_features=params.feature_dim,
|
num_features=params.feature_dim,
|
||||||
nhead=params.nhead,
|
nhead=params.nhead,
|
||||||
@ -103,28 +132,39 @@ def main():
|
|||||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||||
model.load_state_dict(checkpoint["model"])
|
model.load_state_dict(checkpoint["model"])
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
HLG = k2.Fsa.from_dict(torch.load(params.hlg))
|
HLG = k2.Fsa.from_dict(torch.load(params.hlg))
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
|
|
||||||
model.to(device)
|
opts = kaldifeat.FbankOptions()
|
||||||
|
opts.device = device
|
||||||
|
opts.frame_opts.dither = 0
|
||||||
|
opts.frame_opts.snip_edges = False
|
||||||
|
opts.frame_opts.samp_freq = params.sample_rate
|
||||||
|
opts.mel_opts.num_bins = params.feature_dim
|
||||||
|
|
||||||
wave, sample_freq = torchaudio.load(params.sound_file)
|
fbank = kaldifeat.Fbank(opts)
|
||||||
assert sample_freq == params.sample_freq
|
|
||||||
wave = wave.to(device)
|
|
||||||
|
|
||||||
features = torchaudio.compliance.kaldi.fbank(
|
waves = read_sound_files(
|
||||||
waveform=wave,
|
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||||
num_mel_bins=params.feature_dim,
|
)
|
||||||
snip_edges=False,
|
waves = [w.to(device) for w in waves]
|
||||||
sample_frequency=params.sample_freq,
|
|
||||||
|
logging.info(f"Decoding started")
|
||||||
|
features = fbank(waves)
|
||||||
|
|
||||||
|
features = pad_sequence(
|
||||||
|
features, batch_first=True, padding_value=math.log(1e-10)
|
||||||
)
|
)
|
||||||
|
|
||||||
features = features.unsqueeze(0)
|
with torch.no_grad():
|
||||||
|
nnet_output, _, _ = model(features)
|
||||||
|
|
||||||
nnet_output, _, _ = model(features)
|
batch_size = nnet_output.shape[0]
|
||||||
supervision_segments = torch.tensor(
|
supervision_segments = torch.tensor(
|
||||||
[[0, 0, nnet_output.shape[1]]], dtype=torch.int32
|
[[i, 0, nnet_output.shape[1]] for i in range(batch_size)],
|
||||||
|
dtype=torch.int32,
|
||||||
)
|
)
|
||||||
|
|
||||||
lattice = get_lattice(
|
lattice = get_lattice(
|
||||||
@ -145,7 +185,14 @@ def main():
|
|||||||
hyps = get_texts(best_path)
|
hyps = get_texts(best_path)
|
||||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||||
logging.info(hyps)
|
|
||||||
|
s = "\n"
|
||||||
|
for filename, hyp in zip(params.sound_files, hyps):
|
||||||
|
words = " ".join(hyp)
|
||||||
|
s += f"{filename}:\n{words}\n\n"
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
logging.info(f"Decoding Done")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user