mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Use torchaudio to extract features.
This commit is contained in:
parent
0fa4875a9a
commit
f731996abe
@ -10,11 +10,8 @@ You need to prepare 4 files:
|
||||
Supported formats are those supported by `torchaudio.load()`,
|
||||
e.g., wav and flac.
|
||||
|
||||
Also, you need to install `kaldifeat`. Please refer to
|
||||
<https://github.com/csukuangfj/kaldifeat> for installation.
|
||||
|
||||
Once you have the above files ready and have `kaldifeat` installed,
|
||||
you can run:
|
||||
Once you have the above files ready, you can run:
|
||||
|
||||
```
|
||||
./conformer_ctc/pretrained.py \
|
||||
|
@ -4,15 +4,11 @@ import argparse
|
||||
import logging
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import torch
|
||||
import torchaudio
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
one_best_decoding,
|
||||
)
|
||||
from icefall.decode import get_lattice, one_best_decoding
|
||||
from icefall.utils import AttributeDict, get_texts
|
||||
|
||||
|
||||
@ -60,6 +56,7 @@ def get_params() -> AttributeDict:
|
||||
"feature_dim": 80,
|
||||
"nhead": 8,
|
||||
"num_classes": 5000,
|
||||
"sample_freq": 16000,
|
||||
"attention_dim": 512,
|
||||
"subsampling_factor": 4,
|
||||
"num_decoder_layers": 6,
|
||||
@ -112,19 +109,17 @@ def main():
|
||||
|
||||
model.to(device)
|
||||
|
||||
wave, samp_freq = torchaudio.load(params.sound_file)
|
||||
wave = wave.squeeze().to(device)
|
||||
wave, sample_freq = torchaudio.load(params.sound_file)
|
||||
assert sample_freq == params.sample_freq
|
||||
wave = wave.to(device)
|
||||
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = samp_freq
|
||||
opts.mel_opts.num_bins = 80
|
||||
features = torchaudio.compliance.kaldi.fbank(
|
||||
waveform=wave,
|
||||
num_mel_bins=params.feature_dim,
|
||||
snip_edges=False,
|
||||
sample_frequency=params.sample_freq,
|
||||
)
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
features = fbank(wave)
|
||||
features = features.unsqueeze(0)
|
||||
|
||||
nnet_output, _, _ = model(features)
|
||||
|
Loading…
x
Reference in New Issue
Block a user