Use torchaudio to extract features.

This commit is contained in:
Fangjun Kuang 2021-08-18 19:31:06 +08:00
parent 0fa4875a9a
commit f731996abe
2 changed files with 12 additions and 20 deletions

View File

@ -10,11 +10,8 @@ You need to prepare 4 files:
Supported formats are those supported by `torchaudio.load()`,
e.g., wav and flac.
Also, you need to install `kaldifeat`. Please refer to
<https://github.com/csukuangfj/kaldifeat> for installation.
Once you have the above files ready and have `kaldifeat` installed,
you can run:
Once you have the above files ready, you can run:
```
./conformer_ctc/pretrained.py \

View File

@ -4,15 +4,11 @@ import argparse
import logging
import k2
import kaldifeat
import torch
import torchaudio
from conformer import Conformer
from icefall.decode import (
get_lattice,
one_best_decoding,
)
from icefall.decode import get_lattice, one_best_decoding
from icefall.utils import AttributeDict, get_texts
@ -60,6 +56,7 @@ def get_params() -> AttributeDict:
"feature_dim": 80,
"nhead": 8,
"num_classes": 5000,
"sample_freq": 16000,
"attention_dim": 512,
"subsampling_factor": 4,
"num_decoder_layers": 6,
@ -112,19 +109,17 @@ def main():
model.to(device)
wave, samp_freq = torchaudio.load(params.sound_file)
wave = wave.squeeze().to(device)
wave, sample_freq = torchaudio.load(params.sound_file)
assert sample_freq == params.sample_freq
wave = wave.to(device)
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = samp_freq
opts.mel_opts.num_bins = 80
features = torchaudio.compliance.kaldi.fbank(
waveform=wave,
num_mel_bins=params.feature_dim,
snip_edges=False,
sample_frequency=params.sample_freq,
)
fbank = kaldifeat.Fbank(opts)
features = fbank(wave)
features = features.unsqueeze(0)
nnet_output, _, _ = model(features)