mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 14:14:19 +00:00
do some changes
This commit is contained in:
parent
194a4e6864
commit
35812fd6de
@ -28,7 +28,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir transducer_stateless/exp \
|
||||
--max-duration 200 \
|
||||
--max-duration 200
|
||||
```
|
||||
|
||||
The tensorboard training log can be found at
|
||||
@ -64,6 +64,8 @@ avg=16
|
||||
--exp-dir transducer_stateless/exp \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
```
|
||||
|
||||
A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_tedlium3_transducer_stateless>
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||
# Mingshuang Luo)
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# 2022 Xiaomi Crop. (authors: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -29,7 +29,7 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
|
||||
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
@ -83,7 +83,7 @@ def compute_fbank_tedlium():
|
||||
# when an executor is specified, make more partitions
|
||||
num_jobs=num_jobs if ex is None else 80,
|
||||
executor=ex,
|
||||
storage_type=LilcomHdf5Writer,
|
||||
storage_type=ChunkedLilcomHdf5Writer,
|
||||
)
|
||||
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
"""
|
||||
Convert a transcript based on words to a list of BPE ids.
|
||||
|
||||
@ -28,12 +28,6 @@ def get_args():
|
||||
parser.add_argument(
|
||||
"--texts", type=List[str], help="The input transcripts list."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--unk-id",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number id for the token '<unk>'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
@ -70,7 +64,7 @@ def convert_texts_into_ids(
|
||||
else:
|
||||
y_ids.extend(id_segments[i])
|
||||
else:
|
||||
y_ids = sp.encode([text], out_type=int)[0]
|
||||
y_ids = sp.encode(text, out_type=int)
|
||||
y.append(y_ids)
|
||||
|
||||
return y
|
||||
|
@ -59,9 +59,8 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "Stage -1: Download LM"
|
||||
# We assume that you have installed the git-lfs, if not, you could install it
|
||||
# using: `sudo apt-get install git-lfs && git-lfs install`
|
||||
[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm
|
||||
mkdir -p $dl_dir/lm
|
||||
git clone https://huggingface.co/luomingshuang/tedlium3_lm $dl_dir/lm
|
||||
cd $dl_dir/lm && git lfs pull
|
||||
|
||||
# If you want to download Tedlium 4 gram language models
|
||||
# using the follow commands:
|
||||
|
@ -175,13 +175,12 @@ class TedLiumAsrDataModule:
|
||||
|
||||
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "cuts_musan.json.gz"
|
||||
)
|
||||
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "cuts_musan.json.gz"
|
||||
)
|
||||
transforms.append(
|
||||
CutMix(
|
||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||
|
@ -1 +0,0 @@
|
||||
../../../librispeech/ASR/transducer_stateless/pretrained.py
|
343
egs/tedlium3/ASR/transducer_stateless/pretrained.py
Normal file
343
egs/tedlium3/ASR/transducer_stateless/pretrained.py
Normal file
@ -0,0 +1,343 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
# 2022 Xiaomi Crop. (authors: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
|
||||
(1) greedy search
|
||||
./transducer_stateless/pretrained.py \
|
||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method greedy_search \
|
||||
--max-sym-per-frame 1 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav \
|
||||
|
||||
(2) beam search
|
||||
./transducer_stateless/pretrained.py \
|
||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav \
|
||||
|
||||
(3) modified beam search
|
||||
./transducer_stateless/pretrained.py \
|
||||
--checkpoint ./transducer_stateless/exp/pretrained.pt \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav \
|
||||
|
||||
You can also use `./transducer_stateless/exp/epoch-xx.pt`.
|
||||
|
||||
Note: ./transducer_stateless/exp/pretrained.pt is generated by
|
||||
./transducer_stateless/export.py
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from beam_search import beam_search, greedy_search, modified_beam_search
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from model import Transducer
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from icefall.env import get_env_info
|
||||
from icefall.utils import AttributeDict
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.
|
||||
Used only when method is ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Used only when --method is beam_search and modified_beam_search ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=3,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"sample_rate": 16000,
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"encoder_out_dim": 512,
|
||||
"subsampling_factor": 4,
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
"dim_feedforward": 2048,
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
output_dim=params.encoder_out_dim,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
d_model=params.attention_dim,
|
||||
nhead=params.nhead,
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
)
|
||||
return encoder
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
decoder = Decoder(
|
||||
vocab_size=params.vocab_size,
|
||||
embedding_dim=params.encoder_out_dim,
|
||||
blank_id=params.blank_id,
|
||||
unk_id=params.unk_id,
|
||||
context_size=params.context_size,
|
||||
)
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.encoder_out_dim,
|
||||
output_dim=params.vocab_size,
|
||||
)
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
model = Transducer(
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. "
|
||||
f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features, batch_first=True, padding_value=math.log(1e-10)
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features, x_lens=feature_lengths
|
||||
)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
if params.method == "beam_search":
|
||||
msg += f" with beam size {params.beam_size}"
|
||||
logging.info(msg)
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
)
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp = modified_beam_search(
|
||||
model=model, encoder_out=encoder_out_i, beam=params.beam_size
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -26,7 +26,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir transducer_stateless/exp \
|
||||
--max-duration 200 \
|
||||
--max-duration 200
|
||||
"""
|
||||
|
||||
|
||||
@ -394,9 +394,9 @@ def compute_loss(
|
||||
feature = feature.to(device)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)[: feature.size(0)]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
texts = batch["supervisions"]["text"][: feature.size(0)]
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
unk_id = params.unk_id
|
||||
y = convert_texts_into_ids(texts, unk_id, sp=sp)
|
||||
@ -625,7 +625,9 @@ def run(rank, world_size, args):
|
||||
train_cuts = tedlium.train_cuts()
|
||||
|
||||
def remove_short_and_long_utt(c: Cut):
|
||||
# Keep only utterances with duration between 1 second and 20 seconds
|
||||
# Keep only utterances with duration between 1 second and max seconds
|
||||
# Here, we set max as 20.0.
|
||||
# If you want to use a big max-duration, you can set it as 17.0.
|
||||
return 1.0 <= c.duration <= 20.0
|
||||
|
||||
num_in_total = len(train_cuts)
|
||||
|
Loading…
x
Reference in New Issue
Block a user