diff --git a/egs/tedlium3/ASR/RESULTS.md b/egs/tedlium3/ASR/RESULTS.md index d5b24a376..407b77fa8 100644 --- a/egs/tedlium3/ASR/RESULTS.md +++ b/egs/tedlium3/ASR/RESULTS.md @@ -28,7 +28,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --num-epochs 30 \ --start-epoch 0 \ --exp-dir transducer_stateless/exp \ - --max-duration 200 \ + --max-duration 200 ``` The tensorboard training log can be found at @@ -64,6 +64,8 @@ avg=16 --exp-dir transducer_stateless/exp \ --bpe-model ./data/lang_bpe_500/bpe.model \ --max-duration 100 \ - --decoding-method beam_search \ + --decoding-method modified_beam_search \ --beam-size 4 ``` + +A pre-trained model and decoding logs can be found at diff --git a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py index 3f87fabfb..c1d68a07b 100644 --- a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py +++ b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2022 Xiaomi Crop. (authors: Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -29,7 +29,7 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer +from lhotse import ChunkedLilcomHdf5Writer, CutSet, Fbank, FbankConfig from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor @@ -83,7 +83,7 @@ def compute_fbank_tedlium(): # when an executor is specified, make more partitions num_jobs=num_jobs if ex is None else 80, executor=ex, - storage_type=LilcomHdf5Writer, + storage_type=ChunkedLilcomHdf5Writer, ) cut_set.to_json(output_dir / f"cuts_{partition}.json.gz") diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py index e6a3cc783..bbd1b838c 100644 --- a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py +++ b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py @@ -1,6 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) """ Convert a transcript based on words to a list of BPE ids. @@ -28,12 +28,6 @@ def get_args(): parser.add_argument( "--texts", type=List[str], help="The input transcripts list." ) - parser.add_argument( - "--unk-id", - type=int, - default=2, - help="The number id for the token ''.", - ) parser.add_argument( "--bpe-model", type=str, @@ -70,7 +64,7 @@ def convert_texts_into_ids( else: y_ids.extend(id_segments[i]) else: - y_ids = sp.encode([text], out_type=int)[0] + y_ids = sp.encode(text, out_type=int) y.append(y_ids) return y diff --git a/egs/tedlium3/ASR/prepare.sh b/egs/tedlium3/ASR/prepare.sh index 9a643139f..76fc4ab94 100644 --- a/egs/tedlium3/ASR/prepare.sh +++ b/egs/tedlium3/ASR/prepare.sh @@ -59,9 +59,8 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then log "Stage -1: Download LM" # We assume that you have installed the git-lfs, if not, you could install it # using: `sudo apt-get install git-lfs && git-lfs install` - [ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm + mkdir -p $dl_dir/lm git clone https://huggingface.co/luomingshuang/tedlium3_lm $dl_dir/lm - cd $dl_dir/lm && git lfs pull # If you want to download Tedlium 4 gram language models # using the follow commands: diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py index d8ebead6b..7b35166a2 100644 --- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py +++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py @@ -175,13 +175,12 @@ class TedLiumAsrDataModule: def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "cuts_musan.json.gz" - ) - transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") + cuts_musan = load_manifest( + self.args.manifest_dir / "cuts_musan.json.gz" + ) transforms.append( CutMix( cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py deleted file mode 120000 index 4143bb116..000000000 --- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/transducer_stateless/pretrained.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py new file mode 100644 index 000000000..77bdd0ca6 --- /dev/null +++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py @@ -0,0 +1,343 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# 2022 Xiaomi Crop. (authors: Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) greedy search +./transducer_stateless/pretrained.py \ + --checkpoint ./transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + --max-sym-per-frame 1 \ + /path/to/foo.wav \ + /path/to/bar.wav \ + +(2) beam search +./transducer_stateless/pretrained.py \ + --checkpoint ./transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav \ + +(3) modified beam search +./transducer_stateless/pretrained.py \ + --checkpoint ./transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav \ + +You can also use `./transducer_stateless/exp/epoch-xx.pt`. + +Note: ./transducer_stateless/exp/pretrained.pt is generated by +./transducer_stateless/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torch.nn as nn +import torchaudio +from beam_search import beam_search, greedy_search, modified_beam_search +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from model import Transducer +from torch.nn.utils.rnn import pad_sequence + +from icefall.env import get_env_info +from icefall.utils import AttributeDict + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="Used only when --method is beam_search and modified_beam_search ", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=3, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "sample_rate": 16000, + # parameters for conformer + "feature_dim": 80, + "encoder_out_dim": 512, + "subsampling_factor": 4, + "attention_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + "vgg_frontend": False, + "env_info": get_env_info(), + } + ) + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Conformer( + num_features=params.feature_dim, + output_dim=params.encoder_out_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.attention_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + vgg_frontend=params.vgg_frontend, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + embedding_dim=params.encoder_out_dim, + blank_id=params.blank_id, + unk_id=params.unk_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + input_dim=params.encoder_out_dim, + output_dim=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + ) + return model + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + with torch.no_grad(): + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) + elif params.method == "modified_beam_search": + hyp = modified_beam_search( + model=model, encoder_out=encoder_out_i, beam=params.beam_size + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py index 7d1fa3eda..b3dba4409 100755 --- a/egs/tedlium3/ASR/transducer_stateless/train.py +++ b/egs/tedlium3/ASR/transducer_stateless/train.py @@ -26,7 +26,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --num-epochs 30 \ --start-epoch 0 \ --exp-dir transducer_stateless/exp \ - --max-duration 200 \ + --max-duration 200 """ @@ -394,9 +394,9 @@ def compute_loss( feature = feature.to(device) supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device)[: feature.size(0)] + feature_lens = supervisions["num_frames"].to(device) - texts = batch["supervisions"]["text"][: feature.size(0)] + texts = batch["supervisions"]["text"] unk_id = params.unk_id y = convert_texts_into_ids(texts, unk_id, sp=sp) @@ -625,7 +625,9 @@ def run(rank, world_size, args): train_cuts = tedlium.train_cuts() def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds + # Keep only utterances with duration between 1 second and max seconds + # Here, we set max as 20.0. + # If you want to use a big max-duration, you can set it as 17.0. return 1.0 <= c.duration <= 20.0 num_in_total = len(train_cuts)