mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Use torchaudio to extract features.
This commit is contained in:
parent
0fa4875a9a
commit
f731996abe
@ -10,11 +10,8 @@ You need to prepare 4 files:
|
|||||||
Supported formats are those supported by `torchaudio.load()`,
|
Supported formats are those supported by `torchaudio.load()`,
|
||||||
e.g., wav and flac.
|
e.g., wav and flac.
|
||||||
|
|
||||||
Also, you need to install `kaldifeat`. Please refer to
|
|
||||||
<https://github.com/csukuangfj/kaldifeat> for installation.
|
|
||||||
|
|
||||||
Once you have the above files ready and have `kaldifeat` installed,
|
Once you have the above files ready, you can run:
|
||||||
you can run:
|
|
||||||
|
|
||||||
```
|
```
|
||||||
./conformer_ctc/pretrained.py \
|
./conformer_ctc/pretrained.py \
|
||||||
|
@ -4,15 +4,11 @@ import argparse
|
|||||||
import logging
|
import logging
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import kaldifeat
|
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
|
|
||||||
from icefall.decode import (
|
from icefall.decode import get_lattice, one_best_decoding
|
||||||
get_lattice,
|
|
||||||
one_best_decoding,
|
|
||||||
)
|
|
||||||
from icefall.utils import AttributeDict, get_texts
|
from icefall.utils import AttributeDict, get_texts
|
||||||
|
|
||||||
|
|
||||||
@ -60,6 +56,7 @@ def get_params() -> AttributeDict:
|
|||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"nhead": 8,
|
"nhead": 8,
|
||||||
"num_classes": 5000,
|
"num_classes": 5000,
|
||||||
|
"sample_freq": 16000,
|
||||||
"attention_dim": 512,
|
"attention_dim": 512,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"num_decoder_layers": 6,
|
"num_decoder_layers": 6,
|
||||||
@ -112,19 +109,17 @@ def main():
|
|||||||
|
|
||||||
model.to(device)
|
model.to(device)
|
||||||
|
|
||||||
wave, samp_freq = torchaudio.load(params.sound_file)
|
wave, sample_freq = torchaudio.load(params.sound_file)
|
||||||
wave = wave.squeeze().to(device)
|
assert sample_freq == params.sample_freq
|
||||||
|
wave = wave.to(device)
|
||||||
|
|
||||||
opts = kaldifeat.FbankOptions()
|
features = torchaudio.compliance.kaldi.fbank(
|
||||||
opts.device = device
|
waveform=wave,
|
||||||
opts.frame_opts.dither = 0
|
num_mel_bins=params.feature_dim,
|
||||||
opts.frame_opts.snip_edges = False
|
snip_edges=False,
|
||||||
opts.frame_opts.samp_freq = samp_freq
|
sample_frequency=params.sample_freq,
|
||||||
opts.mel_opts.num_bins = 80
|
)
|
||||||
|
|
||||||
fbank = kaldifeat.Fbank(opts)
|
|
||||||
|
|
||||||
features = fbank(wave)
|
|
||||||
features = features.unsqueeze(0)
|
features = features.unsqueeze(0)
|
||||||
|
|
||||||
nnet_output, _, _ = model(features)
|
nnet_output, _, _ = model(features)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user