do some changes

This commit is contained in:
luomingshuang 2022-03-03 16:28:51 +08:00
parent 194a4e6864
commit 35812fd6de
7 changed files with 363 additions and 25 deletions

View File

@ -28,7 +28,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir transducer_stateless/exp \ --exp-dir transducer_stateless/exp \
--max-duration 200 \ --max-duration 200
``` ```
The tensorboard training log can be found at The tensorboard training log can be found at
@ -64,6 +64,8 @@ avg=16
--exp-dir transducer_stateless/exp \ --exp-dir transducer_stateless/exp \
--bpe-model ./data/lang_bpe_500/bpe.model \ --bpe-model ./data/lang_bpe_500/bpe.model \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
``` ```
A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_tedlium3_transducer_stateless>

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang # Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Mingshuang Luo) # 2022 Xiaomi Crop. (authors: Mingshuang Luo)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -29,7 +29,7 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor from icefall.utils import get_executor
@ -83,7 +83,7 @@ def compute_fbank_tedlium():
# when an executor is specified, make more partitions # when an executor is specified, make more partitions
num_jobs=num_jobs if ex is None else 80, num_jobs=num_jobs if ex is None else 80,
executor=ex, executor=ex,
storage_type=LilcomHdf5Writer, storage_type=ChunkedLilcomHdf5Writer,
) )
cut_set.to_json(output_dir / f"cuts_{partition}.json.gz") cut_set.to_json(output_dir / f"cuts_{partition}.json.gz")

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Mingshuang Luo) # Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
""" """
Convert a transcript based on words to a list of BPE ids. Convert a transcript based on words to a list of BPE ids.
@ -28,12 +28,6 @@ def get_args():
parser.add_argument( parser.add_argument(
"--texts", type=List[str], help="The input transcripts list." "--texts", type=List[str], help="The input transcripts list."
) )
parser.add_argument(
"--unk-id",
type=int,
default=2,
help="The number id for the token '<unk>'.",
)
parser.add_argument( parser.add_argument(
"--bpe-model", "--bpe-model",
type=str, type=str,
@ -70,7 +64,7 @@ def convert_texts_into_ids(
else: else:
y_ids.extend(id_segments[i]) y_ids.extend(id_segments[i])
else: else:
y_ids = sp.encode([text], out_type=int)[0] y_ids = sp.encode(text, out_type=int)
y.append(y_ids) y.append(y_ids)
return y return y

View File

@ -59,9 +59,8 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
log "Stage -1: Download LM" log "Stage -1: Download LM"
# We assume that you have installed the git-lfs, if not, you could install it # We assume that you have installed the git-lfs, if not, you could install it
# using: `sudo apt-get install git-lfs && git-lfs install` # using: `sudo apt-get install git-lfs && git-lfs install`
[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm mkdir -p $dl_dir/lm
git clone https://huggingface.co/luomingshuang/tedlium3_lm $dl_dir/lm git clone https://huggingface.co/luomingshuang/tedlium3_lm $dl_dir/lm
cd $dl_dir/lm && git lfs pull
# If you want to download Tedlium 4 gram language models # If you want to download Tedlium 4 gram language models
# using the follow commands: # using the follow commands:

View File

@ -175,13 +175,12 @@ class TedLiumAsrDataModule:
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
logging.info("About to get Musan cuts") logging.info("About to get Musan cuts")
cuts_musan = load_manifest(
self.args.manifest_dir / "cuts_musan.json.gz"
)
transforms = [] transforms = []
if self.args.enable_musan: if self.args.enable_musan:
logging.info("Enable MUSAN") logging.info("Enable MUSAN")
cuts_musan = load_manifest(
self.args.manifest_dir / "cuts_musan.json.gz"
)
transforms.append( transforms.append(
CutMix( CutMix(
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True

View File

@ -1 +0,0 @@
../../../librispeech/ASR/transducer_stateless/pretrained.py

View File

@ -0,0 +1,343 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# 2022 Xiaomi Crop. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \
/path/to/bar.wav \
(2) beam search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
(3) modified beam search
./transducer_stateless/pretrained.py \
--checkpoint ./transducer_stateless/exp/pretrained.pt \
--bpe-model ./data/lang_bpe_500/bpe.model \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav \
You can also use `./transducer_stateless/exp/epoch-xx.pt`.
Note: ./transducer_stateless/exp/pretrained.pt is generated by
./transducer_stateless/export.py
"""
import argparse
import logging
import math
from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torch.nn as nn
import torchaudio
from beam_search import beam_search, greedy_search, modified_beam_search
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from model import Transducer
from torch.nn.utils.rnn import pad_sequence
from icefall.env import get_env_info
from icefall.utils import AttributeDict
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--bpe-model",
type=str,
help="""Path to bpe.model.
Used only when method is ctc-decoding.
""",
)
parser.add_argument(
"--method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search ",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=3,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"sample_rate": 16000,
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"attention_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
"vgg_frontend": False,
"env_info": get_env_info(),
}
)
return params
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Conformer(
num_features=params.feature_dim,
output_dim=params.encoder_out_dim,
subsampling_factor=params.subsampling_factor,
d_model=params.attention_dim,
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
)
return encoder
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
blank_id=params.blank_id,
unk_id=params.unk_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
model = Transducer(
encoder=encoder,
decoder=decoder,
joiner=joiner,
)
return model
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
with torch.no_grad():
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
num_waves = encoder_out.size(0)
hyps = []
msg = f"Using {params.method}"
if params.method == "beam_search":
msg += f" with beam size {params.beam_size}"
logging.info(msg)
for i in range(num_waves):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.method == "beam_search":
hyp = beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
elif params.method == "modified_beam_search":
hyp = modified_beam_search(
model=model, encoder_out=encoder_out_i, beam=params.beam_size
)
else:
raise ValueError(f"Unsupported method: {params.method}")
hyps.append(sp.decode(hyp).split())
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -26,7 +26,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir transducer_stateless/exp \ --exp-dir transducer_stateless/exp \
--max-duration 200 \ --max-duration 200
""" """
@ -394,9 +394,9 @@ def compute_loss(
feature = feature.to(device) feature = feature.to(device)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)[: feature.size(0)] feature_lens = supervisions["num_frames"].to(device)
texts = batch["supervisions"]["text"][: feature.size(0)] texts = batch["supervisions"]["text"]
unk_id = params.unk_id unk_id = params.unk_id
y = convert_texts_into_ids(texts, unk_id, sp=sp) y = convert_texts_into_ids(texts, unk_id, sp=sp)
@ -625,7 +625,9 @@ def run(rank, world_size, args):
train_cuts = tedlium.train_cuts() train_cuts = tedlium.train_cuts()
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and max seconds
# Here, we set max as 20.0.
# If you want to use a big max-duration, you can set it as 17.0.
return 1.0 <= c.duration <= 20.0 return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts) num_in_total = len(train_cuts)