From f731996abe2ac07c9acabb1b75fda569cd71ffdf Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 18 Aug 2021 19:31:06 +0800 Subject: [PATCH] Use torchaudio to extract features. --- egs/librispeech/ASR/conformer_ctc/README.md | 5 +--- .../ASR/conformer_ctc/pretrained.py | 27 ++++++++----------- 2 files changed, 12 insertions(+), 20 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/README.md b/egs/librispeech/ASR/conformer_ctc/README.md index b9e9b7f52..8f43dac34 100644 --- a/egs/librispeech/ASR/conformer_ctc/README.md +++ b/egs/librispeech/ASR/conformer_ctc/README.md @@ -10,11 +10,8 @@ You need to prepare 4 files: Supported formats are those supported by `torchaudio.load()`, e.g., wav and flac. -Also, you need to install `kaldifeat`. Please refer to - for installation. -Once you have the above files ready and have `kaldifeat` installed, -you can run: +Once you have the above files ready, you can run: ``` ./conformer_ctc/pretrained.py \ diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index d16b07166..c02c2cca7 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -4,15 +4,11 @@ import argparse import logging import k2 -import kaldifeat import torch import torchaudio from conformer import Conformer -from icefall.decode import ( - get_lattice, - one_best_decoding, -) +from icefall.decode import get_lattice, one_best_decoding from icefall.utils import AttributeDict, get_texts @@ -60,6 +56,7 @@ def get_params() -> AttributeDict: "feature_dim": 80, "nhead": 8, "num_classes": 5000, + "sample_freq": 16000, "attention_dim": 512, "subsampling_factor": 4, "num_decoder_layers": 6, @@ -112,19 +109,17 @@ def main(): model.to(device) - wave, samp_freq = torchaudio.load(params.sound_file) - wave = wave.squeeze().to(device) + wave, sample_freq = torchaudio.load(params.sound_file) + assert sample_freq == params.sample_freq + wave = wave.to(device) - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = samp_freq - opts.mel_opts.num_bins = 80 + features = torchaudio.compliance.kaldi.fbank( + waveform=wave, + num_mel_bins=params.feature_dim, + snip_edges=False, + sample_frequency=params.sample_freq, + ) - fbank = kaldifeat.Fbank(opts) - - features = fbank(wave) features = features.unsqueeze(0) nnet_output, _, _ = model(features)