Use torchaudio to extract features.

This commit is contained in:
Fangjun Kuang 2021-08-18 19:31:06 +08:00
parent 0fa4875a9a
commit f731996abe
2 changed files with 12 additions and 20 deletions

View File

@ -10,11 +10,8 @@ You need to prepare 4 files:
Supported formats are those supported by `torchaudio.load()`, Supported formats are those supported by `torchaudio.load()`,
e.g., wav and flac. e.g., wav and flac.
Also, you need to install `kaldifeat`. Please refer to
<https://github.com/csukuangfj/kaldifeat> for installation.
Once you have the above files ready and have `kaldifeat` installed, Once you have the above files ready, you can run:
you can run:
``` ```
./conformer_ctc/pretrained.py \ ./conformer_ctc/pretrained.py \

View File

@ -4,15 +4,11 @@ import argparse
import logging import logging
import k2 import k2
import kaldifeat
import torch import torch
import torchaudio import torchaudio
from conformer import Conformer from conformer import Conformer
from icefall.decode import ( from icefall.decode import get_lattice, one_best_decoding
get_lattice,
one_best_decoding,
)
from icefall.utils import AttributeDict, get_texts from icefall.utils import AttributeDict, get_texts
@ -60,6 +56,7 @@ def get_params() -> AttributeDict:
"feature_dim": 80, "feature_dim": 80,
"nhead": 8, "nhead": 8,
"num_classes": 5000, "num_classes": 5000,
"sample_freq": 16000,
"attention_dim": 512, "attention_dim": 512,
"subsampling_factor": 4, "subsampling_factor": 4,
"num_decoder_layers": 6, "num_decoder_layers": 6,
@ -112,19 +109,17 @@ def main():
model.to(device) model.to(device)
wave, samp_freq = torchaudio.load(params.sound_file) wave, sample_freq = torchaudio.load(params.sound_file)
wave = wave.squeeze().to(device) assert sample_freq == params.sample_freq
wave = wave.to(device)
opts = kaldifeat.FbankOptions() features = torchaudio.compliance.kaldi.fbank(
opts.device = device waveform=wave,
opts.frame_opts.dither = 0 num_mel_bins=params.feature_dim,
opts.frame_opts.snip_edges = False snip_edges=False,
opts.frame_opts.samp_freq = samp_freq sample_frequency=params.sample_freq,
opts.mel_opts.num_bins = 80 )
fbank = kaldifeat.Fbank(opts)
features = fbank(wave)
features = features.unsqueeze(0) features = features.unsqueeze(0)
nnet_output, _, _ = model(features) nnet_output, _, _ = model(features)