do some changes

This commit is contained in:
luomingshuang 2022-03-03 16:28:51 +08:00
parent 194a4e6864
commit 35812fd6de
7 changed files with 363 additions and 25 deletions

View File

@ -28,7 +28,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer_stateless/exp \
--max-duration 200 \
--max-duration 200
```
The tensorboard training log can be found at
@ -64,6 +64,8 @@ avg=16
--exp-dir transducer_stateless/exp \
--bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100 \
--decoding-method beam_search \
--decoding-method modified_beam_search \
--beam-size 4
```
A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_tedlium3_transducer_stateless>

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# 2022 Xiaomi Crop. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -29,7 +29,7 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer
from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
@ -83,7 +83,7 @@ def compute_fbank_tedlium():
# when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80,
executor=ex,
storage_type=LilcomHdf5Writer,
storage_type=ChunkedLilcomHdf5Writer,
)
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Mingshuang Luo)
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
"""
Convert a transcript based on words to a list of BPE ids.
@ -28,12 +28,6 @@ def get_args():
parser.add_argument(
"--texts", type=List[str], help="The input transcripts list."
)
parser.add_argument(
"--unk-id",
type=int,
default=2,
help="The number id for the token '<unk>'.",
)
parser.add_argument(
"--bpe-model",
type=str,
@ -70,7 +64,7 @@ def convert_texts_into_ids(
else:
y_ids.extend(id_segments[i])
else:
y_ids = sp.encode([text], out_type=int)[0]
y_ids = sp.encode(text, out_type=int)
y.append(y_ids)
return y

View File

@ -59,9 +59,8 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: Download LM"
# We assume that you have installed the git-lfs, if not, you could install it
# using: `sudo apt-get install git-lfs && git-lfs install`
[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm
mkdir -p $dl_dir/lm
git clone https://huggingface.co/luomingshuang/tedlium3_lm $dl_dir/lm
cd $dl_dir/lm && git lfs pull
# If you want to download Tedlium 4 gram language models
# using the follow commands:

View File

@ -175,13 +175,12 @@ class TedLiumAsrDataModule:
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "cuts_musan.json.gz"
)
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
cuts_musan = load_manifest(
self.args.manifest_dir / "cuts_musan.json.gz"
)
transforms.append(
CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True

View File

@ -1 +0,0 @@
../../../librispeech/ASR/transducer_stateless/pretrained.py

View File

@ -0,0 +1,343 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# 2022 Xiaomi Crop. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \
/path/to/bar.wav \
(2) beam search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
(3) modified beam search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
You can also use `./transducer_stateless/exp/epoch-xx.pt`.
Note: ./transducer_stateless/exp/pretrained.pt is generated by
./transducer_stateless/export.py
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torch.nn as nn
import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info
from icefall.utils import AttributeDict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search ",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sample_rate": 16000,
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
unk_id=params.unk_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
with torch.no_grad():
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -26,7 +26,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer_stateless/exp \
--max-duration 200 \
--max-duration 200
"""
@ -394,9 +394,9 @@ def compute_loss(
feature = feature.to(device)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)[: feature.size(0)]
feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"][: feature.size(0)]
texts = batch["supervisions"]["text"]
unk_id = params.unk_id
y = convert_texts_into_ids(texts, unk_id, sp=sp)
@ -625,7 +625,9 @@ def run(rank, world_size, args):
train_cuts = tedlium.train_cuts()
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
# Keep only utterances with duration between 1 second and max seconds
# Here, we set max as 20.0.
# If you want to use a big max-duration, you can set it as 17.0.
return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts)