diff --git a/egs/himia/wuw/ctc_tdnn/graph.py b/egs/himia/wuw/ctc_tdnn/graph.py new file mode 100644 index 000000000..184e01ed1 --- /dev/null +++ b/egs/himia/wuw/ctc_tdnn/graph.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (Author: Weiji Zhuang, +# Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + + +def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]): + """ + A graph starts with blank/unknown and follwoing by wakeup word. + + Args: + wakeup_word_tokens: A sequence of token ids corresponding wakeup_word. + It should not contain 0 and 1. + We assume 0 is for blank and 1 is for unknown. + """ + assert 0 not in wakeup_word_tokens + assert 1 not in wakeup_word_tokens + assert len(wakeup_word_tokens) >= 2 + keyword_ilabel_start = wakeup_word_tokens[0] + fst_graph = "" + for non_wake_word_token in range(keyword_ilabel_start): + fst_graph += f"0 0 {non_wake_word_token} 0\n" + cur_state = 1 + for token_idx in range(len(wakeup_word_tokens) - 1): + token = wakeup_word_tokens[token_idx] + fst_graph += f"{cur_state - 1} {cur_state} {token} 0\n" + fst_graph += f"{cur_state} {cur_state} {token} 0\n" + cur_state += 1 + + token = wakeup_word_tokens[-1] + fst_graph += f"{cur_state - 1} {cur_state} {token} 1\n" + fst_graph += f"{cur_state} {cur_state} {token} 0\n" + fst_graph += f"{cur_state}\n" + return fst_graph diff --git a/egs/himia/wuw/ctc_tdnn/inference.py b/egs/himia/wuw/ctc_tdnn/inference.py new file mode 100755 index 000000000..eae9c5333 --- /dev/null +++ b/egs/himia/wuw/ctc_tdnn/inference.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corporation (Author: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path + +import torch +from lhotse.features.io import NumpyHdf5Writer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + setup_logger, +) + +from asr_datamodule import HiMiaWuwDataModule +from tdnn import Tdnn + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=10, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 1.", + ) + parser.add_argument( + "--avg", + type=int, + default=1, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="ctc_tdnn/exp", + help="The experiment dir", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "env_info": get_env_info(), + "feature_dim": 80, + "number_class": 9, + } + ) + return params + + +def inference_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: torch.nn.Module, + test_set: str, +): + """Compute and save model output of each utterance. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + test_set: + Name of test set. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + writer = NumpyHdf5Writer(f"{params.out_dir}/{test_set}") + for batch_idx, batch in enumerate(dl): + device = params.device + feature = batch["inputs"] + assert feature.ndim == 3 + supervisions = batch["supervisions"] + start_frames = supervisions["start_frame"] + end_frames = start_frames + supervisions["num_frames"] + + feature = feature.to(device) + # model_output is log_softmax(logit) with shape [N, T, C] + model_output = model(feature) + + for i in range(feature.size(0)): + assert start_frames[i] == 0 + cut = batch["supervisions"]["cut"][i] + cur_target = model_output[i][start_frames[i] : end_frames[i]] + writer.store_array(key=cut.id, value=cur_target.cpu().numpy()) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + +@torch.no_grad() +def main(): + parser = get_parser() + HiMiaWuwDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + out_dir = f"{params.exp_dir}/post/epoch_{params.epoch}-avg_{params.avg}/" + Path(out_dir).mkdir(parents=True, exist_ok=True) + params.out_dir = out_dir + setup_logger(f"{out_dir}/log-decode") + logging.info("Decoding started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = Tdnn(params.feature_dim, params.number_class) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=True) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=True + ) + + model.to(device) + model.eval() + params.device = device + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + himia = HiMiaWuwDataModule(args) + + aishell_test_cuts = himia.aishell_test_cuts() + test_cuts = himia.test_cuts() + cw_test_cuts = himia.cw_test_cuts() + + aishell_test_dl = himia.test_dataloaders(aishell_test_cuts) + test_dl = himia.test_dataloaders(test_cuts) + cw_test_dl = himia.test_dataloaders(cw_test_cuts) + + test_sets = ["aishell_test", "test", "cw_test"] + test_dls = [aishell_test_dl, test_dl, cw_test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + inference_dataset( + dl=test_dl, + params=params, + model=model, + test_set=test_set, + ) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/himia/wuw/ctc_tdnn/tokenizer.py b/egs/himia/wuw/ctc_tdnn/tokenizer.py new file mode 100644 index 000000000..bb988da6d --- /dev/null +++ b/egs/himia/wuw/ctc_tdnn/tokenizer.py @@ -0,0 +1,94 @@ +# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import torch + +from typing import List, Tuple + + +class WakeupWordTokenizer(object): + def __init__( + self, + wakeup_word: str = "", + wakeup_word_tokens: List[int] = None, + ) -> None: + """ + Args: + wakeup_word: content of positive samples. + A sample will be treated as a negative sample unless its context + is exactly the same to key_words. + wakeup_word_tokens: A list if int represents token ids of wakeup_word. + For example: the pronunciation of "你好米雅" is + "n i h ao m i y a". + Suppose we are using following lexicon: + blk 0 + unk 1 + n 2 + i 3 + h 4 + ao 5 + m 6 + y 7 + a 8 + Then wakeup_word_tokens for "你好米雅" is: + n i h ao m i y a + [2, 3, 4, 5, 6, 3, 7, 8] + """ + super().__init__() + assert wakeup_word is not None + assert wakeup_word_tokens is not None + assert ( + 0 not in wakeup_word_tokens + ), f"0 is kept for blank. Please Remove 0 from {wakeup_word_tokens}" + assert 1 not in wakeup_word_tokens, ( + f"1 is kept for unknown and negative samples. " + f" Please Remove 1 from {wakeup_word_tokens}" + ) + self.wakeup_word = wakeup_word + self.wakeup_word_tokens = wakeup_word_tokens + self.positive_number_tokens = len(wakeup_word_tokens) + self.negative_word_tokens = [1] + self.negative_number_tokens = 1 + + def texts_to_token_ids(self, texts: List[str]) -> Tuple[torch.Tensor, int]: + """Convert a list of texts to a list of k2.Fsa based texts. + + Args: + texts: + It is a list of strings. + Returns: + Return a list of k2.Fsa, one for an element in texts. + If the element is `wakeup_word`, a graph for positive samples is appneded + into resulting graph_vec, otherwise, a graph for negative samples is appended. + + Number of positive samples is also returned to track its proportion. + """ + batch_token_ids = [] + target_lengths = [] + number_positive_samples = 0 + for utt_text in texts: + if utt_text == self.wakeup_word: + batch_token_ids.append(self.wakeup_word_tokens) + target_lengths.append(self.positive_number_tokens) + number_positive_samples += 1 + else: + batch_token_ids.append(self.negative_word_tokens) + target_lengths.append(self.negative_number_tokens) + + target = torch.tensor(list(itertools.chain.from_iterable(batch_token_ids))) + target_lengths = torch.tensor(target_lengths) + return target, target_lengths, number_positive_samples diff --git a/egs/himia/wuw/ctc_tdnn/train.py b/egs/himia/wuw/ctc_tdnn/train.py new file mode 100755 index 000000000..95ad6c324 --- /dev/null +++ b/egs/himia/wuw/ctc_tdnn/train.py @@ -0,0 +1,678 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + export CUDA_VISIBLE_DEVICES="0,1,2,3" + ./ctc_tdnn/train.py \ + --exp-dir ./tdnn/exp \ + --world-size 4 \ + --max-duration 200 \ + --num-epochs 20 +""" + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import HiMiaWuwDataModule +from tdnn import Tdnn + +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter + +from tokenizer import WakeupWordTokenizer +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + ctc_tdnn/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="ctc_tdnn/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=0.001, + help="The lr_factor for optimizer", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - number_class: Numer of classes. Each token will have a token id + from [0, num_class). + In this recipe, 0 is usually kept for blank, + and 1 is usually kept for negative words. + - wakeup_word: Text of wakeup word, i.e. positive samples. + - wakeup_word_tokens: A sequence of token ids corresponding wakeup_word. + - weight_decay: The weight_decay for the optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 5, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for model + "feature_dim": 80, + "number_class": 9, + # parameters for tokenizer + "wakeup_word": "你好米雅", + "wakeup_word_tokens": [2, 3, 4, 5, 6, 3, 7, 8], + # parameters for Optimizer + "weight_decay": 1e-6, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + tokenizer: WakeupWordTokenizer, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + tokenizer: + For positive samples, map their texts to corresponding token index sequence. + While for negative samples, map their texts to unknown no matter what they are. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = model.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + N, T, C = feature.shape + feature = feature.to(device) + + supervisions = batch["supervisions"] + texts = supervisions["text"] + with torch.set_grad_enabled(is_training): + # model_output is log_softmax(logit) with shape [N, T, C] + model_output = model(feature) + + assert torch.all(supervisions["start_frame"] == 0) + num_frames = supervisions["num_frames"].to(device) + + target, target_lengths, number_positive_samples = tokenizer.texts_to_token_ids( + texts + ) # noqa E501 + target = target.to(device) + target_lengths = target_lengths.to(device) + ctc_loss = nn.CTCLoss(reduction="sum") + # [N, T, C] --> [T, N, C] + model_output = model_output.transpose(0, 1) + loss = ctc_loss(model_output, target, num_frames, target_lengths) + loss /= num_frames.sum() + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = num_frames.sum().item() + + info["loss"] = loss.detach().cpu().item() * info["frames"] + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = supervisions["num_frames"].sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() + ) + + info["number_positive_cuts_ratio"] = (number_positive_samples / N) * info["frames"] + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + tokenizer: WakeupWordTokenizer, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + tokenizer=tokenizer, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + tokenizer: WakeupWordTokenizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + tokenizer: + For positive samples, map their texts to corresponding token index sequence. + While for negative samples, map their texts to unknown no matter what they are. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + tokenizer=tokenizer, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if batch_idx % params.log_interval == 0: + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + tokenizer = WakeupWordTokenizer( + wakeup_word=params.wakeup_word, + wakeup_word_tokens=params.wakeup_word_tokens, + ) + + logging.info("About to create model") + + model = Tdnn(params.feature_dim, params.number_class) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + model.device = device + + optimizer = torch.optim.Adam( + model.parameters(), + lr=params.lr_factor, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + himia = HiMiaWuwDataModule(args) + + train_cuts = himia.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 0.5 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + train_dl = himia.train_dataloaders(train_cuts) + + valid_cuts = himia.dev_cuts() + valid_dl = himia.valid_dataloaders(valid_cuts) + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + tokenizer=tokenizer, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + fix_random_seed(params.seed + epoch) + train_dl.sampler.set_epoch(epoch) + + # TODO: Support lr scheduler + cur_lr = params.lr_factor + if tb_writer is not None: + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + tokenizer=tokenizer, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + tokenizer: WakeupWordTokenizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + tokenizer=tokenizer, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + HiMiaWuwDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/himia/wuw/prepare.sh b/egs/himia/wuw/prepare.sh index bb4f0f36c..a47a20682 100755 --- a/egs/himia/wuw/prepare.sh +++ b/egs/himia/wuw/prepare.sh @@ -2,7 +2,7 @@ set -eou pipefail -stage=6 +stage=0 stop_stage=6 # HI_MIA and aishell dataset are used in this experiment. diff --git a/egs/himia/wuw/run_ctc_tdnn.sh b/egs/himia/wuw/run_ctc_tdnn.sh new file mode 100644 index 000000000..6556eab93 --- /dev/null +++ b/egs/himia/wuw/run_ctc_tdnn.sh @@ -0,0 +1,55 @@ +#!/usr/bin/env bash + +set -eou pipefail + +# You need to execute ./prepare.sh to prepare datasets. +stage=1 +stop_stage=2 + +epoch=10 +avg=1 +exp_dir=./ctc_tdnn/exp/ +epoch_avg=epoch_${epoch}-avg_${avg} +post_dir=${exp_dir}/post/${epoch_avg} + +. shared/parse_options.sh || exit 1 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Model training" + python ./ctc_tdnn/train.py \ + --num-epochs $epoch +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Get posterior of test sets" + python ctc_tdnn/inference.py \ + --avg $avg \ + --epoch $epoch \ + --exp-dir ${exp_dir} +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Decode and compute area under curve(AUC)" + for test_set in test aishell_test cw_test; do + python ctc_tdnn/decode.py \ + --decoding-graph ./data/LG.int \ + --post-h5 ${post_dir}/${test_set}.h5 \ + --score-file ${post_dir}/fst_${test_set}_pos_h5.txt + done + python ./local/auc.py \ + --legend himia_cw \ + --positive-score-file ${post_dir}/fst_test_pos_h5.txt \ + --negative-score-file ${post_dir}/fst_cw_test_pos_h5.txt + + python ./local/auc.py \ + --legend himia_aishell \ + --positive-score-file ${post_dir}/fst_test_pos_h5.txt \ + --negative-score-file ${post_dir}/fst_aishell_test_pos_h5.txt +fi