From 5e88f80b50eb1b5068bd5dc1728861aeb4abf87c Mon Sep 17 00:00:00 2001 From: Xinyuan Li Date: Tue, 23 Jan 2024 20:14:39 -0500 Subject: [PATCH] Remove tdnn architecture from fluent speech commands recipe --- .../tdnn/asr_datamodule.py | 292 --------- egs/fluent_speech_commands/tdnn/decode.py | 315 ---------- egs/fluent_speech_commands/tdnn/export.py | 118 ---- .../tdnn/export_onnx.py | 158 ----- .../tdnn/jit_pretrained.py | 199 ------ egs/fluent_speech_commands/tdnn/model.py | 81 --- .../tdnn/onnx_pretrained.py | 242 -------- egs/fluent_speech_commands/tdnn/pretrained.py | 221 ------- egs/fluent_speech_commands/tdnn/train.py | 581 ------------------ .../transducer/asr_datamodule.py | 293 ++++++++- 10 files changed, 292 insertions(+), 2208 deletions(-) delete mode 100755 egs/fluent_speech_commands/tdnn/asr_datamodule.py delete mode 100755 egs/fluent_speech_commands/tdnn/decode.py delete mode 100755 egs/fluent_speech_commands/tdnn/export.py delete mode 100755 egs/fluent_speech_commands/tdnn/export_onnx.py delete mode 100755 egs/fluent_speech_commands/tdnn/jit_pretrained.py delete mode 100755 egs/fluent_speech_commands/tdnn/model.py delete mode 100755 egs/fluent_speech_commands/tdnn/onnx_pretrained.py delete mode 100755 egs/fluent_speech_commands/tdnn/pretrained.py delete mode 100755 egs/fluent_speech_commands/tdnn/train.py mode change 120000 => 100755 egs/fluent_speech_commands/transducer/asr_datamodule.py diff --git a/egs/fluent_speech_commands/tdnn/asr_datamodule.py b/egs/fluent_speech_commands/tdnn/asr_datamodule.py deleted file mode 100755 index bffd52e4c..000000000 --- a/egs/fluent_speech_commands/tdnn/asr_datamodule.py +++ /dev/null @@ -1,292 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import argparse -import logging -from functools import lru_cache -from pathlib import Path -from typing import List - -from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.dataset.datamodule import DataModule -from icefall.utils import str2bool - - -class SluDataModule(DataModule): - """ - DataModule for k2 ASR experiments. - It assumes there is always one train dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - """ - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - super().add_arguments(parser) - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--feature-dir", - type=Path, - default=Path("data/fbanks"), - help="Path to directory with train/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=30.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=False, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=10, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - def train_dataloaders(self) -> DataLoader: - logging.info("About to get train cuts") - cuts_train = self.train_cuts() - - logging.info("About to create train dataset") - transforms = [] - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - FbankConfig(sampling_rate=8000, num_mel_bins=23) - ), - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - drop_last=True, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=True, - ) - - return train_dl - - def valid_dataloaders(self) -> DataLoader: - logging.info("About to get valid cuts") - cuts_valid = self.valid_cuts() - - logging.debug("About to create valid dataset") - valid = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create valid dataloader") - valid_dl = DataLoader( - valid, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - persistent_workers=True, - ) - return valid_dl - - def test_dataloaders(self) -> DataLoader: - logging.info("About to get test cuts") - cuts_test = self.test_cuts() - - logging.debug("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts_test, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - persistent_workers=True, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - cuts_train = load_manifest_lazy( - self.args.feature_dir / "slu_cuts_train.jsonl.gz" - ) - return cuts_train - - @lru_cache() - def valid_cuts(self) -> List[CutSet]: - logging.info("About to get valid cuts") - cuts_valid = load_manifest_lazy( - self.args.feature_dir / "slu_cuts_valid.jsonl.gz" - ) - return cuts_valid - - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - logging.info("About to get test cuts") - cuts_test = load_manifest_lazy( - self.args.feature_dir / "slu_cuts_test.jsonl.gz" - ) - return cuts_test diff --git a/egs/fluent_speech_commands/tdnn/decode.py b/egs/fluent_speech_commands/tdnn/decode.py deleted file mode 100755 index a213c886a..000000000 --- a/egs/fluent_speech_commands/tdnn/decode.py +++ /dev/null @@ -1,315 +0,0 @@ -#!/usr/bin/env python3 - - -import argparse -import logging -from pathlib import Path -from typing import List, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import SluDataModule -from model import Tdnn - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.decode import get_lattice, one_best_decoding -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=13, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=1, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--export", - type=str2bool, - default=False, - help="""When enabled, the averaged model is saved to - tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved. - pretrained.pt contains a dict {"model": model.state_dict()}, - which can be loaded by `icefall.checkpoint.load_checkpoint()`. - """, - ) - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "exp_dir": Path("tdnn/exp/"), - "lang_dir": Path("data/lm/frames"), - "feature_dim": 23, - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - HLG: k2.Fsa, - batch: dict, - word_table: k2.SymbolTable, -) -> List[List[int]]: - """Decode one batch and return the result in a list-of-list. - Each sub list contains the word IDs for an utterance in the batch. - - Args: - params: - It's the return value of :func:`get_params`. - - - params.method is "1best", it uses 1best decoding. - - params.method is "nbest", it uses nbest decoding. - - model: - The neural model. - HLG: - The decoding graph. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - (https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py) - word_table: - It is the word symbol table. - Returns: - Return the decoding result. `len(ans)` == batch size. - """ - device = HLG.device - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - nnet_output = model(feature) - # nnet_output is (N, T, C) - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - hyps = get_texts(best_path) - hyps = [[word_table[i] for i in ids] for ids in hyps] - return hyps - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - HLG: k2.Fsa, - word_table: k2.SymbolTable, -) -> List[Tuple[str, List[str], List[str]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. - word_table: - It is word symbol table. - Returns: - Return a tuple contains two elements (ref_text, hyp_text): - The first is the reference transcript, and the second is the - predicted result. - """ - results = [] - - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = [] - for batch_idx, batch in enumerate(dl): - # texts = batch["supervisions"]["custom"]["frames"] - - texts = [' '.join(a.supervisions[0].custom["frames"]) for a in batch["supervisions"]["cut"]] - texts = [' ' + a.replace('change language', 'change_language') + ' ' for a in texts] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps = decode_one_batch( - params=params, - model=model, - HLG=HLG, - batch=batch, - word_table=word_table, - ) - - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results.extend(this_batch) - - num_cuts += len(batch["supervisions"]["text"]) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - exp_dir: Path, - test_set_name: str, - results: List[Tuple[str, List[str], List[str]]], -) -> None: - """Save results to `exp_dir`. - Args: - exp_dir: - The output directory. This function create the following files inside - this directory: - - - recogs-{test_set_name}.text - - - errs-{test_set_name}.txt - - It contains the detailed WER. - test_set_name: - The name of the test set, which will be part of the result filename. - results: - A list of tuples, each of which contains (ref_words, hyp_words). - Returns: - Return None. - """ - recog_path = exp_dir / f"recogs-{test_set_name}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = exp_dir / f"errs-{test_set_name}.txt" - with open(errs_filename, "w") as f: - write_error_stats(f, f"{test_set_name}", results) - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - -@torch.no_grad() -def main(): - parser = get_parser() - SluDataModule.add_arguments(parser) - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - params["env_info"] = get_env_info() - - setup_logger(f"{params.exp_dir}/log/log-decode") - logging.info("Decoding started") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) - HLG = HLG.to(device) - assert HLG.requires_grad is False - - model = Tdnn( - num_features=params.feature_dim, - num_classes=max_token_id + 1, # +1 for the blank symbol - ) - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) - - if params.export: - logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") - return - - model.to(device) - model.eval() - - # we need cut ids to display recognition results. - args.return_cuts = True - slu = SluDataModule(args) - test_dl = slu.test_dataloaders() - results = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - word_table=lexicon.word_table, - ) - - save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/fluent_speech_commands/tdnn/export.py b/egs/fluent_speech_commands/tdnn/export.py deleted file mode 100755 index c40cf8cd1..000000000 --- a/egs/fluent_speech_commands/tdnn/export.py +++ /dev/null @@ -1,118 +0,0 @@ -#!/usr/bin/env python3 - -""" -This file is for exporting trained models to a checkpoint -or to a torchscript model. - -(1) Generate the checkpoint tdnn/exp/pretrained.pt - -./tdnn/export.py \ - --epoch 14 \ - --avg 2 - -See ./tdnn/pretrained.py for how to use the generated file. - -(2) Generate torchscript model tdnn/exp/cpu_jit.pt - -./tdnn/export.py \ - --epoch 14 \ - --avg 2 \ - --jit 1 - -See ./tdnn/jit_pretrained.py for how to use the generated file. -""" - -import argparse -import logging - -import torch -from model import Tdnn -from train import get_params - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=14, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=2, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - return parser - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - - params = get_params() - params.update(vars(args)) - - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - - model = Tdnn( - num_features=params.feature_dim, - num_classes=max_token_id + 1, # +1 for the blank symbol - ) - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) - - model.to("cpu") - model.eval() - - if params.jit: - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/fluent_speech_commands/tdnn/export_onnx.py b/egs/fluent_speech_commands/tdnn/export_onnx.py deleted file mode 100755 index 2436ca81b..000000000 --- a/egs/fluent_speech_commands/tdnn/export_onnx.py +++ /dev/null @@ -1,158 +0,0 @@ -#!/usr/bin/env python3 - -""" -This file is for exporting trained models to onnx. - -Usage: - - ./tdnn/export_onnx.py \ - --epoch 14 \ - --avg 2 - -The above command generates the following two files: - - ./exp/model-epoch-14-avg-2.onnx - - ./exp/model-epoch-14-avg-2.int8.onnx - -See ./tdnn/onnx_pretrained.py for how to use them. -""" - -import argparse -import logging -from typing import Dict - -import onnx -import torch -from model import Tdnn -from onnxruntime.quantization import QuantType, quantize_dynamic -from train import get_params - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=14, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=2, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = str(value) - - onnx.save(model, filename) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - - params = get_params() - params.update(vars(args)) - - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - - model = Tdnn( - num_features=params.feature_dim, - num_classes=max_token_id + 1, # +1 for the blank symbol - ) - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) - - model.to("cpu") - model.eval() - - N = 1 - T = 100 - C = params.feature_dim - x = torch.rand(N, T, C) - - opset_version = 13 - onnx_filename = f"{params.exp_dir}/model-epoch-{params.epoch}-avg-{params.avg}.onnx" - torch.onnx.export( - model, - x, - onnx_filename, - verbose=False, - opset_version=opset_version, - input_names=["x"], - output_names=["log_prob"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "log_prob": {0: "N", 1: "T"}, - }, - ) - - logging.info(f"Saved to {onnx_filename}") - meta_data = { - "model_type": "tdnn", - "version": "1", - "model_author": "k2-fsa", - "comment": "non-streaming tdnn for the yesno recipe", - "vocab_size": max_token_id + 1, - } - - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=onnx_filename, meta_data=meta_data) - - logging.info("Generate int8 quantization models") - onnx_filename_int8 = ( - f"{params.exp_dir}/model-epoch-{params.epoch}-avg-{params.avg}.int8.onnx" - ) - - quantize_dynamic( - model_input=onnx_filename, - model_output=onnx_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - logging.info(f"Saved to {onnx_filename_int8}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/fluent_speech_commands/tdnn/jit_pretrained.py b/egs/fluent_speech_commands/tdnn/jit_pretrained.py deleted file mode 100755 index 84390fca5..000000000 --- a/egs/fluent_speech_commands/tdnn/jit_pretrained.py +++ /dev/null @@ -1,199 +0,0 @@ -#!/usr/bin/env python3 - -""" -This file shows how to use a torchscript model for decoding. - -Usage: - - ./tdnn/jit_pretrained.py \ - --nn-model ./tdnn/exp/cpu_jit.pt \ - --HLG ./data/lang_phone/HLG.pt \ - --words-file ./data/lang_phone/words.txt \ - download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - download/waves_yesno/0_0_1_0_0_0_1_0.wav - -Note that to generate ./tdnn/exp/cpu_jit.pt, -you can use ./export.py --jit 1 -""" - -import argparse -import logging -from typing import List -import math - - -import k2 -import kaldifeat -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - -from icefall.decode import get_lattice, one_best_decoding -from icefall.utils import AttributeDict, get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model", - type=str, - required=True, - help="""Path to the torchscript model. - You can use ./tdnn/export.py --jit 1 - to obtain it - """, - ) - - parser.add_argument( - "--words-file", - type=str, - required=True, - help="Path to words.txt", - ) - - parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. ", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "feature_dim": 23, - "num_classes": 4, # [, N, SIL, Y] - "sample_rate": 8000, - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - if sample_rate != expected_sample_rate: - wave = torchaudio.functional.resample( - wave, - orig_freq=sample_rate, - new_freq=expected_sample_rate, - ) - - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Loading torchscript model") - model = torch.jit.load(args.nn_model) - model.eval() - model.to(device) - - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - # Note: We don't use key padding mask for attention during decoding - nnet_output = model(features) - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/fluent_speech_commands/tdnn/model.py b/egs/fluent_speech_commands/tdnn/model.py deleted file mode 100755 index 52cff37e0..000000000 --- a/egs/fluent_speech_commands/tdnn/model.py +++ /dev/null @@ -1,81 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 Xiaomi Corp. (author: Fangjun Kuang) - - -import torch -import torch.nn as nn - - -class Tdnn(nn.Module): - def __init__(self, num_features: int, num_classes: int): - """ - Args: - num_features: - Model input dimension. - num_classes: - Model output dimension - """ - super().__init__() - - self.tdnn = nn.Sequential( - nn.Conv1d( - in_channels=num_features, - out_channels=32, - kernel_size=3, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=32, affine=False), - nn.Conv1d( - in_channels=32, - out_channels=32, - kernel_size=5, - dilation=2, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=32, affine=False), - nn.Conv1d( - in_channels=32, - out_channels=32, - kernel_size=5, - dilation=4, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=32, affine=False), - ) - self.output_linear = nn.Linear(in_features=32, out_features=num_classes) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - The input tensor with shape [N, T, C] - - Returns: - The output tensor has shape [N, T, C] - """ - x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] - x = self.tdnn(x) - x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] - x = self.output_linear(x) - x = nn.functional.log_softmax(x, dim=-1) - return x - - -def test_tdnn(): - num_features = 23 - num_classes = 4 - model = Tdnn(num_features=num_features, num_classes=num_classes) - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - N = 2 - T = 100 - C = num_features - x = torch.randn(N, T, C) - y = model(x) - print(x.shape) - print(y.shape) - - -if __name__ == "__main__": - test_tdnn() diff --git a/egs/fluent_speech_commands/tdnn/onnx_pretrained.py b/egs/fluent_speech_commands/tdnn/onnx_pretrained.py deleted file mode 100755 index b23a2a381..000000000 --- a/egs/fluent_speech_commands/tdnn/onnx_pretrained.py +++ /dev/null @@ -1,242 +0,0 @@ -#!/usr/bin/env python3 - -""" -This file shows how to use an ONNX model for decoding with onnxruntime. - -Usage: - -(1) Use a not quantized ONNX model, i.e., a float32 model - - ./tdnn/onnx_pretrained.py \ - --nn-model ./tdnn/exp/model-epoch-14-avg-2.onnx \ - --HLG ./data/lang_phone/HLG.pt \ - --words-file ./data/lang_phone/words.txt \ - download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - download/waves_yesno/0_0_1_0_0_0_1_0.wav - -(2) Use a quantized ONNX model, i.e., an int8 model - - ./tdnn/onnx_pretrained.py \ - --nn-model ./tdnn/exp/model-epoch-14-avg-2.int8.onnx \ - --HLG ./data/lang_phone/HLG.pt \ - --words-file ./data/lang_phone/words.txt \ - download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - download/waves_yesno/0_0_1_0_0_0_1_0.wav - -Note that to generate ./tdnn/exp/model-epoch-14-avg-2.onnx, -and ./tdnn/exp/model-epoch-14-avg-2.onnx, -you can use ./export_onnx.py --epoch 14 --avg 2 -""" - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import onnxruntime as ort -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - -from icefall.decode import get_lattice, one_best_decoding -from icefall.utils import AttributeDict, get_texts - - -class OnnxModel: - def __init__(self, nn_model: str): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 1 - - self.session_opts = session_opts - self.model = ort.InferenceSession( - nn_model, - sess_options=self.session_opts, - ) - - meta = self.model.get_modelmeta().custom_metadata_map - self.vocab_size = int(meta["vocab_size"]) - - def run( - self, - x: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - x: - A 3-D tensor of shape (N, T, C) - Returns: - Return a 3-D tensor log_prob of shape (N, T, C) - """ - out = self.model.run( - [ - self.model.get_outputs()[0].name, - ], - { - self.model.get_inputs()[0].name: x.numpy(), - }, - ) - return torch.from_numpy(out[0]) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--nn-model", - type=str, - required=True, - help="""Path to the torchscript model. - You can use ./tdnn/export.py --jit 1 - to obtain it - """, - ) - - parser.add_argument( - "--words-file", - type=str, - required=True, - help="Path to words.txt", - ) - - parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. ", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - if sample_rate != expected_sample_rate: - wave = torchaudio.functional.resample( - wave, - orig_freq=sample_rate, - new_freq=expected_sample_rate, - ) - - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "feature_dim": 23, - "sample_rate": 8000, - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def main(): - parser = get_parser() - args = parser.parse_args() - params = get_params() - params.update(vars(args)) - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - logging.info(f"device: {device}") - - logging.info(f"Loading onnx model {params.nn_model}") - model = OnnxModel(params.nn_model) - - logging.info(f"Loading HLG from {args.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - # Note: We don't use key padding mask for attention during decoding - nnet_output = model.run(features) - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/fluent_speech_commands/tdnn/pretrained.py b/egs/fluent_speech_commands/tdnn/pretrained.py deleted file mode 100755 index 987c49de6..000000000 --- a/egs/fluent_speech_commands/tdnn/pretrained.py +++ /dev/null @@ -1,221 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file shows how to use a checkpoint for decoding. - -Usage: - - ./tdnn/pretrained.py \ - --checkpoint ./tdnn/exp/pretrained.pt \ - --HLG ./data/lang_phone/HLG.pt \ - --words-file ./data/lang_phone/words.txt \ - download/waves_yesno/0_0_0_1_0_0_0_1.wav \ - download/waves_yesno/0_0_1_0_0_0_1_0.wav - -Note that to generate ./tdnn/exp/pretrained.pt, -you can use ./export.py -""" - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from model import Tdnn -from torch.nn.utils.rnn import pad_sequence - -from icefall.decode import get_lattice, one_best_decoding -from icefall.utils import AttributeDict, get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint(). " - "You can use ./tdnn/export.py to obtain it.", - ) - - parser.add_argument( - "--words-file", - type=str, - required=True, - help="Path to words.txt", - ) - - parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. ", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "feature_dim": 23, - "num_classes": 4, # [, N, SIL, Y] - "sample_rate": 8000, - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - if sample_rate != expected_sample_rate: - wave = torchaudio.functional.resample( - wave, - orig_freq=sample_rate, - new_freq=expected_sample_rate, - ) - - # We use only the first channel - ans.append(wave[0].contiguous()) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - - model = Tdnn( - num_features=params.feature_dim, - num_classes=params.num_classes, - ) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"]) - model.to(device) - model.eval() - - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - # Note: We don't use key padding mask for attention during decoding - nnet_output = model(features) - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/fluent_speech_commands/tdnn/train.py b/egs/fluent_speech_commands/tdnn/train.py deleted file mode 100755 index 4934d1b88..000000000 --- a/egs/fluent_speech_commands/tdnn/train.py +++ /dev/null @@ -1,581 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -import torch.optim as optim -from asr_datamodule import SluDataModule -from lhotse.utils import fix_random_seed -from model import Tdnn -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter - -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.graph_compiler import CtcTrainingGraphCompiler -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=100, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=14, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - tdnn/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - is saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - exp_dir: It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - - - lang_dir: It contains language related input files such as - "lexicon.txt" - - - lr: It specifies the initial learning rate - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - weight_decay: The weight_decay for the optimizer. - - - subsampling_factor: The subsampling factor for the model. - - - start_epoch: If it is not zero, load checkpoint `start_epoch-1` - and continue training from that checkpoint. - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - valid_interval: Run validation if batch_idx % valid_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - beam_size: It is used in k2.ctc_loss - - - reduction: It is used in k2.ctc_loss - - - use_double_scores: It is used in k2.ctc_loss - """ - params = AttributeDict( - { - "exp_dir": Path("tdnn/exp"), - "lang_dir": Path("data/lm/frames"), - "lr": 1e-3, - "feature_dim": 23, - "weight_decay": 1e-6, - "start_epoch": 0, - "num_epochs": 5, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 100, - "reset_interval": 20, - "valid_interval": 300, - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler._LRScheduler, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - batch: dict, - graph_compiler: CtcTrainingGraphCompiler, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Tdnn in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = graph_compiler.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - with torch.set_grad_enabled(is_training): - nnet_output = model(feature) - # nnet_output is (N, T, C) - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - supervisions = batch["supervisions"] - - # texts = supervisions["custom"]["frames"] - - - texts = [' '.join(a.supervisions[0].custom["frames"]) for a in supervisions["cut"]] - texts = [' ' + a.replace('change language', 'change_language') + ' ' for a in texts] - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - decoding_graph = graph_compiler.compile(texts) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - ) - - loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - info["frames"] = supervision_segments[:, 2].sum().item() - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=False, - ) - assert loss.requires_grad is False - - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - graph_compiler: CtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - graph_compiler: - It is used to convert transcripts to FSAs. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - # summary stats. - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - if batch_idx % params.log_interval == 0: - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, - "train/valid_", - params.batch_idx_train, - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - params["env_info"] = get_env_info() - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - lexicon = Lexicon(params.lang_dir) - max_phone_id = max(lexicon.tokens) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"device: {device}") - - graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device) - - model = Tdnn( - num_features=params.feature_dim, - num_classes=max_phone_id + 1, # +1 for the blank symbol - ) - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - model = DDP(model, device_ids=[rank]) - - optimizer = optim.SGD( - model.parameters(), - lr=params.lr, - weight_decay=params.weight_decay, - ) - - if checkpoints: - optimizer.load_state_dict(checkpoints["optimizer"]) - - slu = SluDataModule(args) - train_dl = slu.train_dataloaders() - - # There are only 60 waves: 30 files are used for training - # and the remaining 30 files are used for testing. - # We use test data as validation. - valid_dl = slu.test_dataloaders() - - for epoch in range(params.start_epoch, params.num_epochs): - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=None, - rank=rank, - ) - - logging.info("Done!") - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - SluDataModule.add_arguments(parser) - args = parser.parse_args() - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - main() diff --git a/egs/fluent_speech_commands/transducer/asr_datamodule.py b/egs/fluent_speech_commands/transducer/asr_datamodule.py deleted file mode 120000 index c9c8adb57..000000000 --- a/egs/fluent_speech_commands/transducer/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../tdnn/asr_datamodule.py \ No newline at end of file diff --git a/egs/fluent_speech_commands/transducer/asr_datamodule.py b/egs/fluent_speech_commands/transducer/asr_datamodule.py new file mode 100755 index 000000000..bffd52e4c --- /dev/null +++ b/egs/fluent_speech_commands/transducer/asr_datamodule.py @@ -0,0 +1,292 @@ +# Copyright 2021 Piotr Żelasko +# 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import List + +from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy +from lhotse.dataset import ( + CutConcatenate, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from torch.utils.data import DataLoader + +from icefall.dataset.datamodule import DataModule +from icefall.utils import str2bool + + +class SluDataModule(DataModule): + """ + DataModule for k2 ASR experiments. + It assumes there is always one train dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + """ + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + super().add_arguments(parser) + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--feature-dir", + type=Path, + default=Path("data/fbanks"), + help="Path to directory with train/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=30.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=False, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=10, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + def train_dataloaders(self) -> DataLoader: + logging.info("About to get train cuts") + cuts_train = self.train_cuts() + + logging.info("About to create train dataset") + transforms = [] + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + FbankConfig(sampling_rate=8000, num_mel_bins=23) + ), + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=True, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=True, + ) + + return train_dl + + def valid_dataloaders(self) -> DataLoader: + logging.info("About to get valid cuts") + cuts_valid = self.valid_cuts() + + logging.debug("About to create valid dataset") + valid = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create valid dataloader") + valid_dl = DataLoader( + valid, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + persistent_workers=True, + ) + return valid_dl + + def test_dataloaders(self) -> DataLoader: + logging.info("About to get test cuts") + cuts_test = self.test_cuts() + + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=23))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts_test, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + persistent_workers=True, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + cuts_train = load_manifest_lazy( + self.args.feature_dir / "slu_cuts_train.jsonl.gz" + ) + return cuts_train + + @lru_cache() + def valid_cuts(self) -> List[CutSet]: + logging.info("About to get valid cuts") + cuts_valid = load_manifest_lazy( + self.args.feature_dir / "slu_cuts_valid.jsonl.gz" + ) + return cuts_valid + + + @lru_cache() + def test_cuts(self) -> List[CutSet]: + logging.info("About to get test cuts") + cuts_test = load_manifest_lazy( + self.args.feature_dir / "slu_cuts_test.jsonl.gz" + ) + return cuts_test