From 80b83a2a36d566f320f7c1ddad694d5a59c6f7f1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 14 Jan 2022 22:09:53 +0000 Subject: [PATCH 01/12] Initial commit for fisher english + swbd data prep --- egs/fisher_swbd/ASR/conformer_ctc/README.md | 75 + egs/fisher_swbd/ASR/conformer_ctc/__init__.py | 0 egs/fisher_swbd/ASR/conformer_ctc/ali.py | 399 +++++ .../ASR/conformer_ctc/asr_datamodule.py | 1 + .../ASR/conformer_ctc/conformer.py | 930 +++++++++++ egs/fisher_swbd/ASR/conformer_ctc/decode.py | 701 +++++++++ egs/fisher_swbd/ASR/conformer_ctc/export.py | 165 ++ .../ASR/conformer_ctc/label_smoothing.py | 98 ++ .../ASR/conformer_ctc/pretrained.py | 435 ++++++ .../ASR/conformer_ctc/subsampling.py | 161 ++ .../ASR/conformer_ctc/test_label_smoothing.py | 52 + .../ASR/conformer_ctc/test_subsampling.py | 48 + .../ASR/conformer_ctc/test_transformer.py | 104 ++ egs/fisher_swbd/ASR/conformer_ctc/train.py | 741 +++++++++ .../ASR/conformer_ctc/transformer.py | 953 ++++++++++++ egs/fisher_swbd/ASR/local/__init__.py | 0 egs/fisher_swbd/ASR/local/compile_hlg.py | 159 ++ .../ASR/local/compute_fbank_librispeech.py | 100 ++ .../ASR/local/compute_fbank_musan.py | 97 ++ .../convert_transcript_words_to_tokens.py | 107 ++ .../ASR/local/display_manifest_statistics.py | 215 +++ egs/fisher_swbd/ASR/local/download_lm.py | 97 ++ .../ASR/local/generate_unique_lexicon.py | 100 ++ .../normalize_and_filter_supervisions.py | 169 ++ egs/fisher_swbd/ASR/local/prepare_lang.py | 413 +++++ egs/fisher_swbd/ASR/local/prepare_lang_bpe.py | 254 +++ .../ASR/local/prepare_lang_g2pen.py | 482 ++++++ .../ASR/local/test_prepare_lang.py | 106 ++ egs/fisher_swbd/ASR/local/train_bpe_model.py | 98 ++ egs/fisher_swbd/ASR/prepare.sh | 244 +++ egs/fisher_swbd/ASR/shared | 1 + .../ASR/streaming_conformer_ctc/README.md | 90 ++ .../streaming_conformer_ctc/asr_datamodule.py | 1 + .../ASR/streaming_conformer_ctc/conformer.py | 1355 +++++++++++++++++ .../label_smoothing.py | 1 + .../streaming_decode.py | 521 +++++++ .../streaming_conformer_ctc/subsampling.py | 1 + .../ASR/streaming_conformer_ctc/train.py | 745 +++++++++ .../streaming_conformer_ctc/transformer.py | 966 ++++++++++++ egs/fisher_swbd/ASR/tdnn_lstm_ctc/README.md | 4 + egs/fisher_swbd/ASR/tdnn_lstm_ctc/__init__.py | 0 .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 385 +++++ egs/fisher_swbd/ASR/tdnn_lstm_ctc/decode.py | 505 ++++++ egs/fisher_swbd/ASR/tdnn_lstm_ctc/model.py | 103 ++ .../ASR/tdnn_lstm_ctc/pretrained.py | 277 ++++ egs/fisher_swbd/ASR/tdnn_lstm_ctc/train.py | 603 ++++++++ 46 files changed, 13062 insertions(+) create mode 100644 egs/fisher_swbd/ASR/conformer_ctc/README.md create mode 100644 egs/fisher_swbd/ASR/conformer_ctc/__init__.py create mode 100755 egs/fisher_swbd/ASR/conformer_ctc/ali.py create mode 120000 egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py create mode 100644 egs/fisher_swbd/ASR/conformer_ctc/conformer.py create mode 100755 egs/fisher_swbd/ASR/conformer_ctc/decode.py create mode 100755 egs/fisher_swbd/ASR/conformer_ctc/export.py create mode 100644 egs/fisher_swbd/ASR/conformer_ctc/label_smoothing.py create mode 100755 egs/fisher_swbd/ASR/conformer_ctc/pretrained.py create mode 100644 egs/fisher_swbd/ASR/conformer_ctc/subsampling.py create mode 100755 egs/fisher_swbd/ASR/conformer_ctc/test_label_smoothing.py create mode 100755 egs/fisher_swbd/ASR/conformer_ctc/test_subsampling.py create mode 100644 egs/fisher_swbd/ASR/conformer_ctc/test_transformer.py create mode 100755 egs/fisher_swbd/ASR/conformer_ctc/train.py create mode 100644 egs/fisher_swbd/ASR/conformer_ctc/transformer.py create mode 100644 egs/fisher_swbd/ASR/local/__init__.py create mode 100755 egs/fisher_swbd/ASR/local/compile_hlg.py create mode 100755 egs/fisher_swbd/ASR/local/compute_fbank_librispeech.py create mode 100755 egs/fisher_swbd/ASR/local/compute_fbank_musan.py create mode 100755 egs/fisher_swbd/ASR/local/convert_transcript_words_to_tokens.py create mode 100755 egs/fisher_swbd/ASR/local/display_manifest_statistics.py create mode 100755 egs/fisher_swbd/ASR/local/download_lm.py create mode 100755 egs/fisher_swbd/ASR/local/generate_unique_lexicon.py create mode 100644 egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py create mode 100755 egs/fisher_swbd/ASR/local/prepare_lang.py create mode 100755 egs/fisher_swbd/ASR/local/prepare_lang_bpe.py create mode 100755 egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py create mode 100755 egs/fisher_swbd/ASR/local/test_prepare_lang.py create mode 100755 egs/fisher_swbd/ASR/local/train_bpe_model.py create mode 100755 egs/fisher_swbd/ASR/prepare.sh create mode 120000 egs/fisher_swbd/ASR/shared create mode 100644 egs/fisher_swbd/ASR/streaming_conformer_ctc/README.md create mode 120000 egs/fisher_swbd/ASR/streaming_conformer_ctc/asr_datamodule.py create mode 100644 egs/fisher_swbd/ASR/streaming_conformer_ctc/conformer.py create mode 120000 egs/fisher_swbd/ASR/streaming_conformer_ctc/label_smoothing.py create mode 100755 egs/fisher_swbd/ASR/streaming_conformer_ctc/streaming_decode.py create mode 120000 egs/fisher_swbd/ASR/streaming_conformer_ctc/subsampling.py create mode 100755 egs/fisher_swbd/ASR/streaming_conformer_ctc/train.py create mode 100644 egs/fisher_swbd/ASR/streaming_conformer_ctc/transformer.py create mode 100644 egs/fisher_swbd/ASR/tdnn_lstm_ctc/README.md create mode 100644 egs/fisher_swbd/ASR/tdnn_lstm_ctc/__init__.py create mode 100644 egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py create mode 100755 egs/fisher_swbd/ASR/tdnn_lstm_ctc/decode.py create mode 100644 egs/fisher_swbd/ASR/tdnn_lstm_ctc/model.py create mode 100755 egs/fisher_swbd/ASR/tdnn_lstm_ctc/pretrained.py create mode 100755 egs/fisher_swbd/ASR/tdnn_lstm_ctc/train.py diff --git a/egs/fisher_swbd/ASR/conformer_ctc/README.md b/egs/fisher_swbd/ASR/conformer_ctc/README.md new file mode 100644 index 000000000..37ace4204 --- /dev/null +++ b/egs/fisher_swbd/ASR/conformer_ctc/README.md @@ -0,0 +1,75 @@ +## Introduction + +Please visit + +for how to run this recipe. + +## How to compute framewise alignment information + +### Step 1: Train a model + +Please use `conformer_ctc/train.py` to train a model. +See +for how to do it. + +### Step 2: Compute framewise alignment + +Run + +``` +# Choose a checkpoint and determine the number of checkpoints to average +epoch=30 +avg=15 +./conformer_ctc/ali.py \ + --epoch $epoch \ + --avg $avg \ + --max-duration 500 \ + --bucketing-sampler 0 \ + --full-libri 1 \ + --exp-dir conformer_ctc/exp \ + --lang-dir data/lang_bpe_500 \ + --ali-dir data/ali_500 +``` +and you will get four files inside the folder `data/ali_500`: + +``` +$ ls -lh data/ali_500 +total 546M +-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:06 test_clean.pt +-rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:07 test_other.pt +-rw-r--r-- 1 kuangfangjun root 542M Sep 28 11:36 train-960.pt +-rw-r--r-- 1 kuangfangjun root 2.1M Sep 28 11:38 valid.pt +``` + +**Note**: It can take more than 3 hours to compute the alignment +for the training dataset, which contains 960 * 3 = 2880 hours of data. + +**Caution**: The model parameters in `conformer_ctc/ali.py` have to match those +in `conformer_ctc/train.py`. + +**Caution**: You have to set the parameter `preserve_id` to `True` for `CutMix`. +Search `./conformer_ctc/asr_datamodule.py` for `preserve_id`. + +### Step 3: Check your extracted alignments + +There is a file `test_ali.py` in `icefall/test` that can be used to test your +alignments. It uses pre-computed alignments to modify a randomly generated +`nnet_output` and it checks that we can decode the correct transcripts +from the resulting `nnet_output`. + +You should get something like the following if you run that script: + +``` +$ ./test/test_ali.py +['THE GOOD NATURED AUDIENCE IN PITY TO FALLEN MAJESTY SHOWED FOR ONCE GREATER DEFERENCE TO THE KING THAN TO THE MINISTER AND SUNG THE PSALM WHICH THE FORMER HAD CALLED FOR', 'THE OLD SERVANT TOLD HIM QUIETLY AS THEY CREPT BACK TO DWELL THAT THIS PASSAGE THAT LED FROM THE HUT IN THE PLEASANCE TO SHERWOOD AND THAT GEOFFREY FOR THE TIME WAS HIDING WITH THE OUTLAWS IN THE FOREST', 'FOR A WHILE SHE LAY IN HER CHAIR IN HAPPY DREAMY PLEASURE AT SUN AND BIRD AND TREE', "BUT THE ESSENCE OF LUTHER'S LECTURES IS THERE"] +['THE GOOD NATURED AUDIENCE IN PITY TO FALLEN MAJESTY SHOWED FOR ONCE GREATER DEFERENCE TO THE KING THAN TO THE MINISTER AND SUNG THE PSALM WHICH THE FORMER HAD CALLED FOR', 'THE OLD SERVANT TOLD HIM QUIETLY AS THEY CREPT BACK TO GAMEWELL THAT THIS PASSAGE WAY LED FROM THE HUT IN THE PLEASANCE TO SHERWOOD AND THAT GEOFFREY FOR THE TIME WAS HIDING WITH THE OUTLAWS IN THE FOREST', 'FOR A WHILE SHE LAY IN HER CHAIR IN HAPPY DREAMY PLEASURE AT SUN AND BIRD AND TREE', "BUT THE ESSENCE OF LUTHER'S LECTURES IS THERE"] +``` + +### Step 4: Use your alignments in training + +Please refer to `conformer_mmi/train.py` for usage. Some useful +functions are: + +- `load_alignments()`, it loads alignment saved by `conformer_ctc/ali.py` +- `convert_alignments_to_tensor()`, it converts alignments to PyTorch tensors +- `lookup_alignments()`, it returns the alignments of utterances by giving the cut ID of the utterances. diff --git a/egs/fisher_swbd/ASR/conformer_ctc/__init__.py b/egs/fisher_swbd/ASR/conformer_ctc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/fisher_swbd/ASR/conformer_ctc/ali.py b/egs/fisher_swbd/ASR/conformer_ctc/ali.py new file mode 100755 index 000000000..42fa2308e --- /dev/null +++ b/egs/fisher_swbd/ASR/conformer_ctc/ali.py @@ -0,0 +1,399 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + ./conformer_ctc/ali.py \ + --exp-dir ./conformer_ctc/exp \ + --lang-dir ./data/lang_bpe_500 \ + --epoch 20 \ + --avg 10 \ + --max-duration 300 \ + --dataset train-clean-100 \ + --out-dir data/ali +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import numpy as np +import torch +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from lhotse import CutSet +from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.decode import one_best_decoding +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + encode_supervisions, + get_alignments, + setup_logger, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--out-dir", + type=str, + required=True, + help="""Output directory. + It contains 3 generated files: + + - labels_xxx.h5 + - aux_labels_xxx.h5 + - cuts_xxx.json.gz + + where xxx is the value of `--dataset`. For instance, if + `--dataset` is `train-clean-100`, it will contain 3 files: + + - `labels_train-clean-100.h5` + - `aux_labels_train-clean-100.h5` + - `cuts_train-clean-100.json.gz` + + Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise + alignment. The difference is that labels_xxx.h5 contains repeats. + """, + ) + + parser.add_argument( + "--dataset", + type=str, + required=True, + help="""The name of the dataset to compute alignments for. + Possible values are: + - test-clean. + - test-other + - train-clean-100 + - train-clean-360 + - train-other-500 + - dev-clean + - dev-other + """, + ) + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "lm_dir": Path("data/lm"), + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "subsampling_factor": 4, + # Set it to 0 since attention decoder + # is not used for computing alignments + "num_decoder_layers": 0, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "output_beam": 10, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def compute_alignments( + model: torch.nn.Module, + dl: torch.utils.data.DataLoader, + labels_writer: FeaturesWriter, + aux_labels_writer: FeaturesWriter, + params: AttributeDict, + graph_compiler: BpeCtcTrainingGraphCompiler, +) -> CutSet: + """Compute the framewise alignments of a dataset. + + Args: + model: + The neural network model. + dl: + Dataloader containing the dataset. + params: + Parameters for computing alignments. + graph_compiler: + It converts token IDs to decoding graphs. + Returns: + Return a CutSet. Each cut has two custom fields: labels_alignment + and aux_labels_alignment, containing framewise alignments information. + Both are of type `lhotse.array.TemporalArray`. The difference between + the two alignments is that `labels_alignment` contain repeats. + """ + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + num_cuts = 0 + + device = graph_compiler.device + cuts = [] + for batch_idx, batch in enumerate(dl): + feature = batch["inputs"] + + # at entry, feature is [N, T, C] + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cut_list = supervisions["cut"] + + for cut in cut_list: + assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}" + + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is [N, T, C] + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + # we need also to sort cut_ids as encode_supervisions() + # reorders "texts". + # In general, new2old is an identity map since lhotse sorts the returned + # cuts by duration in descending order + new2old = supervision_segments[:, 0].tolist() + + cut_list = [cut_list[i] for i in new2old] + + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + lattice = k2.intersect_dense( + decoding_graph, + dense_fsa_vec, + params.output_beam, + ) + + best_path = one_best_decoding( + lattice=lattice, + use_double_scores=params.use_double_scores, + ) + + labels_ali = get_alignments(best_path, kind="labels") + aux_labels_ali = get_alignments(best_path, kind="aux_labels") + assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) + for cut, labels, aux_labels in zip( + cut_list, labels_ali, aux_labels_ali + ): + cut.labels_alignment = labels_writer.store_array( + key=cut.id, + value=np.asarray(labels, dtype=np.int32), + # frame shift is 0.01s, subsampling_factor is 4 + frame_shift=0.04, + temporal_dim=0, + start=0, + ) + cut.aux_labels_alignment = aux_labels_writer.store_array( + key=cut.id, + value=np.asarray(aux_labels, dtype=np.int32), + # frame shift is 0.01s, subsampling_factor is 4 + frame_shift=0.04, + temporal_dim=0, + start=0, + ) + + cuts += cut_list + + num_cuts += len(cut_list) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + + return CutSet.from_cuts(cuts) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + args.enable_spec_aug = False + args.enable_musan = False + args.return_cuts = True + args.concatenate_cuts = False + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-ali") + + logging.info(f"Computing alignments for {params.dataset} - started") + logging.info(params) + + out_dir = Path(params.out_dir) + out_dir.mkdir(exist_ok=True) + + out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" + out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" + out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz" + + for f in ( + out_labels_ali_filename, + out_aux_labels_ali_filename, + out_manifest_filename, + ): + if f.exists(): + logging.info(f"{f} exists - skipping") + return + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + model.to(device) + + if params.avg == 1: + load_checkpoint( + f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False + ) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + + model.eval() + + librispeech = LibriSpeechAsrDataModule(args) + if params.dataset == "test-clean": + test_clean_cuts = librispeech.test_clean_cuts() + dl = librispeech.test_dataloaders(test_clean_cuts) + elif params.dataset == "test-other": + test_other_cuts = librispeech.test_other_cuts() + dl = librispeech.test_dataloaders(test_other_cuts) + elif params.dataset == "train-clean-100": + train_clean_100_cuts = librispeech.train_clean_100_cuts() + dl = librispeech.train_dataloaders(train_clean_100_cuts) + elif params.dataset == "train-clean-360": + train_clean_360_cuts = librispeech.train_clean_360_cuts() + dl = librispeech.train_dataloaders(train_clean_360_cuts) + elif params.dataset == "train-other-500": + train_other_500_cuts = librispeech.train_other_500_cuts() + dl = librispeech.train_dataloaders(train_other_500_cuts) + elif params.dataset == "dev-clean": + dev_clean_cuts = librispeech.dev_clean_cuts() + dl = librispeech.valid_dataloaders(dev_clean_cuts) + else: + assert params.dataset == "dev-other", f"{params.dataset}" + dev_other_cuts = librispeech.dev_other_cuts() + dl = librispeech.valid_dataloaders(dev_other_cuts) + + logging.info(f"Processing {params.dataset}") + with NumpyHdf5Writer(out_labels_ali_filename) as labels_writer: + with NumpyHdf5Writer(out_aux_labels_ali_filename) as aux_labels_writer: + cut_set = compute_alignments( + model=model, + dl=dl, + labels_writer=labels_writer, + aux_labels_writer=aux_labels_writer, + params=params, + graph_compiler=graph_compiler, + ) + + cut_set.to_file(out_manifest_filename) + + logging.info( + f"For dataset {params.dataset}, its alignments with repeats are " + f"saved to {out_labels_ali_filename}, the alignments without repeats " + f"are saved to {out_aux_labels_ali_filename}, and the cut manifest " + f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}" + ) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py new file mode 120000 index 000000000..fa1b8cca3 --- /dev/null +++ b/egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py @@ -0,0 +1 @@ +../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/fisher_swbd/ASR/conformer_ctc/conformer.py b/egs/fisher_swbd/ASR/conformer_ctc/conformer.py new file mode 100644 index 000000000..871712a46 --- /dev/null +++ b/egs/fisher_swbd/ASR/conformer_ctc/conformer.py @@ -0,0 +1,930 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings +from typing import Optional, Tuple, Union + +import torch +from torch import Tensor, nn +from transformer import Supervisions, Transformer, encoder_padding_mask + + +class Conformer(Transformer): + """ + Args: + num_features (int): Number of input features + num_classes (int): Number of output classes + subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) + d_model (int): attention dimension + nhead (int): number of head + dim_feedforward (int): feedforward dimention + num_encoder_layers (int): number of encoder layers + num_decoder_layers (int): number of decoder layers + dropout (float): dropout rate + cnn_module_kernel (int): Kernel size of convolution module + normalize_before (bool): whether to use layer_norm before the first block. + vgg_frontend (bool): whether to use vgg frontend. + """ + + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + vgg_frontend: bool = False, + use_feat_batchnorm: Union[float, bool] = 0.1, + ) -> None: + super(Conformer, self).__init__( + num_features=num_features, + num_classes=num_classes, + subsampling_factor=subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dropout=dropout, + normalize_before=normalize_before, + vgg_frontend=vgg_frontend, + use_feat_batchnorm=use_feat_batchnorm, + ) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + use_conv_batchnorm = True + if isinstance(use_feat_batchnorm, float): + use_conv_batchnorm = False + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + cnn_module_kernel, + normalize_before, + use_conv_batchnorm, + ) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self.normalize_before = normalize_before + if self.normalize_before: + self.after_norm = nn.LayerNorm(d_model) + else: + # Note: TorchScript detects that self.after_norm could be used inside forward() + # and throws an error without this change. + self.after_norm = identity + + def run_encoder( + self, x: Tensor, supervisions: Optional[Supervisions] = None + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. + + Returns: + Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). + Tensor: Mask tensor of dimension (batch_size, input_length) + """ + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + mask = encoder_padding_mask(x.size(0), supervisions) + if mask is not None: + mask = mask.to(x.device) + x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + + if self.normalize_before: + x = self.after_norm(x) + + return x, mask + + +class ConformerEncoderLayer(nn.Module): + """ + ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + normalize_before: whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + use_conv_batchnorm: bool = False, + ) -> None: + super(ConformerEncoderLayer, self).__init__() + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.feed_forward_macaron = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm + ) + + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + + self.ff_scale = 0.5 + + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block + + self.dropout = nn.Dropout(dropout) + + self.normalize_before = normalize_before + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + # macaron style feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff_macaron(src) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) + if not self.normalize_before: + src = self.norm_ff_macaron(src) + + # multi-headed self-attention module + residual = src + if self.normalize_before: + src = self.norm_mha(src) + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout(src_att) + if not self.normalize_before: + src = self.norm_mha(src) + + # convolution module + residual = src + if self.normalize_before: + src = self.norm_conv(src) + src = residual + self.dropout(self.conv_module(src)) + if not self.normalize_before: + src = self.norm_conv(src) + + # feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + if not self.normalize_before: + src = self.norm_ff(src) + + if self.normalize_before: + src = self.norm_final(src) + + return src + + +class ConformerEncoder(nn.TransformerEncoder): + r"""ConformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ConformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = conformer_encoder(src, pos_emb) + """ + + def __init__( + self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None + ) -> None: + super(ConformerEncoder, self).__init__( + encoder_layer=encoder_layer, num_layers=num_layers, norm=norm + ) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for mod in self.layers: + output = mod( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.weight, + self.in_proj.bias, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + ) + + def rel_shift(self, x: Tensor) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + assert n == 2 * time1 - 1 + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time1), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + q_with_bias_u = (q + self.pos_bias_u).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self.pos_bias_v).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul( + q_with_bias_v, p.transpose(-2, -1) + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd) + + attn_output_weights = ( + matrix_ac + matrix_bd + ) * scaling # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) + + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + bias: bool = True, + use_batchnorm: bool = False, + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + self.use_batchnorm = use_batchnorm + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + if self.use_batchnorm: + self.norm = nn.BatchNorm1d(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward(self, x: Tensor) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + if self.use_batchnorm: + x = self.norm(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def identity(x): + return x diff --git a/egs/fisher_swbd/ASR/conformer_ctc/decode.py b/egs/fisher_swbd/ASR/conformer_ctc/decode.py new file mode 100755 index 000000000..177e33a6e --- /dev/null +++ b/egs/fisher_swbd/ASR/conformer_ctc/decode.py @@ -0,0 +1,701 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=77, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=55, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--method", + type=str, + default="attention-decoder", + help="""Decoding method. + Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (2) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (3) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + - (5) attention-decoder. Extract n paths from the LM rescored + lattice, the path with the highest score is the decoding result. + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_decoder_layers": 6, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + - params.method is "nbest", it uses nbest decoding without LM rescoring. + - params.method is "nbest-rescoring", it uses nbest LM rescoring. + - params.method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + + nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.method in ["1best", "nbest"]: + if params.method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + elif params.method == "attention-decoder": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + # TODO: pass `lattice` instead of `rescored_lattice` to + # `rescore_with_attention_decoder` + + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + nbest_scale=params.nbest_scale, + ) + else: + assert False, f"Unsupported decoding method: {params.method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + if hyps_dict is not None: + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + else: + assert ( + len(results) > 0 + ), "It should not decode to empty in the first batch!" + this_batch = [] + hyp_words = [] + for ref_text in texts: + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + for lm_scale in results.keys(): + results[lm_scale].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + if params.method == "attention-decoder": + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.method in ["whole-lattice-rescoring", "attention-decoder"]: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/fisher_swbd/ASR/conformer_ctc/export.py b/egs/fisher_swbd/ASR/conformer_ctc/export.py new file mode 100755 index 000000000..28c28df01 --- /dev/null +++ b/egs/fisher_swbd/ASR/conformer_ctc/export.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +import argparse +import logging +from pathlib import Path + +import torch +from conformer import Conformer + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""It contains language related input files such as "lexicon.txt" + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=True, + help="""True to save a model after applying torch.jit.script. + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + "num_decoder_layers": 6, + } + ) + return params + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/fisher_swbd/ASR/conformer_ctc/label_smoothing.py b/egs/fisher_swbd/ASR/conformer_ctc/label_smoothing.py new file mode 100644 index 000000000..cdc85ce9a --- /dev/null +++ b/egs/fisher_swbd/ASR/conformer_ctc/label_smoothing.py @@ -0,0 +1,98 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class LabelSmoothingLoss(torch.nn.Module): + """ + Implement the LabelSmoothingLoss proposed in the following paper + https://arxiv.org/pdf/1512.00567.pdf + (Rethinking the Inception Architecture for Computer Vision) + + """ + + def __init__( + self, + ignore_index: int = -1, + label_smoothing: float = 0.1, + reduction: str = "sum", + ) -> None: + """ + Args: + ignore_index: + ignored class id + label_smoothing: + smoothing rate (0.0 means the conventional cross entropy loss) + reduction: + It has the same meaning as the reduction in + `torch.nn.CrossEntropyLoss`. It can be one of the following three + values: (1) "none": No reduction will be applied. (2) "mean": the + mean of the output is taken. (3) "sum": the output will be summed. + """ + super().__init__() + assert 0.0 <= label_smoothing < 1.0 + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute loss between x and target. + + Args: + x: + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.ignore_index of + dimension (batch_size, input_length). + + Returns: + A scalar tensor containing the loss without normalization. + """ + assert x.ndim == 3 + assert target.ndim == 2 + assert x.shape[:2] == target.shape + num_classes = x.size(-1) + x = x.reshape(-1, num_classes) + # Now x is of shape (N*T, C) + + # We don't want to change target in-place below, + # so we make a copy of it here + target = target.clone().reshape(-1) + + ignored = target == self.ignore_index + target[ignored] = 0 + + true_dist = torch.nn.functional.one_hot( + target, num_classes=num_classes + ).to(x) + + true_dist = ( + true_dist * (1 - self.label_smoothing) + + self.label_smoothing / num_classes + ) + # Set the value of ignored indexes to 0 + true_dist[ignored] = 0 + + loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) + if self.reduction == "sum": + return loss.sum() + elif self.reduction == "mean": + return loss.sum() / (~ignored).sum() + else: + return loss.sum(dim=-1) diff --git a/egs/fisher_swbd/ASR/conformer_ctc/pretrained.py b/egs/fisher_swbd/ASR/conformer_ctc/pretrained.py new file mode 100755 index 000000000..28724e1eb --- /dev/null +++ b/egs/fisher_swbd/ASR/conformer_ctc/pretrained.py @@ -0,0 +1,435 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from conformer import Conformer +from torch.nn.utils.rnn import pad_sequence + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_whole_lattice, +) +from icefall.utils import AttributeDict, get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + (3) attention-decoder - Extract n paths from the rescored + lattice and use the transformer attention decoder for + rescoring. + We call it HLG decoding + n-gram LM rescoring + attention + decoder rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or attention-decoder. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and attention-decoder. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--attention-decoder-scale", + type=float, + default=1.2, + help=""" + Used only when method is attention-decoder. + It specifies the scale for attention decoder scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is attention-decoder. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--sos-id", + type=int, + default=1, + help=""" + Used only when method is attention-decoder. + It specifies ID for the SOS token. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--eos-id", + type=int, + default=1, + help=""" + Used only when method is attention-decoder. + It specifies ID for the EOS token. + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "sample_rate": 16000, + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_decoder_layers": 6, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + if args.method != "attention-decoder": + # to save memory as the attention decoder + # will not be used + params.num_decoder_layers = 0 + + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=params.num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + # Note: We don't use key padding mask for attention during decoding + with torch.no_grad(): + nnet_output, memory, memory_key_padding_mask = model(features) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "whole-lattice-rescoring", + "attention-decoder", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = G.to(device) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "attention-decoder": + logging.info("Use HLG + LM rescoring + attention decoder rescoring") + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + ) + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=params.sos_id, + eos_id=params.eos_id, + nbest_scale=params.nbest_scale, + ngram_lm_scale=params.ngram_lm_scale, + attention_scale=params.attention_decoder_scale, + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/fisher_swbd/ASR/conformer_ctc/subsampling.py b/egs/fisher_swbd/ASR/conformer_ctc/subsampling.py new file mode 100644 index 000000000..542fb0364 --- /dev/null +++ b/egs/fisher_swbd/ASR/conformer_ctc/subsampling.py @@ -0,0 +1,161 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__(self, idim: int, odim: int) -> None: + """ + Args: + idim: + Input dim. The input shape is (N, T, idim). + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) + """ + assert idim >= 7 + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), + nn.ReLU(), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), + nn.ReLU(), + ) + self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + return x + + +class VggSubsampling(nn.Module): + """Trying to follow the setup described in the following paper: + https://arxiv.org/pdf/1910.09799.pdf + + This paper is not 100% explicit so I am guessing to some extent, + and trying to compare with other VGG implementations. + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 + """ + + def __init__(self, idim: int, odim: int) -> None: + """Construct a VggSubsampling object. + + This uses 2 VGG blocks with 2 Conv2d layers each, + subsampling its input by a factor of 4 in the time dimensions. + + Args: + idim: + Input dim. The input shape is (N, T, idim). + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, odim) + """ + super().__init__() + + cur_channels = 1 + layers = [] + block_dims = [32, 64] + + # The decision to use padding=1 for the 1st convolution, then padding=0 + # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by + # a back-compatibility concern so that the number of frames at the + # output would be equal to: + # (((T-1)//2)-1)//2. + # We can consider changing this by using padding=1 on the + # 2nd convolution, so the num-frames at the output would be T//4. + for block_dim in block_dims: + layers.append( + torch.nn.Conv2d( + in_channels=cur_channels, + out_channels=block_dim, + kernel_size=3, + padding=1, + stride=1, + ) + ) + layers.append(torch.nn.ReLU()) + layers.append( + torch.nn.Conv2d( + in_channels=block_dim, + out_channels=block_dim, + kernel_size=3, + padding=0, + stride=1, + ) + ) + layers.append( + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) + ) + cur_channels = block_dim + + self.layers = nn.Sequential(*layers) + + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + x = x.unsqueeze(1) + x = self.layers(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + return x diff --git a/egs/fisher_swbd/ASR/conformer_ctc/test_label_smoothing.py b/egs/fisher_swbd/ASR/conformer_ctc/test_label_smoothing.py new file mode 100755 index 000000000..5d4438fd1 --- /dev/null +++ b/egs/fisher_swbd/ASR/conformer_ctc/test_label_smoothing.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distutils.version import LooseVersion + +import torch +from label_smoothing import LabelSmoothingLoss + +torch_ver = LooseVersion(torch.__version__) + + +def test_with_torch_label_smoothing_loss(): + if torch_ver < LooseVersion("1.10.0"): + print(f"Current torch version: {torch_ver}") + print("Please use torch >= 1.10 to run this test - skipping") + return + torch.manual_seed(20211105) + x = torch.rand(20, 30, 5000) + tgt = torch.randint(low=-1, high=x.size(-1), size=x.shape[:2]) + for reduction in ["none", "sum", "mean"]: + custom_loss_func = LabelSmoothingLoss( + ignore_index=-1, label_smoothing=0.1, reduction=reduction + ) + custom_loss = custom_loss_func(x, tgt) + + torch_loss_func = torch.nn.CrossEntropyLoss( + ignore_index=-1, reduction=reduction, label_smoothing=0.1 + ) + torch_loss = torch_loss_func(x.reshape(-1, x.size(-1)), tgt.reshape(-1)) + assert torch.allclose(custom_loss, torch_loss) + + +def main(): + test_with_torch_label_smoothing_loss() + + +if __name__ == "__main__": + main() diff --git a/egs/fisher_swbd/ASR/conformer_ctc/test_subsampling.py b/egs/fisher_swbd/ASR/conformer_ctc/test_subsampling.py new file mode 100755 index 000000000..81fa234dd --- /dev/null +++ b/egs/fisher_swbd/ASR/conformer_ctc/test_subsampling.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from subsampling import Conv2dSubsampling, VggSubsampling + + +def test_conv2d_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = Conv2dSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim + + +def test_vgg_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = VggSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim diff --git a/egs/fisher_swbd/ASR/conformer_ctc/test_transformer.py b/egs/fisher_swbd/ASR/conformer_ctc/test_transformer.py new file mode 100644 index 000000000..667057c51 --- /dev/null +++ b/egs/fisher_swbd/ASR/conformer_ctc/test_transformer.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from torch.nn.utils.rnn import pad_sequence +from transformer import ( + Transformer, + add_eos, + add_sos, + decoder_padding_mask, + encoder_padding_mask, + generate_square_subsequent_mask, +) + + +def test_encoder_padding_mask(): + supervisions = { + "sequence_idx": torch.tensor([0, 1, 2]), + "start_frame": torch.tensor([0, 0, 0]), + "num_frames": torch.tensor([18, 7, 13]), + } + + max_len = ((18 - 1) // 2 - 1) // 2 + mask = encoder_padding_mask(max_len, supervisions) + expected_mask = torch.tensor( + [ + [False, False, False], # ((18 - 1)//2 - 1)//2 = 3, + [False, True, True], # ((7 - 1)//2 - 1)//2 = 1, + [False, False, True], # ((13 - 1)//2 - 1)//2 = 2, + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_transformer(): + num_features = 40 + num_classes = 87 + model = Transformer(num_features=num_features, num_classes=num_classes) + + N = 31 + + for T in range(7, 30): + x = torch.rand(N, T, num_features) + y, _, _ = model(x) + assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) + + +def test_generate_square_subsequent_mask(): + s = 5 + mask = generate_square_subsequent_mask(s) + inf = float("inf") + expected_mask = torch.tensor( + [ + [0.0, -inf, -inf, -inf, -inf], + [0.0, 0.0, -inf, -inf, -inf], + [0.0, 0.0, 0.0, -inf, -inf], + [0.0, 0.0, 0.0, 0.0, -inf], + [0.0, 0.0, 0.0, 0.0, 0.0], + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_decoder_padding_mask(): + x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])] + y = pad_sequence(x, batch_first=True, padding_value=-1) + mask = decoder_padding_mask(y, ignore_id=-1) + expected_mask = torch.tensor( + [ + [False, False, True], + [False, True, True], + [False, False, False], + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_add_sos(): + x = [[1, 2], [3], [2, 5, 8]] + y = add_sos(x, sos_id=0) + expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]] + assert y == expected_y + + +def test_add_eos(): + x = [[1, 2], [3], [2, 5, 8]] + y = add_eos(x, eos_id=0) + expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]] + assert y == expected_y diff --git a/egs/fisher_swbd/ASR/conformer_ctc/train.py b/egs/fisher_swbd/ASR/conformer_ctc/train.py new file mode 100755 index 000000000..c1fa814c0 --- /dev/null +++ b/egs/fisher_swbd/ASR/conformer_ctc/train.py @@ -0,0 +1,741 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=78, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + conformer_ctc/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.8, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=5.0, + help="The lr_factor for Noam optimizer", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Normalization for the input features, can be a + boolean indicating whether to do batch + normalization, or a float which means just scaling + the input features with this float value. + If given a float value, we will remove batchnorm + layer in `ConvolutionModule` as well. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + + - weight_decay: The weight_decay for the optimizer. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + "num_decoder_layers": 6, + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for Noam + "weight_decay": 1e-6, + "warm_step": 80000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: BpeCtcTrainingGraphCompiler, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = graph_compiler.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + token_ids = graph_compiler.texts_to_ids(texts) + + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + if params.att_rate != 0.0: + with torch.set_grad_enabled(is_training): + mmodel = model.module if hasattr(model, "module") else model + # Note: We need to generate an unsorted version of token_ids + # `encode_supervisions()` called above sorts text, but + # encoder_memory and memory_mask are not sorted, so we + # use an unsorted version `supervisions["text"]` to regenerate + # the token_ids + # + # See https://github.com/k2-fsa/icefall/issues/97 + # for more details + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + else: + loss = ctc_loss + att_loss = torch.tensor([0]) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + info["loss"] = loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: BpeCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if batch_idx % params.log_interval == 0: + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(42) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + train_dl = librispeech.train_dataloaders(train_cuts) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/fisher_swbd/ASR/conformer_ctc/transformer.py b/egs/fisher_swbd/ASR/conformer_ctc/transformer.py new file mode 100644 index 000000000..00ca027a7 --- /dev/null +++ b/egs/fisher_swbd/ASR/conformer_ctc/transformer.py @@ -0,0 +1,953 @@ +# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from label_smoothing import LabelSmoothingLoss +from subsampling import Conv2dSubsampling, VggSubsampling +from torch.nn.utils.rnn import pad_sequence + +# Note: TorchScript requires Dict/List/etc. to be fully typed. +Supervisions = Dict[str, torch.Tensor] + + +class Transformer(nn.Module): + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + normalize_before: bool = True, + vgg_frontend: bool = False, + use_feat_batchnorm: Union[float, bool] = 0.1, + ) -> None: + """ + Args: + num_features: + The input dimension of the model. + num_classes: + The output dimension of the model. + subsampling_factor: + Number of output frames is num_in_frames // subsampling_factor. + Currently, subsampling_factor MUST be 4. + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder/decoder. + num_encoder_layers: + Number of encoder layers. + num_decoder_layers: + Number of decoder layers. + dropout: + Dropout in encoder/decoder. + normalize_before: + If True, use pre-layer norm; False to use post-layer norm. + vgg_frontend: + True to use vgg style frontend for subsampling. + use_feat_batchnorm: + True to use batchnorm for the input layer. + Float value to scale the input layer. + False to do nothing. + """ + super().__init__() + self.use_feat_batchnorm = use_feat_batchnorm + assert isinstance(use_feat_batchnorm, (float, bool)) + if isinstance(use_feat_batchnorm, bool) and use_feat_batchnorm: + self.feat_batchnorm = nn.BatchNorm1d(num_features) + + self.num_features = num_features + self.num_classes = num_classes + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_classes) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_classes -> d_model + if vgg_frontend: + self.encoder_embed = VggSubsampling(num_features, d_model) + else: + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_pos = PositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + encoder_norm = nn.LayerNorm(d_model) + else: + encoder_norm = None + + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + norm=encoder_norm, + ) + + # TODO(fangjun): remove dropout + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) + ) + + if num_decoder_layers > 0: + self.decoder_num_class = ( + self.num_classes + ) # bpe model already has sos/eos symbol + + self.decoder_embed = nn.Embedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model + ) + self.decoder_pos = PositionalEncoding(d_model, dropout) + + decoder_layer = TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + decoder_norm = nn.LayerNorm(d_model) + else: + decoder_norm = None + + self.decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=num_decoder_layers, + norm=decoder_norm, + ) + + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) + + self.decoder_criterion = LabelSmoothingLoss() + else: + self.decoder_criterion = None + + def forward( + self, x: torch.Tensor, supervision: Optional[Supervisions] = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (N, T, C). + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Return a tuple containing 3 tensors: + - CTC output for ctc decoding. Its shape is (N, T, C) + - Encoder output with shape (T, N, C). It can be used as key and + value for the decoder. + - Encoder output padding mask. It can be used as + memory_key_padding_mask for the decoder. Its shape is (N, T). + It is None if `supervision` is None. + """ + if ( + isinstance(self.use_feat_batchnorm, bool) + and self.use_feat_batchnorm + ): + x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) + x = self.feat_batchnorm(x) + x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) + if isinstance(self.use_feat_batchnorm, float): + x *= self.use_feat_batchnorm + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) + x = self.ctc_output(encoder_memory) + return x, encoder_memory, memory_key_padding_mask + + def run_encoder( + self, x: torch.Tensor, supervisions: Optional[Supervisions] = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Run the transformer encoder. + + Args: + x: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute the encoder padding mask, which is used as memory key + padding mask for the decoder. + Returns: + Return a tuple with two tensors: + - The encoder output, with shape (T, N, C) + - encoder padding mask, with shape (N, T). + The mask is None if `supervisions` is None. + It is used as memory key padding mask in the decoder. + """ + x = self.encoder_embed(x) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + mask = encoder_padding_mask(x.size(0), supervisions) + mask = mask.to(x.device) if mask is not None else None + x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) + + return x, mask + + def ctc_output(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + The output tensor from the transformer encoder. + Its shape is (T, N, C) + + Returns: + Return a tensor that can be used for CTC decoding. + Its shape is (N, T, C) + """ + x = self.encoder_output_layer(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) + return x + + @torch.jit.export + def decoder_forward( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs. Each sublist contains IDs for an utterance. + The IDs can be either phone IDs or word piece IDs. + sos_id: + sos token id + eos_id: + eos token id + + Returns: + A scalar, the **sum** of label smoothing loss over utterances + in the batch without any normalization. + """ + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) + + device = memory.device + ys_in_pad = ys_in_pad.to(device) + ys_out_pad = ys_out_pad.to(device) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) + + decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) + + return decoder_loss + + @torch.jit.export + def decoder_nll( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[torch.Tensor], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs (e.g., word piece IDs). + Each sublist represents an utterance. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + Returns: + A 2-D tensor of shape (len(token_ids), max_token_length) + representing the cross entropy loss (i.e., negative log-likelihood). + """ + # The common part between this function and decoder_forward could be + # extracted as a separate function. + if isinstance(token_ids[0], torch.Tensor): + # This branch is executed by torchscript in C++. + # See https://github.com/k2-fsa/k2/pull/870 + # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 + token_ids = [tolist(t) for t in token_ids] + + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) + + device = memory.device + ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, B, F) + pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) + pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + pred_pad.view(-1, self.decoder_num_class), + ys_out_pad.view(-1), + ignore_index=-1, + reduction="none", + ) + + nll = nll.view(pred_pad.shape[0], -1) + + return nll + + +class TransformerEncoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerEncoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + normalize_before: + whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerEncoderLayer, self).__setstate__(state) + + def forward( + self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional) + + Shape: + src: (S, N, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + residual = src + if self.normalize_before: + src = self.norm1(src) + src2 = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout1(src2) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src2) + if not self.normalize_before: + src = self.norm2(src) + return src + + +class TransformerDecoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerDecoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerDecoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerDecoderLayer, self).__setstate__(state) + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: + the sequence to the decoder layer (required). + memory: + the sequence from the last layer of the encoder (required). + tgt_mask: + the mask for the tgt sequence (optional). + memory_mask: + the mask for the memory sequence (optional). + tgt_key_padding_mask: + the mask for the tgt keys per batch (optional). + memory_key_padding_mask: + the mask for the memory keys per batch (optional). + + Shape: + tgt: (T, N, E). + memory: (S, N, E). + tgt_mask: (T, T). + memory_mask: (T, S). + tgt_key_padding_mask: (N, T). + memory_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + tgt2 = self.self_attn( + tgt, + tgt, + tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask, + )[0] + tgt = residual + self.dropout1(tgt2) + if not self.normalize_before: + tgt = self.norm1(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm2(tgt) + tgt2 = self.src_attn( + tgt, + memory, + memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = residual + self.dropout2(tgt2) + if not self.normalize_before: + tgt = self.norm2(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = residual + self.dropout3(tgt2) + if not self.normalize_before: + tgt = self.norm3(tgt) + return tgt + + +def _get_activation_fn(activation: str): + if activation == "relu": + return nn.functional.relu + elif activation == "gelu": + return nn.functional.gelu + + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) + + +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) + + Note:: + + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) + """ + + def __init__(self, d_model: int, dropout: float = 0.1) -> None: + """ + Args: + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. + """ + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout) + # not doing: self.pe = None because of errors thrown by torchscript + self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) + + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is (1, T1, d_model). The shape of the input x + is (N, T, d_model). If T > T1, then we change the shape of self.pe + to (N, T, d_model). Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape (N, T, C). + Returns: + Return None. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + # Now pe is of shape (1, T, d_model), where T is x.size(1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding. + + Args: + x: + Its shape is (N, T, C) + + Returns: + Return a tensor of shape (N, T, C) + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1), :] + return self.dropout(x) + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) + + +def encoder_padding_mask( + max_len: int, supervisions: Optional[Supervisions] = None +) -> Optional[torch.Tensor]: + """Make mask tensor containing indexes of padded part. + + TODO:: + This function **assumes** that the model uses + a subsampling factor of 4. We should remove that + assumption later. + + Args: + max_len: + Maximum length of input features. + CAUTION: It is the length after subsampling. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Tensor: Mask tensor of dimension (batch_size, input_length), + True denote the masked indices. + """ + if supervisions is None: + return None + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"], + supervisions["num_frames"], + ), + 1, + ).to(torch.int32) + + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] + for idx in range(supervision_segments.size(0)): + # Note: TorchScript doesn't allow to unpack tensors as tuples + sequence_idx = supervision_segments[idx, 0].item() + start_frame = supervision_segments[idx, 1].item() + num_frames = supervision_segments[idx, 2].item() + lengths[sequence_idx] = start_frame + num_frames + + lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] + bs = int(len(lengths)) + seq_range = torch.arange(0, max_len, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) + # Note: TorchScript doesn't implement Tensor.new() + seq_length_expand = torch.tensor( + lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype + ).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + return mask + + +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: + """Generate a length mask for input. + + The masked position are filled with True, + Unmasked positions are filled with False. + + Args: + ys_pad: + padded tensor of dimension (batch_size, input_length). + ignore_id: + the ignored number (the padding number) in ys_pad + + Returns: + Tensor: + a bool tensor of the same shape as the input tensor. + """ + ys_mask = ys_pad == ignore_id + return ys_mask + + +def generate_square_subsequent_mask(sz: int) -> torch.Tensor: + """Generate a square mask for the sequence. The masked positions are + filled with float('-inf'). Unmasked positions are filled with float(0.0). + The mask can be used for masked self-attention. + + For instance, if sz is 3, it returns:: + + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0]]) + + Args: + sz: mask size + + Returns: + A square mask of dimension (sz, sz) + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask + + +def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: + """Prepend sos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + sos_id: + The ID of the SOS token. + + Return: + Return a new list-of-list, where each sublist starts + with SOS ID. + """ + return [[sos_id] + utt for utt in token_ids] + + +def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: + """Append eos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + eos_id: + The ID of the EOS token. + + Return: + Return a new list-of-list, where each sublist ends + with EOS ID. + """ + return [utt + [eos_id] for utt in token_ids] + + +def tolist(t: torch.Tensor) -> List[int]: + """Used by jit""" + return torch.jit.annotate(List[int], t.tolist()) diff --git a/egs/fisher_swbd/ASR/local/__init__.py b/egs/fisher_swbd/ASR/local/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/fisher_swbd/ASR/local/compile_hlg.py b/egs/fisher_swbd/ASR/local/compile_hlg.py new file mode 100755 index 000000000..9a35750e0 --- /dev/null +++ b/egs/fisher_swbd/ASR/local/compile_hlg.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input lang_dir and generates HLG from + + - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt + - L, the lexicon, built from lang_dir/L_disambig.pt + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from data/lm/G_3_gram.fst.txt + +The generated HLG is saved in $lang_dir/HLG.pt +""" +import argparse +import logging +from pathlib import Path + +import k2 +import torch + +from icefall.lexicon import Lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + +def compile_HLG(lang_dir: str) -> k2.Fsa: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + + Return: + An FSA representing HLG. + """ + lexicon = Lexicon(lang_dir) + max_token_id = max(lexicon.tokens) + logging.info(f"Building ctc_topo. max_token_id: {max_token_id}") + H = k2.ctc_topo(max_token_id) + L = k2.Fsa.from_dict(torch.load(f"{lang_dir}/L_disambig.pt")) + + if Path("data/lm/G_3_gram.pt").is_file(): + logging.info("Loading pre-compiled G_3_gram") + d = torch.load("data/lm/G_3_gram.pt") + G = k2.Fsa.from_dict(d) + else: + logging.info("Loading G_3_gram.fst.txt") + with open("data/lm/G_3_gram.fst.txt") as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + torch.save(G.as_dict(), "data/lm/G_3_gram.pt") + + first_token_disambig_id = lexicon.token_table["#0"] + first_word_disambig_id = lexicon.word_table["#0"] + + L = k2.arc_sort(L) + G = k2.arc_sort(G) + + logging.info("Intersecting L and G") + LG = k2.compose(L, G) + logging.info(f"LG shape: {LG.shape}") + + logging.info("Connecting LG") + LG = k2.connect(LG) + logging.info(f"LG shape after k2.connect: {LG.shape}") + + logging.info(type(LG.aux_labels)) + logging.info("Determinizing LG") + + LG = k2.determinize(LG) + logging.info(type(LG.aux_labels)) + + logging.info("Connecting LG after k2.determinize") + LG = k2.connect(LG) + + logging.info("Removing disambiguation symbols on LG") + + LG.labels[LG.labels >= first_token_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set LG.properties to None + LG.__dict__["_properties"] = None + + assert isinstance(LG.aux_labels, k2.RaggedTensor) + LG.aux_labels.values[LG.aux_labels.values >= first_word_disambig_id] = 0 + + LG = k2.remove_epsilon(LG) + logging.info(f"LG shape after k2.remove_epsilon: {LG.shape}") + + LG = k2.connect(LG) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + logging.info("Composing H and LG") + # CAUTION: The name of the inner_labels is fixed + # to `tokens`. If you want to change it, please + # also change other places in icefall that are using + # it. + HLG = k2.compose(H, LG, inner_labels="tokens") + + logging.info("Connecting LG") + HLG = k2.connect(HLG) + + logging.info("Arc sorting LG") + HLG = k2.arc_sort(HLG) + logging.info(f"HLG.shape: {HLG.shape}") + + return HLG + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + if (lang_dir / "HLG.pt").is_file(): + logging.info(f"{lang_dir}/HLG.pt already exists - skipping") + return + + logging.info(f"Processing {lang_dir}") + + HLG = compile_HLG(lang_dir) + logging.info(f"Saving HLG.pt to {lang_dir}") + torch.save(HLG.as_dict(), f"{lang_dir}/HLG.pt") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/fisher_swbd/ASR/local/compute_fbank_librispeech.py b/egs/fisher_swbd/ASR/local/compute_fbank_librispeech.py new file mode 100755 index 000000000..b26034eb2 --- /dev/null +++ b/egs/fisher_swbd/ASR/local/compute_fbank_librispeech.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the LibriSpeech dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_librispeech(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + num_jobs = min(15, os.cpu_count()) + num_mel_bins = 80 + + dataset_parts = ( + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ) + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, output_dir=src_dir + ) + assert manifests is not None + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + if (output_dir / f"cuts_{partition}.json.gz").is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if "train" in partition: + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomHdf5Writer, + ) + cut_set.to_json(output_dir / f"cuts_{partition}.json.gz") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + compute_fbank_librispeech() diff --git a/egs/fisher_swbd/ASR/local/compute_fbank_musan.py b/egs/fisher_swbd/ASR/local/compute_fbank_musan.py new file mode 100755 index 000000000..d44524e70 --- /dev/null +++ b/egs/fisher_swbd/ASR/local/compute_fbank_musan.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the musan dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_musan(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + num_jobs = min(15, os.cpu_count()) + num_mel_bins = 80 + + dataset_parts = ( + "music", + "speech", + "noise", + ) + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, output_dir=src_dir + ) + assert manifests is not None + + musan_cuts_path = output_dir / "cuts_musan.json.gz" + + if musan_cuts_path.is_file(): + logging.info(f"{musan_cuts_path} already exists - skipping") + return + + logging.info("Extracting features for Musan") + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + # create chunks of Musan with duration 5 - 10 seconds + musan_cuts = ( + CutSet.from_manifests( + recordings=combine( + part["recordings"] for part in manifests.values() + ) + ) + .cut_into_windows(10.0) + .filter(lambda c: c.duration > 5) + .compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/feats_musan", + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomHdf5Writer, + ) + ) + musan_cuts.to_json(musan_cuts_path) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_fbank_musan() diff --git a/egs/fisher_swbd/ASR/local/convert_transcript_words_to_tokens.py b/egs/fisher_swbd/ASR/local/convert_transcript_words_to_tokens.py new file mode 100755 index 000000000..133499c8b --- /dev/null +++ b/egs/fisher_swbd/ASR/local/convert_transcript_words_to_tokens.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +""" +Convert a transcript file containing words to a corpus file containing tokens +for LM training with the help of a lexicon. + +If the lexicon contains phones, the resulting LM will be a phone LM; If the +lexicon contains word pieces, the resulting LM will be a word piece LM. + +If a word has multiple pronunciations, the one that appears first in the lexicon +is kept; others are removed. + +If the input transcript is: + + hello zoo world hello + world zoo + foo zoo world hellO + +and if the lexicon is + + SPN + hello h e l l o 2 + hello h e l l o + world w o r l d + zoo z o o + +Then the output is + + h e l l o 2 z o o w o r l d h e l l o 2 + w o r l d z o o + SPN z o o w o r l d SPN +""" + +import argparse +from pathlib import Path +from typing import Dict, List + +from generate_unique_lexicon import filter_multiple_pronunications + +from icefall.lexicon import read_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transcript", + type=str, + help="The input transcript file." + "We assume that the transcript file consists of " + "lines. Each line consists of space separated words.", + ) + parser.add_argument("--lexicon", type=str, help="The input lexicon file.") + parser.add_argument( + "--oov", type=str, default="", help="The OOV word." + ) + + return parser.parse_args() + + +def process_line( + lexicon: Dict[str, List[str]], line: str, oov_token: str +) -> None: + """ + Args: + lexicon: + A dict containing pronunciations. Its keys are words and values + are pronunciations (i.e., tokens). + line: + A line of transcript consisting of space(s) separated words. + oov_token: + The pronunciation of the oov word if a word in `line` is not present + in the lexicon. + Returns: + Return None. + """ + s = "" + words = line.strip().split() + for i, w in enumerate(words): + tokens = lexicon.get(w, oov_token) + s += " ".join(tokens) + s += " " + print(s.strip()) + + +def main(): + args = get_args() + assert Path(args.lexicon).is_file() + assert Path(args.transcript).is_file() + assert len(args.oov) > 0 + + # Only the first pronunciation of a word is kept + lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon)) + + lexicon = dict(lexicon) + + assert args.oov in lexicon + + oov_token = lexicon[args.oov] + + with open(args.transcript) as f: + for line in f: + process_line(lexicon=lexicon, line=line, oov_token=oov_token) + + +if __name__ == "__main__": + main() diff --git a/egs/fisher_swbd/ASR/local/display_manifest_statistics.py b/egs/fisher_swbd/ASR/local/display_manifest_statistics.py new file mode 100755 index 000000000..15bd206fa --- /dev/null +++ b/egs/fisher_swbd/ASR/local/display_manifest_statistics.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in transducer/train.py +for usage. +""" + + +from lhotse import load_manifest + + +def main(): + path = "./data/fbank/cuts_train-clean-100.json.gz" + path = "./data/fbank/cuts_train-clean-360.json.gz" + path = "./data/fbank/cuts_train-other-500.json.gz" + path = "./data/fbank/cuts_dev-clean.json.gz" + path = "./data/fbank/cuts_dev-other.json.gz" + path = "./data/fbank/cuts_test-clean.json.gz" + path = "./data/fbank/cuts_test-other.json.gz" + + cuts = load_manifest(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +## train-clean-100 +Cuts count: 85617 +Total duration (hours): 303.8 +Speech duration (hours): 303.8 (100.0%) +*** +Duration statistics (seconds): +mean 12.8 +std 3.8 +min 1.3 +0.1% 1.9 +0.5% 2.2 +1% 2.5 +5% 4.2 +10% 6.4 +25% 11.4 +50% 13.8 +75% 15.3 +90% 16.7 +95% 17.3 +99% 18.1 +99.5% 18.4 +99.9% 18.8 +max 27.2 + +## train-clean-360 +Cuts count: 312042 +Total duration (hours): 1098.2 +Speech duration (hours): 1098.2 (100.0%) +*** +Duration statistics (seconds): +mean 12.7 +std 3.8 +min 1.0 +0.1% 1.8 +0.5% 2.2 +1% 2.5 +5% 4.2 +10% 6.2 +25% 11.2 +50% 13.7 +75% 15.3 +90% 16.6 +95% 17.3 +99% 18.1 +99.5% 18.4 +99.9% 18.8 +max 33.0 + +## train-other 500 +Cuts count: 446064 +Total duration (hours): 1500.6 +Speech duration (hours): 1500.6 (100.0%) +*** +Duration statistics (seconds): +mean 12.1 +std 4.2 +min 0.8 +0.1% 1.7 +0.5% 2.1 +1% 2.3 +5% 3.5 +10% 5.0 +25% 9.8 +50% 13.4 +75% 15.1 +90% 16.5 +95% 17.2 +99% 18.1 +99.5% 18.4 +99.9% 18.9 +max 31.0 + +## dev-clean +Cuts count: 2703 +Total duration (hours): 5.4 +Speech duration (hours): 5.4 (100.0%) +*** +Duration statistics (seconds): +mean 7.2 +std 4.7 +min 1.4 +0.1% 1.6 +0.5% 1.8 +1% 1.9 +5% 2.4 +10% 2.7 +25% 3.8 +50% 5.9 +75% 9.3 +90% 13.3 +95% 16.4 +99% 23.8 +99.5% 28.5 +99.9% 32.3 +max 32.6 + +## dev-other +Cuts count: 2864 +Total duration (hours): 5.1 +Speech duration (hours): 5.1 (100.0%) +*** +Duration statistics (seconds): +mean 6.4 +std 4.3 +min 1.1 +0.1% 1.3 +0.5% 1.7 +1% 1.8 +5% 2.2 +10% 2.6 +25% 3.5 +50% 5.3 +75% 7.9 +90% 12.0 +95% 15.0 +99% 22.2 +99.5% 27.1 +99.9% 32.4 +max 35.2 + +## test-clean +Cuts count: 2620 +Total duration (hours): 5.4 +Speech duration (hours): 5.4 (100.0%) +*** +Duration statistics (seconds): +mean 7.4 +std 5.2 +min 1.3 +0.1% 1.6 +0.5% 1.8 +1% 2.0 +5% 2.3 +10% 2.7 +25% 3.7 +50% 5.8 +75% 9.6 +90% 14.6 +95% 17.8 +99% 25.5 +99.5% 28.4 +99.9% 32.8 +max 35.0 + +## test-other +Cuts count: 2939 +Total duration (hours): 5.3 +Speech duration (hours): 5.3 (100.0%) +*** +Duration statistics (seconds): +mean 6.5 +std 4.4 +min 1.2 +0.1% 1.5 +0.5% 1.8 +1% 1.9 +5% 2.3 +10% 2.6 +25% 3.4 +50% 5.2 +75% 8.2 +90% 12.6 +95% 15.8 +99% 21.4 +99.5% 23.8 +99.9% 33.5 +max 34.5 +""" diff --git a/egs/fisher_swbd/ASR/local/download_lm.py b/egs/fisher_swbd/ASR/local/download_lm.py new file mode 100755 index 000000000..94d23afed --- /dev/null +++ b/egs/fisher_swbd/ASR/local/download_lm.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file downloads the following LibriSpeech LM files: + + - 3-gram.pruned.1e-7.arpa.gz + - 4-gram.arpa.gz + - librispeech-vocab.txt + - librispeech-lexicon.txt + +from http://www.openslr.org/resources/11 +and save them in the user provided directory. + +Files are not re-downloaded if they already exist. + +Usage: + ./local/download_lm.py --out-dir ./download/lm +""" + +import argparse +import gzip +import logging +import os +import shutil +from pathlib import Path + +from lhotse.utils import urlretrieve_progress +from tqdm.auto import tqdm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--out-dir", type=str, help="Output directory.") + + args = parser.parse_args() + return args + + +def main(out_dir: str): + url = "http://www.openslr.org/resources/11" + out_dir = Path(out_dir) + + files_to_download = ( + "3-gram.pruned.1e-7.arpa.gz", + "4-gram.arpa.gz", + "librispeech-vocab.txt", + "librispeech-lexicon.txt", + ) + + for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"): + filename = out_dir / f + if filename.is_file() is False: + urlretrieve_progress( + f"{url}/{f}", + filename=filename, + desc=f"Downloading {filename}", + ) + else: + logging.info(f"{filename} already exists - skipping") + + if ".gz" in str(filename): + unzipped = Path(os.path.splitext(filename)[0]) + if unzipped.is_file() is False: + with gzip.open(filename, "rb") as f_in: + with open(unzipped, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + else: + logging.info(f"{unzipped} already exist - skipping") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + logging.info(f"out_dir: {args.out_dir}") + + main(out_dir=args.out_dir) diff --git a/egs/fisher_swbd/ASR/local/generate_unique_lexicon.py b/egs/fisher_swbd/ASR/local/generate_unique_lexicon.py new file mode 100755 index 000000000..566c0743d --- /dev/null +++ b/egs/fisher_swbd/ASR/local/generate_unique_lexicon.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file takes as input a lexicon.txt and output a new lexicon, +in which each word has a unique pronunciation. + +The way to do this is to keep only the first pronunciation of a word +in lexicon.txt. +""" + + +import argparse +import logging +from pathlib import Path +from typing import List, Tuple + +from icefall.lexicon import read_lexicon, write_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file lexicon.txt. + This file will generate a new file uniq_lexicon.txt + in it. + """, + ) + + return parser.parse_args() + + +def filter_multiple_pronunications( + lexicon: List[Tuple[str, List[str]]] +) -> List[Tuple[str, List[str]]]: + """Remove multiple pronunciations of words from a lexicon. + + If a word has more than one pronunciation in the lexicon, only + the first one is kept, while other pronunciations are removed + from the lexicon. + + Args: + lexicon: + The input lexicon, containing a list of (word, [p1, p2, ..., pn]), + where "p1, p2, ..., pn" are the pronunciations of the "word". + Returns: + Return a new lexicon where each word has a unique pronunciation. + """ + seen = set() + ans = [] + + for word, tokens in lexicon: + if word in seen: + continue + seen.add(word) + ans.append((word, tokens)) + return ans + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + lexicon_filename = lang_dir / "lexicon.txt" + + in_lexicon = read_lexicon(lexicon_filename) + + out_lexicon = filter_multiple_pronunications(in_lexicon) + + write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon) + + logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}") + logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py b/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py new file mode 100644 index 000000000..be9715578 --- /dev/null +++ b/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python3 + +import argparse +import re +from typing import Tuple + +from tqdm import tqdm + +from lhotse import SupervisionSet, SupervisionSegment +from lhotse.serialization import load_manifest_lazy_or_eager + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("input_sups") + parser.add_argument("output_sups") + return parser.parse_args() + + +# Note: the functions "normalize" and "keep" implement the logic similar to +# Kaldi's data prep scripts for Fisher: +# https://github.com/kaldi-asr/kaldi/blob/master/egs/fisher_swbd/s5/local/fisher_data_prep.sh +# and for SWBD: +# https://github.com/kaldi-asr/kaldi/blob/master/egs/fisher_swbd/s5/local/swbd1_data_prep.sh + + +class Normalizer: + def __init__(self) -> None: + # tuples of (pattern, replacement) + # note: Kaldi replaces sighs, coughs, etc with [noise]. + # We don't do that here. + # We also uppercase the text as the first operation. + self.replace_regexps: Tuple[re.Pattern, str] = [ + # SWBD: + # [LAUGHTER-STORY] -> STORY + (re.compile(r"\[LAUGHTER-(.*?)\]"), r"\1"), + # [WEA[SONABLE]-/REASONABLE] + (re.compile(r"\[\S+/(\S+)\]"), r"\1"), + # -[ADV]AN[TAGE]- -> AN + (re.compile(r"-?\[.*?\](\w+)\[.*?\]-?"), r"\1-"), + # ABSOLUTE[LY]- -> ABSOLUTE- + (re.compile(r"(\w+)\[.*?\]-?"), r"\1-"), + # [AN]Y- -> Y- + # -[AN]Y- -> Y- + (re.compile(r"-?\[.*?\](\w+)-?"), r"\1-"), + # special tokens + (re.compile(r"\[LAUGH.*?\]"), r"[LAUGHTER]"), + (re.compile(r"\[SIGH.*?\]"), r"[SIGH]"), + (re.compile(r"\[COUGH.*?\]"), r"[COUGH]"), + (re.compile(r"\[MN.*?\]"), r"[VOCALIZED-NOISE]"), + (re.compile(r"\[BREATH.*?\]"), r"[BREATH]"), + (re.compile(r"\[LIPSMACK.*?\]"), r"[LIPSMACK]"), + (re.compile(r"\[SNEEZE.*?\]"), r"[SNEEZE]"), + # abbreviations + (re.compile(r"(\w)\.(\w)\.(\w)",), r"\1 \2 \3"), + (re.compile(r"(\w)\.(\w)",), r"\1 \2"), + (re.compile(r"\._",), r" "), + (re.compile(r"_(\w)",), r"\1"), + (re.compile(r"(\w)\.s",), r"\1's"), + # words between apostrophes + (re.compile(r"'(\S*?)'"), r"\1"), + # dangling dashes (2 passes) + (re.compile(r"\s-\s"), r" "), + (re.compile(r"\s-\s"), r" "), + # special symbol with trailing dash + (re.compile(r"(\[\w+\])-"), r"\1"), + ] + + # unwanted symbols in the transcripts + self.remove_regexp = re.compile( + r"|".join([ + # special symbols + r"\[\[SKIP.*\]\]", + r"\[SKIP.*\]", + r"\[PAUSE.*\]", + r"\[SILENCE\]", + r"", + r"", + # remaining punctuation + r"\.", + r",", + r"\?", + r"{", + r"}", + r"~", + r"_\d", + ]) + ) + + self.whitespace_regexp = re.compile(r"\s+") + + def normalize(self, text: str) -> str: + text = text.upper() + + # first replace + for pattern, sub in self.replace_regexps: + text = pattern.sub(sub, text) + + # then remove + text = self.remove_regexp.sub("", text) + + # then clean up whitespace + text = self.whitespace_regexp.sub(" ", text).strip() + + return text + + +def keep(sup: SupervisionSegment) -> bool: + if "((" in sup.text: + return False + + if " yes", + "[laugh] oh this is [laught] this is great [silence] yes", + "i don't kn- - know a.b.c's", + "'absolutely yes", + "absolutely' yes", + "'absolutely' yes", + "'absolutely' yes 'aight", + "ABSOLUTE[LY]", + "ABSOLUTE[LY]-", + "[AN]Y", + "[AN]Y-", + "[ADV]AN[TAGE]", + "[ADV]AN[TAGE]-", + "-[ADV]AN[TAGE]", + "-[ADV]AN[TAGE]-", + "[WEA[SONABLE]-/REASONABLE]", + ]: + print(text) + print(normalizer.normalize(text)) + print() + +if __name__ == "__main__": + # test() + main() diff --git a/egs/fisher_swbd/ASR/local/prepare_lang.py b/egs/fisher_swbd/ASR/local/prepare_lang.py new file mode 100755 index 000000000..d913756a1 --- /dev/null +++ b/egs/fisher_swbd/ASR/local/prepare_lang.py @@ -0,0 +1,413 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input a lexicon file "data/lang_phone/lexicon.txt" +consisting of words and tokens (i.e., phones) and does the following: + +1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt + +2. Generate tokens.txt, the token table mapping a token to a unique integer. + +3. Generate words.txt, the word table mapping a word to a unique integer. + +4. Generate L.pt, in k2 format. It can be loaded by + + d = torch.load("L.pt") + lexicon = k2.Fsa.from_dict(d) + +5. Generate L_disambig.pt, in k2 format. +""" +import argparse +import math +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import k2 +import torch + +from icefall.lexicon import read_lexicon, write_lexicon +from icefall.utils import str2bool + +Lexicon = List[Tuple[str, List[str]]] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file lexicon.txt. + Generated files by this script are saved into this directory. + """, + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + """, + ) + + return parser.parse_args() + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Note: + No need to implement `read_mapping` as it can be done + through :func:`k2.SymbolTable.from_file`. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf-8") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_tokens(lexicon: Lexicon) -> List[str]: + """Get tokens from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique tokens. + """ + ans = set() + for _, tokens in lexicon: + ans.update(tokens) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def get_words(lexicon: Lexicon) -> List[str]: + """Get words from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique words. + """ + ans = set() + for word, _ in lexicon: + ans.add(word) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: + """It adds pseudo-token disambiguation symbols #1, #2 and so on + at the ends of tokens to ensure that all pronunciations are different, + and that none is a prefix of another. + + See also add_lex_disambig.pl from kaldi. + + Args: + lexicon: + It is returned by :func:`read_lexicon`. + Returns: + Return a tuple with two elements: + + - The output lexicon with disambiguation symbols + - The ID of the max disambiguation symbol that appears + in the lexicon + """ + + # (1) Work out the count of each token-sequence in the + # lexicon. + count = defaultdict(int) + for _, tokens in lexicon: + count[" ".join(tokens)] += 1 + + # (2) For each left sub-sequence of each token-sequence, note down + # that it exists (for identifying prefixes of longer strings). + issubseq = defaultdict(int) + for _, tokens in lexicon: + tokens = tokens.copy() + tokens.pop() + while tokens: + issubseq[" ".join(tokens)] = 1 + tokens.pop() + + # (3) For each entry in the lexicon: + # if the token sequence is unique and is not a + # prefix of another word, no disambig symbol. + # Else output #1, or #2, #3, ... if the same token-seq + # has already been assigned a disambig symbol. + ans = [] + + # We start with #1 since #0 has its own purpose + first_allowed_disambig = 1 + max_disambig = first_allowed_disambig - 1 + last_used_disambig_symbol_of = defaultdict(int) + + for word, tokens in lexicon: + tokenseq = " ".join(tokens) + assert tokenseq != "" + if issubseq[tokenseq] == 0 and count[tokenseq] == 1: + ans.append((word, tokens)) + continue + + cur_disambig = last_used_disambig_symbol_of[tokenseq] + if cur_disambig == 0: + cur_disambig = first_allowed_disambig + else: + cur_disambig += 1 + + if cur_disambig > max_disambig: + max_disambig = cur_disambig + last_used_disambig_symbol_of[tokenseq] = cur_disambig + tokenseq += f" #{cur_disambig}" + ans.append((word, tokenseq.split())) + return ans, max_disambig + + +def generate_id_map(symbols: List[str]) -> Dict[str, int]: + """Generate ID maps, i.e., map a symbol to a unique ID. + + Args: + symbols: + A list of unique symbols. + Returns: + A dict containing the mapping between symbols and IDs. + """ + return {sym: i for i, sym in enumerate(symbols)} + + +def add_self_loops( + arcs: List[List[Any]], disambig_token: int, disambig_word: int +) -> List[List[Any]]: + """Adds self-loops to states of an FST to propagate disambiguation symbols + through it. They are added on each state with non-epsilon output symbols + on at least one arc out of the state. + + See also fstaddselfloops.pl from Kaldi. One difference is that + Kaldi uses OpenFst style FSTs and it has multiple final states. + This function uses k2 style FSTs and it does not need to add self-loops + to the final state. + + The input label of a self-loop is `disambig_token`, while the output + label is `disambig_word`. + + Args: + arcs: + A list-of-list. The sublist contains + `[src_state, dest_state, label, aux_label, score]` + disambig_token: + It is the token ID of the symbol `#0`. + disambig_word: + It is the word ID of the symbol `#0`. + + Return: + Return new `arcs` containing self-loops. + """ + states_needs_self_loops = set() + for arc in arcs: + src, dst, ilabel, olabel, score = arc + if olabel != 0: + states_needs_self_loops.add(src) + + ans = [] + for s in states_needs_self_loops: + ans.append([s, s, disambig_token, disambig_word, 0]) + + return arcs + ans + + +def lexicon_to_fst( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + sil_token: str = "SIL", + sil_prob: float = 0.5, + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format) with optional silence at + the beginning and end of each word. + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + sil_token: + The silence token. + sil_prob: + The probability for adding a silence at the beginning and end + of the word. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + assert sil_prob > 0.0 and sil_prob < 1.0 + # CAUTION: we use score, i.e, negative cost. + sil_score = math.log(sil_prob) + no_sil_score = math.log(1.0 - sil_prob) + + start_state = 0 + loop_state = 1 # words enter and leave from here + sil_state = 2 # words terminate here when followed by silence; this state + # has a silence transition to loop_state. + next_state = 3 # the next un-allocated state, will be incremented as we go. + arcs = [] + + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + sil_token = token2id[sil_token] + + arcs.append([start_state, loop_state, eps, eps, no_sil_score]) + arcs.append([start_state, sil_state, eps, eps, sil_score]) + arcs.append([sil_state, loop_state, sil_token, eps, 0]) + + for word, tokens in lexicon: + assert len(tokens) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + tokens = [token2id[i] for i in tokens] + + for i in range(len(tokens) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, tokens[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last token of this word + # It has two out-going arcs, one to the loop state, + # the other one to the sil_state. + i = len(tokens) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) + arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + lexicon_filename = lang_dir / "lexicon.txt" + sil_token = "SIL" + sil_prob = 0.5 + + lexicon = read_lexicon(lexicon_filename) + tokens = get_tokens(lexicon) + words = get_words(lexicon) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in tokens + tokens.append(f"#{i}") + + assert "" not in tokens + tokens = [""] + tokens + + assert "" not in words + assert "#0" not in words + assert "" not in words + assert "" not in words + + words = [""] + words + ["#0", "", ""] + + token2id = generate_id_map(tokens) + word2id = generate_id_map(words) + + write_mapping(lang_dir / "tokens.txt", token2id) + write_mapping(lang_dir / "words.txt", word2id) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst( + lexicon, + token2id=token2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + ) + + L_disambig = lexicon_to_fst( + lexicon_disambig, + token2id=token2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/fisher_swbd/ASR/local/prepare_lang_bpe.py b/egs/fisher_swbd/ASR/local/prepare_lang_bpe.py new file mode 100755 index 000000000..cf32f308d --- /dev/null +++ b/egs/fisher_swbd/ASR/local/prepare_lang_bpe.py @@ -0,0 +1,254 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +""" + +This script takes as input `lang_dir`, which should contain:: + + - lang_dir/bpe.model, + - lang_dir/words.txt + +and generates the following files in the directory `lang_dir`: + + - lexicon.txt + - lexicon_disambig.txt + - L.pt + - L_disambig.pt + - tokens.txt +""" + +import argparse +from pathlib import Path +from typing import Dict, List, Tuple + +import k2 +import sentencepiece as spm +import torch +from prepare_lang import ( + Lexicon, + add_disambig_symbols, + add_self_loops, + write_lexicon, + write_mapping, +) + +from icefall.utils import str2bool + + +def lexicon_to_fst_no_sil( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format). + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + loop_state = 0 # words enter and leave from here + next_state = 1 # the next un-allocated state, will be incremented as we go + + arcs = [] + + # The blank symbol is defined in local/train_bpe_model.py + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + for word, pieces in lexicon: + assert len(pieces) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + pieces = [token2id[i] for i in pieces] + + for i in range(len(pieces) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, pieces[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last piece of this word + i = len(pieces) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, pieces[i], w, 0]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def generate_lexicon( + model_file: str, words: List[str] +) -> Tuple[Lexicon, Dict[str, int]]: + """Generate a lexicon from a BPE model. + + Args: + model_file: + Path to a sentencepiece model. + words: + A list of strings representing words. + Returns: + Return a tuple with two elements: + - A dict whose keys are words and values are the corresponding + word pieces. + - A dict representing the token symbol, mapping from tokens to IDs. + """ + sp = spm.SentencePieceProcessor() + sp.load(str(model_file)) + + words_pieces: List[List[str]] = sp.encode(words, out_type=str) + + lexicon = [] + for word, pieces in zip(words, words_pieces): + lexicon.append((word, pieces)) + + # The OOV word is + lexicon.append(("", [sp.id_to_piece(sp.unk_id())])) + + token2id: Dict[str, int] = dict() + for i in range(sp.vocab_size()): + token2id[sp.id_to_piece(i)] = i + + return lexicon, token2id + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the bpe.model and words.txt + """, + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + + See "test/test_bpe_lexicon.py" for usage. + """, + ) + + return parser.parse_args() + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + model_file = lang_dir / "bpe.model" + + word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") + + words = word_sym_table.symbols + + excluded = ["", "!SIL", "", "", "#0", "", ""] + for w in excluded: + if w in words: + words.remove(w) + + lexicon, token_sym_table = generate_lexicon(model_file, words) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + next_token_id = max(token_sym_table.values()) + 1 + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in token_sym_table + token_sym_table[disambig] = next_token_id + next_token_id += 1 + + word_sym_table.add("#0") + word_sym_table.add("") + word_sym_table.add("") + + write_mapping(lang_dir / "tokens.txt", token_sym_table) + + write_lexicon(lang_dir / "lexicon.txt", lexicon) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst_no_sil( + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, + ) + + L_disambig = lexicon_to_fst_no_sil( + lexicon_disambig, + token2id=token_sym_table, + word2id=word_sym_table, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py b/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py new file mode 100755 index 000000000..c771dac4b --- /dev/null +++ b/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py @@ -0,0 +1,482 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input a lexicon file "data/lang_phone/lexicon.txt" +consisting of words and tokens (i.e., phones) and does the following: + +1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt + +2. Generate tokens.txt, the token table mapping a token to a unique integer. + +3. Generate words.txt, the word table mapping a word to a unique integer. + +4. Generate L.pt, in k2 format. It can be loaded by + + d = torch.load("L.pt") + lexicon = k2.Fsa.from_dict(d) + +5. Generate L_disambig.pt, in k2 format. +""" +import argparse +import math +from collections import defaultdict +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import k2 +import torch +from g2p_en import G2p +from tqdm import tqdm + +from icefall.lexicon import read_lexicon, write_lexicon +from icefall.utils import str2bool + +Lexicon = List[Tuple[str, List[str]]] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain a file words.txt. + Generated files by this script are saved into this directory. + """, + ) + + parser.add_argument( + "--debug", + type=str2bool, + default=False, + help="""True for debugging, which will generate + a visualization of the lexicon FST. + + Caution: If your lexicon contains hundreds of thousands + of lines, please set it to False! + """, + ) + + return parser.parse_args() + + +def get_g2p_sym2int(): + + # These symbols are removed from from g2p_en's vocabulary + excluded_symbols = [ + "", + "", + "", + "", + ] + + symbols = [p for p in sorted(G2p().phonemes) if p not in excluded_symbols] + # reserve 0 and 1 for blank and sos/eos/pad tokens + # symbols start at index 2 + sym2int = { + "": 0, + "SIL": 1, + "UNK": 2, + "LAUGHTER": 3, + "SIGH": 4, + "COUGH": 5, + "VOCALIZED-NOISE": 6, + "BREATH": 7, + "LIPSMACK": 8, + "SNEEZE": 9, + "NOISE": 10, + **{sym: idx for idx, sym in enumerate(symbols, start=11)}, + } + return sym2int + + + +def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: + """Write a symbol to ID mapping to a file. + + Note: + No need to implement `read_mapping` as it can be done + through :func:`k2.SymbolTable.from_file`. + + Args: + filename: + Filename to save the mapping. + sym2id: + A dict mapping symbols to IDs. + Returns: + Return None. + """ + with open(filename, "w", encoding="utf-8") as f: + for sym, i in sym2id.items(): + f.write(f"{sym} {i}\n") + + +def get_tokens(lexicon: Lexicon) -> List[str]: + """Get tokens from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique tokens. + """ + ans = set() + for _, tokens in lexicon: + ans.update(tokens) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def get_words(lexicon: Lexicon) -> List[str]: + """Get words from a lexicon. + + Args: + lexicon: + It is the return value of :func:`read_lexicon`. + Returns: + Return a list of unique words. + """ + ans = set() + for word, _ in lexicon: + ans.add(word) + sorted_ans = sorted(list(ans)) + return sorted_ans + + +def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: + """It adds pseudo-token disambiguation symbols #1, #2 and so on + at the ends of tokens to ensure that all pronunciations are different, + and that none is a prefix of another. + + See also add_lex_disambig.pl from kaldi. + + Args: + lexicon: + It is returned by :func:`read_lexicon`. + Returns: + Return a tuple with two elements: + + - The output lexicon with disambiguation symbols + - The ID of the max disambiguation symbol that appears + in the lexicon + """ + + # (1) Work out the count of each token-sequence in the + # lexicon. + count = defaultdict(int) + for _, tokens in lexicon: + count[" ".join(tokens)] += 1 + + # (2) For each left sub-sequence of each token-sequence, note down + # that it exists (for identifying prefixes of longer strings). + issubseq = defaultdict(int) + for _, tokens in lexicon: + tokens = tokens.copy() + tokens.pop() + while tokens: + issubseq[" ".join(tokens)] = 1 + tokens.pop() + + # (3) For each entry in the lexicon: + # if the token sequence is unique and is not a + # prefix of another word, no disambig symbol. + # Else output #1, or #2, #3, ... if the same token-seq + # has already been assigned a disambig symbol. + ans = [] + + # We start with #1 since #0 has its own purpose + first_allowed_disambig = 1 + max_disambig = first_allowed_disambig - 1 + last_used_disambig_symbol_of = defaultdict(int) + + for word, tokens in lexicon: + tokenseq = " ".join(tokens) + assert tokenseq != "" + if issubseq[tokenseq] == 0 and count[tokenseq] == 1: + ans.append((word, tokens)) + continue + + cur_disambig = last_used_disambig_symbol_of[tokenseq] + if cur_disambig == 0: + cur_disambig = first_allowed_disambig + else: + cur_disambig += 1 + + if cur_disambig > max_disambig: + max_disambig = cur_disambig + last_used_disambig_symbol_of[tokenseq] = cur_disambig + tokenseq += f" #{cur_disambig}" + ans.append((word, tokenseq.split())) + return ans, max_disambig + + +def generate_id_map(symbols: List[str]) -> Dict[str, int]: + """Generate ID maps, i.e., map a symbol to a unique ID. + + Args: + symbols: + A list of unique symbols. + Returns: + A dict containing the mapping between symbols and IDs. + """ + return {sym: i for i, sym in enumerate(symbols)} + + +def add_self_loops( + arcs: List[List[Any]], disambig_token: int, disambig_word: int +) -> List[List[Any]]: + """Adds self-loops to states of an FST to propagate disambiguation symbols + through it. They are added on each state with non-epsilon output symbols + on at least one arc out of the state. + + See also fstaddselfloops.pl from Kaldi. One difference is that + Kaldi uses OpenFst style FSTs and it has multiple final states. + This function uses k2 style FSTs and it does not need to add self-loops + to the final state. + + The input label of a self-loop is `disambig_token`, while the output + label is `disambig_word`. + + Args: + arcs: + A list-of-list. The sublist contains + `[src_state, dest_state, label, aux_label, score]` + disambig_token: + It is the token ID of the symbol `#0`. + disambig_word: + It is the word ID of the symbol `#0`. + + Return: + Return new `arcs` containing self-loops. + """ + states_needs_self_loops = set() + for arc in arcs: + src, dst, ilabel, olabel, score = arc + if olabel != 0: + states_needs_self_loops.add(src) + + ans = [] + for s in states_needs_self_loops: + ans.append([s, s, disambig_token, disambig_word, 0]) + + return arcs + ans + + +def lexicon_to_fst( + lexicon: Lexicon, + token2id: Dict[str, int], + word2id: Dict[str, int], + sil_token: str = "SIL", + sil_prob: float = 0.5, + need_self_loops: bool = False, +) -> k2.Fsa: + """Convert a lexicon to an FST (in k2 format) with optional silence at + the beginning and end of each word. + + Args: + lexicon: + The input lexicon. See also :func:`read_lexicon` + token2id: + A dict mapping tokens to IDs. + word2id: + A dict mapping words to IDs. + sil_token: + The silence token. + sil_prob: + The probability for adding a silence at the beginning and end + of the word. + need_self_loops: + If True, add self-loop to states with non-epsilon output symbols + on at least one arc out of the state. The input label for this + self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. + Returns: + Return an instance of `k2.Fsa` representing the given lexicon. + """ + assert sil_prob > 0.0 and sil_prob < 1.0 + # CAUTION: we use score, i.e, negative cost. + sil_score = math.log(sil_prob) + no_sil_score = math.log(1.0 - sil_prob) + + start_state = 0 + loop_state = 1 # words enter and leave from here + sil_state = 2 # words terminate here when followed by silence; this state + # has a silence transition to loop_state. + next_state = 3 # the next un-allocated state, will be incremented as we go. + arcs = [] + + assert token2id[""] == 0 + assert word2id[""] == 0 + + eps = 0 + + sil_token = token2id[sil_token] + + arcs.append([start_state, loop_state, eps, eps, no_sil_score]) + arcs.append([start_state, sil_state, eps, eps, sil_score]) + arcs.append([sil_state, loop_state, sil_token, eps, 0]) + + for word, tokens in lexicon: + assert len(tokens) > 0, f"{word} has no pronunciations" + cur_state = loop_state + + word = word2id[word] + tokens = [token2id[i] for i in tokens] + + for i in range(len(tokens) - 1): + w = word if i == 0 else eps + arcs.append([cur_state, next_state, tokens[i], w, 0]) + + cur_state = next_state + next_state += 1 + + # now for the last token of this word + # It has two out-going arcs, one to the loop state, + # the other one to the sil_state. + i = len(tokens) - 1 + w = word if i == 0 else eps + arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) + arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) + + if need_self_loops: + disambig_token = token2id["#0"] + disambig_word = word2id["#0"] + arcs = add_self_loops( + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, + ) + + final_state = next_state + arcs.append([loop_state, final_state, -1, -1, 0]) + arcs.append([final_state]) + + arcs = sorted(arcs, key=lambda arc: arc[0]) + arcs = [[str(i) for i in arc] for arc in arcs] + arcs = [" ".join(arc) for arc in arcs] + arcs = "\n".join(arcs) + + fsa = k2.Fsa.from_str(arcs, acceptor=False) + return fsa + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + vocab_filename = lang_dir / "words.txt" + lexicon_filename = lang_dir / "lexicon.txt" + sil_token = "SIL" + sil_prob = 0.5 + special_symbols = ["[UNK]", "[BREATH]", "[COUGH]", "[LAUGHTER]", "[LIPSMACK]", "[NOISE]", "[SIGH]", "[SNEEZE]", "[VOCALIZED-NOISE]"] + + g2p = G2p() + token2id = get_g2p_sym2int() + + vocab = sorted( + [ + l.split()[0] + for l in vocab_filename.read_text().splitlines() + if l.strip() and not l.startswith(("!", "[", "<", "#")) + ] + ) + print("First ten words from the vocabulary:") + print(vocab[:10]) + + if not lexicon_filename.is_file(): + lexicon = [ + ("!SIL", [sil_token]), + ] + for symbol in special_symbols: + lexicon.append((symbol, [symbol[1:-1]])) + lexicon += [ + ( + word, + [ + phn for phn in g2p(word) + if phn not in ("'", " ", "-", ",") # g2p_en has these symbols as phones + ], + ) + for word in tqdm(vocab, desc="Processing vocab with G2P") + ] + lexicon = [entry for entry in lexicon if entry[1]] # filter empty prons + print(lexicon[:10]) + + write_lexicon(lexicon_filename, lexicon) + else: + lexicon = read_lexicon(lexicon_filename) + + tokens = get_tokens(lexicon) + + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + + for i in range(max_disambig + 1): + disambig = f"#{i}" + assert disambig not in tokens + tokens.append(disambig) + token2id[disambig] = max(token2id.values()) + 1 + + print("Tokens in the lexicon:") + print(tokens) + + # sort by ID + token2id = dict(sorted(token2id.items(), key=lambda tpl: tpl[1])) + print(token2id) + word2id = {"": 0} + word2id.update({ + word: int(id_) for id_, (word, pron) in enumerate(lexicon, start=1) + }) + for symbol in ["", "", "#0"]: + word2id[symbol] = len(word2id) + + write_mapping(lang_dir / "tokens.txt", token2id) + write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) + + L = lexicon_to_fst( + lexicon, + token2id=token2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + ) + + L_disambig = lexicon_to_fst( + lexicon_disambig, + token2id=token2id, + word2id=word2id, + sil_token=sil_token, + sil_prob=sil_prob, + need_self_loops=True, + ) + torch.save(L.as_dict(), lang_dir / "L.pt") + torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") + + if args.debug: + labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") + aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") + + L.labels_sym = labels_sym + L.aux_labels_sym = aux_labels_sym + L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") + + L_disambig.labels_sym = labels_sym + L_disambig.aux_labels_sym = aux_labels_sym + L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") + + +if __name__ == "__main__": + main() diff --git a/egs/fisher_swbd/ASR/local/test_prepare_lang.py b/egs/fisher_swbd/ASR/local/test_prepare_lang.py new file mode 100755 index 000000000..d4cf62bba --- /dev/null +++ b/egs/fisher_swbd/ASR/local/test_prepare_lang.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) + +import os +import tempfile + +import k2 +from prepare_lang import ( + add_disambig_symbols, + generate_id_map, + get_phones, + get_words, + lexicon_to_fst, + read_lexicon, + write_lexicon, + write_mapping, +) + + +def generate_lexicon_file() -> str: + fd, filename = tempfile.mkstemp() + os.close(fd) + s = """ + !SIL SIL + SPN + SPN + f f + a a + foo f o o + bar b a r + bark b a r k + food f o o d + food2 f o o d + fo f o + """.strip() + with open(filename, "w") as f: + f.write(s) + return filename + + +def test_read_lexicon(filename: str): + lexicon = read_lexicon(filename) + phones = get_phones(lexicon) + words = get_words(lexicon) + print(lexicon) + print(phones) + print(words) + lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) + print(lexicon_disambig) + print("max disambig:", f"#{max_disambig}") + + phones = ["", "SIL", "SPN"] + phones + for i in range(max_disambig + 1): + phones.append(f"#{i}") + words = [""] + words + + phone2id = generate_id_map(phones) + word2id = generate_id_map(words) + + print(phone2id) + print(word2id) + + write_mapping("phones.txt", phone2id) + write_mapping("words.txt", word2id) + + write_lexicon("a.txt", lexicon) + write_lexicon("a_disambig.txt", lexicon_disambig) + + fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id) + fsa.labels_sym = k2.SymbolTable.from_file("phones.txt") + fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") + fsa.draw("L.pdf", title="L") + + fsa_disambig = lexicon_to_fst( + lexicon_disambig, phone2id=phone2id, word2id=word2id + ) + fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") + fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") + fsa_disambig.draw("L_disambig.pdf", title="L_disambig") + + +def main(): + filename = generate_lexicon_file() + test_read_lexicon(filename) + os.remove(filename) + + +if __name__ == "__main__": + main() diff --git a/egs/fisher_swbd/ASR/local/train_bpe_model.py b/egs/fisher_swbd/ASR/local/train_bpe_model.py new file mode 100755 index 000000000..bc5812810 --- /dev/null +++ b/egs/fisher_swbd/ASR/local/train_bpe_model.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +# You can install sentencepiece via: +# +# pip install sentencepiece +# +# Due to an issue reported in +# https://github.com/google/sentencepiece/pull/642#issuecomment-857972030 +# +# Please install a version >=0.1.96 + +import argparse +import shutil +from pathlib import Path + +import sentencepiece as spm + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + It should contain the training corpus: transcript_words.txt. + The generated bpe.model is saved to this directory. + """, + ) + + parser.add_argument( + "--transcript", + type=str, + help="Training transcript.", + ) + + parser.add_argument( + "--vocab-size", + type=int, + help="Vocabulary size for BPE training", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + vocab_size = args.vocab_size + lang_dir = Path(args.lang_dir) + + model_type = "unigram" + + model_prefix = f"{lang_dir}/{model_type}_{vocab_size}" + train_text = args.transcript + character_coverage = 1.0 + input_sentence_size = 100000000 + + user_defined_symbols = ["", ""] + unk_id = len(user_defined_symbols) + # Note: unk_id is fixed to 2. + # If you change it, you should also change other + # places that are using it. + + model_file = Path(model_prefix + ".model") + if not model_file.is_file(): + spm.SentencePieceTrainer.train( + input=train_text, + vocab_size=vocab_size, + model_type=model_type, + model_prefix=model_prefix, + input_sentence_size=input_sentence_size, + character_coverage=character_coverage, + user_defined_symbols=user_defined_symbols, + unk_id=unk_id, + bos_id=-1, + eos_id=-1, + ) + + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") + + +if __name__ == "__main__": + main() diff --git a/egs/fisher_swbd/ASR/prepare.sh b/egs/fisher_swbd/ASR/prepare.sh new file mode 100755 index 000000000..bf61b1be9 --- /dev/null +++ b/egs/fisher_swbd/ASR/prepare.sh @@ -0,0 +1,244 @@ +#!/usr/bin/env bash + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/LibriSpeech +# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it. +# You can download them from https://www.openslr.org/12 +# +# - $dl_dir/lm +# This directory contains the following files downloaded from +# http://www.openslr.org/resources/11 +# +# - 3-gram.pruned.1e-7.arpa.gz +# - 3-gram.pruned.1e-7.arpa +# - 4-gram.arpa.gz +# - 4-gram.arpa +# - librispeech-vocab.txt +# - librispeech-lexicon.txt +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech +dl_dir=$PWD/download +mkdir -p $dl_dir + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + 500 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download LM" + #[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm + #./local/download_lm.py --out-dir=$dl_dir/lm +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/fisher and /path/to/swbd, + # you can create a symlink + # + # ln -sfv /path/to/fisher $dl_dir/fisher + # + + # TODO: remove + LDC_ROOT=/fsx/resources/LDC + for pkg in LDC2004S13 LDC2004T19 LDC2005S13 LDC2005T19 LDC97S62; do + ln -sfv $LDC_ROOT/$pkg $dl_dir/ + done + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare Fisher manifests" + # We assume that you have downloaded the LibriSpeech corpus + # to $dl_dir/LibriSpeech + mkdir -p data/manifests/fisher + lhotse prepare fisher-english $dl_dir data/manifests/fisher +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare SWBD manifests" + # We assume that you have downloaded the LibriSpeech corpus + # to $dl_dir/LibriSpeech + mkdir -p data/manifests/swbd + lhotse prepare switchboard --omit-silence $dl_dir/LDC97S62 data/manifests/swbd +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + mkdir -p data/manifests + lhotse prepare musan $dl_dir/musan data/manifests +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Combine Fisher + SBWD manifests" + + set -x + + lhotse combine \ + data/manifests/fisher/recordings.jsonl.gz \ + data/manifests/swbd/swbd_recordings.jsonl \ + data/manifests/fisher-swbd_recordings.jsonl.gz + + lhotse combine \ + data/manifests/fisher/supervisions.jsonl.gz \ + data/manifests/swbd/swbd_supervisions.jsonl \ + data/manifests/fisher-swbd_supervisions.jsonl.gz + + python local/normalize_and_filter_supervisions.py \ + data/manifests/fisher-swbd_supervisions.jsonl.gz \ + data/manifests/fisher-swbd_supervisions_norm.jsonl.gz \ + + lhotse cut simple \ + -r data/manifests/fisher-swbd_recordings.jsonl.gz \ + -s data/manifests/fisher-swbd_supervisions_norm.jsonl.gz \ + data/manifests/fisher-swbd_cuts_unshuf.jsonl.gz + + gunzip -c data/manifests/fisher-swbd_cuts_unshuf.jsonl.gz \ + | shuf \ + | gzip -c \ + > data/manifests/fisher-swbd_cuts.jsonl.gz + + set +x +fi + +# TODO: optional stage 5, compute features + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Dump transcripts for LM training" + mkdir -p data/lm + gunzip -c data/manifests/fisher-swbd_supervisions_norm.jsonl.gz \ + | jq '.text' \ + | sed 's:"::g' \ + > data/lm/transcript_words.txt +fi + +#if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then +# log "Stage 3: Compute fbank for librispeech" +# mkdir -p data/fbank +# ./local/compute_fbank_librispeech.py +#fi +# +#if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then +# log "Stage 4: Compute fbank for musan" +# mkdir -p data/fbank +# ./local/compute_fbank_musan.py +#fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare lexicon using g2p_en" + lang_dir=data/lang_phone + mkdir -p $lang_dir + + # Add special words to words.txt + echo " 0" > $lang_dir/words.txt + echo "!SIL 1" >> $lang_dir/words.txt + echo " 2" >> $lang_dir/words.txt + + # Add regular words to words.txt + gunzip -c data/manifests/fisher-swbd_supervisions_norm.jsonl.gz \ + | jq '.text' \ + | sed 's:"::g' \ + | sed 's: :\n:g' \ + | sort \ + | uniq \ + | awk '{print $0,NR+2}' \ + >> $lang_dir/words.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + pip install g2p_en + ./local/prepare_lang_g2pen.py --lang-dir $lang_dir + fi +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt $lang_dir + + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript data/lm/transcript_words.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + fi + done +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Train LM" + lm_dir=data/lm + + if [ ! -f $lm_dir/G.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text $lm_dir/transcript_words.txt \ + -lm $lm_dir/G.arpa + fi + + if [ ! -f $lm_dir/G_3_gram.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + $lm_dir/G.arpa > $lm_dir/G_3_gram.fst.txt + fi +fi + +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Compile HLG" + ./local/compile_hlg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + done +fi diff --git a/egs/fisher_swbd/ASR/shared b/egs/fisher_swbd/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/fisher_swbd/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/fisher_swbd/ASR/streaming_conformer_ctc/README.md b/egs/fisher_swbd/ASR/streaming_conformer_ctc/README.md new file mode 100644 index 000000000..01be7090b --- /dev/null +++ b/egs/fisher_swbd/ASR/streaming_conformer_ctc/README.md @@ -0,0 +1,90 @@ +## Train and Decode +Commands of data preparation/train/decode steps are almost the same with +../conformer_ctc experiment except some options. + +Please read the code and understand following new added options before running this experiment: + + For data preparation: + + Nothing new. + + For streaming_conformer_ctc/train.py: + + --dynamic-chunk-training + --short-chunk-proportion + + For streaming_conformer_ctc/streaming_decode.py: + + --chunk-size + --tailing-num-frames + --simulate-streaming + +## Performence and a trained model. + +The latest results with this streaming code is shown in following table: + +chunk size | wer on test-clean | wer on test-other +-- | -- | -- +full | 3.53 | 8.52 +40(1.96s) | 3.78 | 9.38 +32(1.28s) | 3.82 | 9.44 +24(0.96s) | 3.95 | 9.76 +16(0.64s) | 4.06 | 9.98 +8(0.32s) | 4.30 | 10.55 +4(0.16s) | 5.88 | 12.01 + + +A trained model is also provided. +By run +``` +git clone https://huggingface.co/GuoLiyong/streaming_conformer + +# You may want to manually check md5sum values of downloaded files +# 8e633bc1de37f5ae57a2694ceee32a93 trained_streaming_conformer.pt +# 4c0aeefe26c784ec64873cc9b95420f1 L.pt +# d1f91d81005fb8ce4d65953a4a984ee7 Linv.pt +# e1c1902feb7b9fc69cd8d26e663c2608 bpe.model +# 8617e67159b0ff9118baa54f04db24cc tokens.txt +# 72b075ab5e851005cd854e666c82c3bb words.txt +``` + +If there is any different md5sum values, please run +``` +cd streaming_models +git lfs pull +``` +And check md5sum values again. + +Finally, following files will be downloaded: +
+streaming_models/  
+|-- lang_bpe  
+|   |-- L.pt  
+|   |-- Linv.pt  
+|   |-- bpe.model
+|   |-- tokens.txt
+|   `-- words.txt
+`-- trained_streaming_conformer.pt
+
+ + +And run commands you will get the same results of previous table: +``` +trained_models=/path/to/downloaded/streaming_models/ +for chunk_size in 4 8 16 24 36 40 -1; do + ./streaming_conformer_ctc/streaming_decode.py \ + --chunk-size=${chunk_size} \ + --trained-dir=${trained_models} +done +``` +Results of following command is indentical to previous one, +but model consumes features chunk_by_chunk, i.e. a streaming way. +``` +trained_models=/path/to/downloaded/streaming_models/ +for chunk_size in 4 8 16 24 36 40 -1; do + ./streaming_conformer_ctc/streaming_decode.py \ + --simulate-streaming=True \ + --chunk-size=${chunk_size} \ + --trained-dir=${trained_models} +done +``` diff --git a/egs/fisher_swbd/ASR/streaming_conformer_ctc/asr_datamodule.py b/egs/fisher_swbd/ASR/streaming_conformer_ctc/asr_datamodule.py new file mode 120000 index 000000000..a73848de9 --- /dev/null +++ b/egs/fisher_swbd/ASR/streaming_conformer_ctc/asr_datamodule.py @@ -0,0 +1 @@ +../conformer_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/fisher_swbd/ASR/streaming_conformer_ctc/conformer.py b/egs/fisher_swbd/ASR/streaming_conformer_ctc/conformer.py new file mode 100644 index 000000000..33f8e2e15 --- /dev/null +++ b/egs/fisher_swbd/ASR/streaming_conformer_ctc/conformer.py @@ -0,0 +1,1355 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +import warnings +from typing import Optional, Tuple + +import torch +from torch import Tensor, nn +from transformer import Supervisions, Transformer, encoder_padding_mask + + +# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py#L42 +def subsequent_chunk_mask( + size: int, + chunk_size: int, + num_left_chunks: int = -1, + device: torch.device = torch.device("cpu"), +) -> torch.Tensor: + """Create mask for subsequent steps (size, size) with chunk size, + this is for streaming encoder + Args: + size (int): size of mask + chunk_size (int): size of chunk + num_left_chunks (int): number of left chunks + <0: use full chunk + >=0: use num_left_chunks + device (torch.device): "cpu" or "cuda" or torch.Tensor.device + Returns: + torch.Tensor: mask + Examples: + >>> subsequent_chunk_mask(4, 2) + [[1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]] + """ + ret = torch.zeros(size, size, device=device, dtype=torch.bool) + for i in range(size): + if num_left_chunks < 0: + start = 0 + else: + start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) + ending = min((i // chunk_size + 1) * chunk_size, size) + ret[i, start:ending] = True + return ret + + +class Conformer(Transformer): + """ + Args: + num_features (int): Number of input features + num_classes (int): Number of output classes + subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) + d_model (int): attention dimension + nhead (int): number of head + dim_feedforward (int): feedforward dimention + num_encoder_layers (int): number of encoder layers + num_decoder_layers (int): number of decoder layers + dropout (float): dropout rate + cnn_module_kernel (int): Kernel size of convolution module + normalize_before (bool): whether to use layer_norm before the first block. + vgg_frontend (bool): whether to use vgg frontend. + """ + + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + vgg_frontend: bool = False, + use_feat_batchnorm: bool = False, + causal: bool = False, + ) -> None: + super(Conformer, self).__init__( + num_features=num_features, + num_classes=num_classes, + subsampling_factor=subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dropout=dropout, + normalize_before=normalize_before, + vgg_frontend=vgg_frontend, + use_feat_batchnorm=use_feat_batchnorm, + ) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = ConformerEncoderLayer( + d_model, + nhead, + dim_feedforward, + dropout, + cnn_module_kernel, + normalize_before, + causal, + ) + self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) + self.normalize_before = normalize_before + if self.normalize_before: + self.after_norm = nn.LayerNorm(d_model) + else: + # Note: TorchScript detects that self.after_norm could be used inside forward() + # and throws an error without this change. + self.after_norm = identity + + def run_encoder( + self, + x: Tensor, + supervisions: Optional[Supervisions] = None, + dynamic_chunk_training: bool = False, + short_chunk_proportion: float = 0.5, + chunk_size: int = -1, + simulate_streaming: bool = False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. + dynamic_chunk_training: + For training only. + IF True, train with dynamic right context for some batches + sampled with a distribution + if False, train with full right context all the time. + short_chunk_proportion: + For training only. + Proportion of samples that will be trained with dynamic chunk. + chunk_size: + For eval only. + right context when evaluating test utts. + -1 means all right context. + simulate_streaming=False, + For eval only. + If true, the feature will be feeded into the model chunk by chunk. + If false, the whole utts if feeded into the model together i.e. the + model only foward once. + + + Returns: + Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). + Tensor: Mask tensor of dimension (batch_size, input_length) + """ + if self.encoder.training: + return self.train_run_encoder( + x, supervisions, dynamic_chunk_training, short_chunk_proportion + ) + else: + return self.eval_run_encoder( + x, supervisions, chunk_size, simulate_streaming + ) + + def train_run_encoder( + self, + x: Tensor, + supervisions: Optional[Supervisions] = None, + dynamic_chunk_training: bool = False, + short_chunk_threshold: float = 0.5, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + x: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. + dynamic_chunk_training: + IF True, train with dynamic right context for some batches + sampled with a distribution + if False, train with full right context all the time. + short_chunk_proportion: + Proportion of samples that will be trained with dynamic chunk. + """ + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + src_key_padding_mask = encoder_padding_mask(x.size(0), supervisions) + if src_key_padding_mask is not None: + src_key_padding_mask = src_key_padding_mask.to(x.device) + + if dynamic_chunk_training: + max_len = x.size(0) + chunk_size = torch.randint(1, max_len, (1,)).item() + if chunk_size > (max_len * short_chunk_threshold): + chunk_size = max_len + else: + chunk_size = chunk_size % 25 + 1 + mask = ~subsequent_chunk_mask( + size=x.size(0), chunk_size=chunk_size, device=x.device + ) + x = self.encoder( + x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask + ) # (T, B, F) + else: + x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) + + if self.normalize_before: + x = self.after_norm(x) + + return x, src_key_padding_mask + + def eval_run_encoder( + self, + feature: Tensor, + supervisions: Optional[Supervisions] = None, + chunk_size: int = -1, + simulate_streaming=False, + ) -> Tuple[Tensor, Optional[Tensor]]: + """ + Args: + feature: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. + + Returns: + Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). + Tensor: Mask tensor of dimension (batch_size, input_length) + """ + # feature.shape: N T C + num_frames = feature.size(1) + + # As temporarily in icefall only subsampling_rate == 4 is supported, + # following parameters are hard-coded here. + # Change it accordingly if other subsamling_rate are supported. + embed_left_context = 7 + embed_conv_right_context = 3 + subsampling_rate = 4 + stride = chunk_size * subsampling_rate + decoding_window = embed_conv_right_context + stride + + # This is also only compatible to sumsampling_rate == 4 + length_after_subsampling = ((feature.size(1) - 1) // 2 - 1) // 2 + src_key_padding_mask = encoder_padding_mask( + length_after_subsampling, supervisions + ) + if src_key_padding_mask is not None: + src_key_padding_mask = src_key_padding_mask.to(feature.device) + + if chunk_size < 0: + # non-streaming decoding + x = self.encoder_embed(feature) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + x = self.encoder( + x, pos_emb, src_key_padding_mask=src_key_padding_mask + ) # (T, B, F) + else: + if simulate_streaming: + # simulate chunk_by_chunk streaming decoding + # Results of this branch should be identical to following + # "else" branch. + # But this branch is a little slower + # as the feature is feeded chunk by chunk + + # store the result of chunk_by_chunk decoding + encoder_output = [] + + # caches + pos_emb_positive = [] + pos_emb_negative = [] + pos_emb_central = None + encoder_cache = [None for i in range(len(self.encoder.layers))] + conv_cache = [None for i in range(len(self.encoder.layers))] + + # start chunk_by_chunk decoding + offset = 0 + for cur in range( + 0, num_frames - embed_left_context + 1, stride + ): + end = min(cur + decoding_window, num_frames) + cur_feature = feature[:, cur:end, :] + cur_feature = self.encoder_embed(cur_feature) + cur_embed, cur_pos_emb = self.encoder_pos( + cur_feature, offset + ) + cur_embed = cur_embed.permute( + 1, 0, 2 + ) # (B, T, F) -> (T, B, F) + + cur_T = cur_feature.size(1) + if cur == 0: + # for first chunk extract the central pos embedding + pos_emb_central = cur_pos_emb[ + 0, (chunk_size - 1), : + ].view(1, 1, -1) + cur_T -= 1 + pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0)) + pos_emb_negative.append(cur_pos_emb[0, -cur_T:]) + assert pos_emb_positive[-1].size(0) == cur_T + + pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze( + 0 + ) + pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze( + 0 + ) + cur_pos_emb = torch.cat( + [pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg], + dim=1, + ) + + x = self.encoder.chunk_forward( + cur_embed, + cur_pos_emb, + src_key_padding_mask=src_key_padding_mask[ + :, : offset + cur_embed.size(0) + ], + encoder_cache=encoder_cache, + conv_cache=conv_cache, + offset=offset, + ) # (T, B, F) + encoder_output.append(x) + offset += cur_embed.size(0) + + x = torch.cat(encoder_output, dim=0) + else: + # NOT simulate chunk_by_chunk decoding + # Results of this branch should be identical to previous + # simulate chunk_by_chunk decoding branch. + # But this branch is faster. + x = self.encoder_embed(feature) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + mask = ~subsequent_chunk_mask( + size=x.size(0), chunk_size=chunk_size, device=x.device + ) + x = self.encoder( + x, + pos_emb, + mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) # (T, B, F) + + if self.normalize_before: + x = self.after_norm(x) + + return x, src_key_padding_mask + + +class ConformerEncoderLayer(nn.Module): + """ + ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + + Args: + d_model: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + dim_feedforward: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + cnn_module_kernel (int): Kernel size of convolution module. + normalize_before: whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + normalize_before: bool = True, + causal: bool = False, + ) -> None: + super(ConformerEncoderLayer, self).__init__() + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) + + self.feed_forward = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.feed_forward_macaron = nn.Sequential( + nn.Linear(d_model, dim_feedforward), + Swish(), + nn.Dropout(dropout), + nn.Linear(dim_feedforward, d_model), + ) + + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, causal=causal + ) + + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module + self.norm_ff = nn.LayerNorm(d_model) # for the FNN module + self.norm_mha = nn.LayerNorm(d_model) # for the MHA module + + self.ff_scale = 0.5 + + self.norm_conv = nn.LayerNorm(d_model) # for the CNN module + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block + + self.dropout = nn.Dropout(dropout) + + self.normalize_before = normalize_before + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + # macaron style feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff_macaron(src) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) + if not self.normalize_before: + src = self.norm_ff_macaron(src) + + # multi-headed self-attention module + residual = src + if self.normalize_before: + src = self.norm_mha(src) + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout(src_att) + if not self.normalize_before: + src = self.norm_mha(src) + + # convolution module + residual = src + if self.normalize_before: + src = self.norm_conv(src) + src = residual + self.dropout(self.conv_module(src)) + if not self.normalize_before: + src = self.norm_conv(src) + + # feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + if not self.normalize_before: + src = self.norm_ff(src) + + if self.normalize_before: + src = self.norm_final(src) + + return src + + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + encoder_cache: Optional[Tensor] = None, + conv_cache: Optional[Tensor] = None, + offset=0, + ) -> Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + pos_emb: Positional embedding tensor (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, N is the batch size, E is the feature number + """ + + # macaron style feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff_macaron(src) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) + if not self.normalize_before: + src = self.norm_ff_macaron(src) + + # multi-headed self-attention module + residual = src + if self.normalize_before: + src = self.norm_mha(src) + if encoder_cache is None: + # src: [chunk_size, N, F] e.g. [8, 41, 512] + key = src + val = key + encoder_cache = key + else: + key = torch.cat([encoder_cache, src], dim=0) + val = key + encoder_cache = key + src_att = self.self_attn( + src, + key, + val, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + offset=offset, + )[0] + src = residual + self.dropout(src_att) + if not self.normalize_before: + src = self.norm_mha(src) + + # convolution module + residual = src # [chunk_size, N, F] e.g. [8, 41, 512] + if self.normalize_before: + src = self.norm_conv(src) + if conv_cache is not None: + src = torch.cat([conv_cache, src], dim=0) + conv_cache = src + + src = self.conv_module(src) + src = src[-residual.size(0) :, :, :] # noqa: E203 + + src = residual + self.dropout(src) + if not self.normalize_before: + src = self.norm_conv(src) + + # feed forward module + residual = src + if self.normalize_before: + src = self.norm_ff(src) + src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) + if not self.normalize_before: + src = self.norm_ff(src) + + if self.normalize_before: + src = self.norm_final(src) + + return src, encoder_cache, conv_cache + + +class ConformerEncoder(nn.TransformerEncoder): + r"""ConformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the ConformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + + Examples:: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = conformer_encoder(src, pos_emb) + """ + + def __init__( + self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None + ) -> None: + super(ConformerEncoder, self).__init__( + encoder_layer=encoder_layer, num_layers=num_layers, norm=norm + ) + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for mod in self.layers: + output = mod( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + def chunk_forward( + self, + src: Tensor, + pos_emb: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + encoder_cache=None, + conv_cache=None, + offset=0, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + pos_emb: Positional embedding tensor (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + src: (S, N, E). + pos_emb: (N, 2*S-1, E) + mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + + """ + output = src + + for layer_index, mod in enumerate(self.layers): + output, e_cache, c_cache = mod.chunk_forward( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + encoder_cache=encoder_cache[layer_index], + conv_cache=conv_cache[layer_index], + offset=offset, + ) + encoder_cache[layer_index] = e_cache + conv_cache[layer_index] = c_cache + + if self.norm is not None: + output = self.norm(output) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """Relative positional encoding module. + + See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: + """Construct an PositionalEncoding object.""" + super(RelPositionalEncoding, self).__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: Tensor, offset: int = 0) -> None: + """Reset the positional encodings.""" + x_size_1 = offset + x.size(1) + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x_size_1 * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: + """Add positional encoding. + + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). + + """ + self.extend_pe(x, offset) + x = x * self.xscale + x_size_1 = offset + x.size(1) + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x_size_1 + + 1 : self.pe.size(1) // 2 # noqa E203 + + x_size_1, + ] + x_T = x.size(1) + if offset > 0: + pos_emb = torch.cat([pos_emb[:, :x_T], pos_emb[:, -x_T:]], dim=1) + else: + pos_emb = torch.cat( + [ + pos_emb[:, : (x_T - 1)], + self.pe[0, self.pe.size(1) // 2].view( + 1, 1, self.pe.size(-1) + ), + pos_emb[:, -(x_T - 1) :], # noqa: E203 + ], + dim=1, + ) + + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + r"""Multi-Head Attention layer with relative position encoding + + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + + Args: + embed_dim: total dimension of the model. + num_heads: parallel attention heads. + dropout: a Dropout layer on attn_output_weights. Default: 0.0. + + Examples:: + + >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + super(RelPositionMultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + # linear transformation for positional encoding. + self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + + self._reset_parameters() + + def _reset_parameters(self) -> None: + nn.init.xavier_uniform_(self.in_proj.weight) + nn.init.constant_(self.in_proj.bias, 0.0) + nn.init.constant_(self.out_proj.bias, 0.0) + + nn.init.xavier_uniform_(self.pos_bias_u) + nn.init.xavier_uniform_(self.pos_bias_v) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + offset=0, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask and a value is True, + the corresponding value on the attention layer will be ignored. When given + a byte mask and a value is non-zero, the corresponding value on the attention + layer will be ignored + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.weight, + self.in_proj.bias, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + offset=offset, + ) + + def rel_shift(self, x: Tensor, offset=0) -> Tensor: + """Compute relative positional encoding. + + Args: + x: Input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 == time1 + offset, since it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + time2 = time1 + offset + assert n == 2 * time2 - 1 + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + + return x.as_strided( + (batch_size, num_heads, time1, time2), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_emb: Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: Tensor, + in_proj_bias: Tensor, + dropout_p: float, + out_proj_weight: Tensor, + out_proj_bias: Tensor, + training: bool = True, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + offset=0, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. When the value is True, + the corresponding value on the attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all + the batches while a 3D mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) + else: + raise RuntimeError( + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + + # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) + p = p.permute(0, 2, 3, 1) + + q_with_bias_u = (q + self.pos_bias_u).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self.pos_bias_v).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul( + q_with_bias_v, p + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift( + matrix_bd, offset=offset + ) # [B, head, time1, time2] + attn_output_weights = ( + matrix_ac + matrix_bd + ) * scaling # (batch, head, time1, time2) + + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) + + assert list(attn_output_weights.size()) == [ + bsz * num_heads, + tgt_len, + src_len, + ] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class ConvolutionModule(nn.Module): + """ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + + Args: + channels (int): The number of channels of conv layers. + kernel_size (int): Kernerl size of conv layers. + bias (bool): Whether to use bias in conv layers (default=True). + + """ + + def __init__( + self, + channels: int, + kernel_size: int, + bias: bool = True, + causal: bool = False, + ) -> None: + """Construct an ConvolutionModule object.""" + super(ConvolutionModule, self).__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = nn.Conv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + # from https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/convolution.py#L41 + if causal: + self.lorder = kernel_size - 1 + padding = 0 # manualy padding self.lorder zeros to the left later + else: + assert (kernel_size - 1) % 2 == 0 + self.lorder = 0 + padding = (kernel_size - 1) // 2 + self.depthwise_conv = nn.Conv1d( + channels, + channels, + kernel_size, + stride=1, + padding=padding, + groups=channels, + bias=bias, + ) + self.norm = nn.BatchNorm1d(channels) + self.pointwise_conv2 = nn.Conv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + self.activation = Swish() + + def forward(self, x: Tensor) -> Tensor: + """Compute convolution module. + + Args: + x: Input tensor (#time, batch, channels). + + Returns: + Tensor: Output tensor (#time, batch, channels). + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + if self.lorder > 0: + # manualy padding self.lorder zeros to the left + # make depthwise_conv causal + x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) + x = self.depthwise_conv(x) + x = self.activation(self.norm(x)) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) + + +class Swish(torch.nn.Module): + """Construct an Swish object.""" + + def forward(self, x: Tensor) -> Tensor: + """Return Swich activation function.""" + return x * torch.sigmoid(x) + + +def identity(x): + return x diff --git a/egs/fisher_swbd/ASR/streaming_conformer_ctc/label_smoothing.py b/egs/fisher_swbd/ASR/streaming_conformer_ctc/label_smoothing.py new file mode 120000 index 000000000..08734abd7 --- /dev/null +++ b/egs/fisher_swbd/ASR/streaming_conformer_ctc/label_smoothing.py @@ -0,0 +1 @@ +../conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/fisher_swbd/ASR/streaming_conformer_ctc/streaming_decode.py b/egs/fisher_swbd/ASR/streaming_conformer_ctc/streaming_decode.py new file mode 100755 index 000000000..a74c51836 --- /dev/null +++ b/egs/fisher_swbd/ASR/streaming_conformer_ctc/streaming_decode.py @@ -0,0 +1,521 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py#L166 +def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: + new_hyp: List[int] = [] + cur = 0 + while cur < len(hyp): + if hyp[cur] != 0: + new_hyp.append(hyp[cur]) + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + cur += 1 + return new_hyp + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--chunk-size", + type=int, + default=8, + help="Frames of right context" + "-1 for whole right context, i.e. non-streaming decoding", + ) + + parser.add_argument( + "--tailing-num-frames", + type=int, + default=20, + help="tailing dummy frames padded to the right," + "only used during decoding", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="simulate chunk by chunk decoding", + ) + parser.add_argument( + "--method", + type=str, + default="ctc-greedy-search", + help="Streaming Decoding method", + ) + + parser.add_argument( + "--export", + type=str2bool, + default=False, + help="""When enabled, the averaged model is saved to + conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved. + pretrained.pt contains a dict {"model": model.state_dict()}, + which can be loaded by `icefall.checkpoint.load_checkpoint()`. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="streaming_conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--trained-dir", + type=Path, + default=None, + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe", + help="The lang dir", + ) + + parser.add_argument( + "--avg-models", + type=str, + default=None, + help="Manually select models to average, seperated by comma;" + "e.g. 60,62,63,72", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "exp_dir": Path("conformer_ctc/exp"), + "lang_dir": Path("data/lang_bpe"), + # parameters for conformer + "causal": True, + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_decoder_layers": 6, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + chunk_size: int = -1, + simulate_streaming=False, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + model: + The neural model. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + feature = batch["inputs"] + device = torch.device("cuda") + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + # Extra dummy tailing frames my reduce deletion error + # example WITHOUT padding: + # CHAPTER SEVEN ON THE RACES OF MAN + # example WITH padding: + # CHAPTER SEVEN ON THE RACES OF (MAN->*) + tailing_frames = ( + torch.tensor([-23.0259]) + .expand([feature.size(0), params.tailing_num_frames, 80]) + .to(feature.device) + ) + feature = torch.cat([feature, tailing_frames], dim=1) + supervisions["num_frames"] += params.tailing_num_frames + + nnet_output, memory, memory_key_padding_mask = model( + feature, + supervisions, + chunk_size=chunk_size, + simulate_streaming=simulate_streaming, + ) + + assert params.method == "ctc-greedy-search" + key = "ctc-greedy-search" + batch_size = nnet_output.size(0) + maxlen = nnet_output.size(1) + topk_prob, topk_index = nnet_output.topk(1, dim=2) # (B, maxlen, 1) + topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) + topk_index = topk_index.masked_fill_( + memory_key_padding_mask, 0 + ) # (B, maxlen) + token_ids = [token_id.tolist() for token_id in topk_index] + token_ids = [ + remove_duplicates_and_blank(token_id) for token_id in token_ids + ] + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + return {key: hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + chunk_size: int = -1, + simulate_streaming=False, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + chunk_size: + right context to simulate streaming decoding + -1 for whole right context, i.e. non-stream decoding + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + results = [] + + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + sos_id=sos_id, + eos_id=eos_id, + chunk_size=chunk_size, + simulate_streaming=simulate_streaming, + ) + + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + if params.method == "attention-decoder": + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() + if params.avg_models is not None: + avg_models = params.avg_models.replace(",", "_") + result_file_prefix = f"epoch-avg-{avg_models}-chunksize \ + -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-" + else: + result_file_prefix = f"epoch-{params.epoch}-avg-{params.avg}-chunksize \ + -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-" + for key, results in results_dict.items(): + recog_path = ( + params.exp_dir + / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.exp_dir + / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + if params.trained_dir is not None: + params.lang_dir = Path(params.trained_dir) / "lang_bpe" + # used naming result files + params.epoch = "trained_model" + params.avg = 1 + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + causal=params.causal, + ) + + if params.trained_dir is not None: + model_name = f"{params.trained_dir}/trained_streaming_conformer.pt" + load_checkpoint(model_name, model) + elif params.avg == 1 and params.avg_models is not None: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + filenames = [] + if params.avg_models is not None: + model_ids = params.avg_models.split(",") + for i in model_ids: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + else: + start = params.epoch - params.avg + 1 + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames)) + + if params.export: + logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) + return + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + # CAUTION: `test_sets` is for displaying only. + # If you want to skip test-clean, you have to skip + # it inside the for loop. That is, use + # + # if test_set == 'test-clean': continue + # + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + test_sets = ["test-clean", "test-other"] + for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + bpe_model=bpe_model, + word_table=lexicon.word_table, + sos_id=sos_id, + eos_id=eos_id, + chunk_size=params.chunk_size, + simulate_streaming=params.simulate_streaming, + ) + + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/fisher_swbd/ASR/streaming_conformer_ctc/subsampling.py b/egs/fisher_swbd/ASR/streaming_conformer_ctc/subsampling.py new file mode 120000 index 000000000..6fee09e58 --- /dev/null +++ b/egs/fisher_swbd/ASR/streaming_conformer_ctc/subsampling.py @@ -0,0 +1 @@ +../conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/fisher_swbd/ASR/streaming_conformer_ctc/train.py b/egs/fisher_swbd/ASR/streaming_conformer_ctc/train.py new file mode 100755 index 000000000..8b4d6701e --- /dev/null +++ b/egs/fisher_swbd/ASR/streaming_conformer_ctc/train.py @@ -0,0 +1,745 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=78, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + conformer_ctc/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.8, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=True, + help="Whether to use dynamic right context during training.", + ) + + parser.add_argument( + "--short-chunk-proportion", + type=float, + default=0.7, + help="Proportion of samples trained with short right context", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Whether to do batch normalization for the + input features. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + + - weight_decay: The weight_decay for the optimizer. + + - lr_factor: The lr_factor for Noam optimizer. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + "num_decoder_layers": 6, + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for Noam + "weight_decay": 1e-6, + "lr_factor": 5.0, + "warm_step": 80000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: BpeCtcTrainingGraphCompiler, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = graph_compiler.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model( + feature, + supervisions, + dynamic_chunk_training=params.dynamic_chunk_training, + short_chunk_proportion=params.short_chunk_proportion, + ) + # nnet_output is (N, T, C) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + token_ids = graph_compiler.texts_to_ids(texts) + + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + if params.att_rate != 0.0: + with torch.set_grad_enabled(is_training): + mmodel = model.module if hasattr(model, "module") else model + # Note: We need to generate an unsorted version of token_ids + # `encode_supervisions()` called above sorts text, but + # encoder_memory and memory_mask are not sorted, so we + # use an unsorted version `supervisions["text"]` to regenerate + # the token_ids + # + # See https://github.com/k2-fsa/icefall/issues/97 + # for more details + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + else: + loss = ctc_loss + att_loss = torch.tensor([0]) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + info["loss"] = loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: BpeCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if batch_idx % params.log_interval == 0: + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(42) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + librispeech = LibriSpeechAsrDataModule(args) + train_dl = librispeech.train_dataloaders() + valid_dl = librispeech.valid_dataloaders() + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/fisher_swbd/ASR/streaming_conformer_ctc/transformer.py b/egs/fisher_swbd/ASR/streaming_conformer_ctc/transformer.py new file mode 100644 index 000000000..bc78e4a41 --- /dev/null +++ b/egs/fisher_swbd/ASR/streaming_conformer_ctc/transformer.py @@ -0,0 +1,966 @@ +# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from label_smoothing import LabelSmoothingLoss +from subsampling import Conv2dSubsampling, VggSubsampling +from torch.nn.utils.rnn import pad_sequence + +# Note: TorchScript requires Dict/List/etc. to be fully typed. +Supervisions = Dict[str, torch.Tensor] + + +class Transformer(nn.Module): + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + normalize_before: bool = True, + vgg_frontend: bool = False, + use_feat_batchnorm: bool = False, + ) -> None: + """ + Args: + num_features: + The input dimension of the model. + num_classes: + The output dimension of the model. + subsampling_factor: + Number of output frames is num_in_frames // subsampling_factor. + Currently, subsampling_factor MUST be 4. + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder/decoder. + num_encoder_layers: + Number of encoder layers. + num_decoder_layers: + Number of decoder layers. + dropout: + Dropout in encoder/decoder. + normalize_before: + If True, use pre-layer norm; False to use post-layer norm. + vgg_frontend: + True to use vgg style frontend for subsampling. + use_feat_batchnorm: + True to use batchnorm for the input layer. + """ + super().__init__() + self.use_feat_batchnorm = use_feat_batchnorm + if use_feat_batchnorm: + self.feat_batchnorm = nn.BatchNorm1d(num_features) + + self.num_features = num_features + self.num_classes = num_classes + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_classes) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_classes -> d_model + if vgg_frontend: + self.encoder_embed = VggSubsampling(num_features, d_model) + else: + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_pos = PositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + encoder_norm = nn.LayerNorm(d_model) + else: + encoder_norm = None + + self.encoder = nn.TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + norm=encoder_norm, + ) + + # TODO(fangjun): remove dropout + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) + ) + + if num_decoder_layers > 0: + self.decoder_num_class = ( + self.num_classes + ) # bpe model already has sos/eos symbol + + self.decoder_embed = nn.Embedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model + ) + self.decoder_pos = PositionalEncoding(d_model, dropout) + + decoder_layer = TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + normalize_before=normalize_before, + ) + + if normalize_before: + decoder_norm = nn.LayerNorm(d_model) + else: + decoder_norm = None + + self.decoder = nn.TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=num_decoder_layers, + norm=decoder_norm, + ) + + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) + + self.decoder_criterion = LabelSmoothingLoss() + else: + self.decoder_criterion = None + + def forward( + self, + x: torch.Tensor, + supervision: Optional[Supervisions] = None, + dynamic_chunk_training: bool = False, + short_chunk_proportion: float = 0.5, + chunk_size: int = -1, + simulate_streaming=False, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (N, T, C). + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Return a tuple containing 3 tensors: + - CTC output for ctc decoding. Its shape is (N, T, C) + - Encoder output with shape (T, N, C). It can be used as key and + value for the decoder. + - Encoder output padding mask. It can be used as + memory_key_padding_mask for the decoder. Its shape is (N, T). + It is None if `supervision` is None. + """ + if self.use_feat_batchnorm: + x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) + x = self.feat_batchnorm(x) + x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, + supervision, + dynamic_chunk_training=dynamic_chunk_training, + short_chunk_proportion=short_chunk_proportion, + chunk_size=chunk_size, + simulate_streaming=simulate_streaming, + ) + x = self.ctc_output(encoder_memory) + return x, encoder_memory, memory_key_padding_mask + + def run_encoder( + self, + x: torch.Tensor, + supervisions: Optional[Supervisions] = None, + chunk_size: int = -1, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Run the transformer encoder. + + Args: + x: + The model input. Its shape is (N, T, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute the encoder padding mask, which is used as memory key + padding mask for the decoder. + chunk_size: right chunk_size to simulate streaming decoding + -1 for whole right context + Returns: + Return a tuple with two tensors: + - The encoder output, with shape (T, N, C) + - encoder padding mask, with shape (N, T). + The mask is None if `supervisions` is None. + It is used as memory key padding mask in the decoder. + """ + # streaming decoding(chunk_size >= 0) is only verified with Conformer + assert chunk_size == -1 + x = self.encoder_embed(x) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + mask = encoder_padding_mask(x.size(0), supervisions) + mask = mask.to(x.device) if mask is not None else None + x = self.encoder( + x, src_key_padding_mask=mask, chunk_size=chunk_size + ) # (T, N, C) + + return x, mask + + def ctc_output(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + The output tensor from the transformer encoder. + Its shape is (T, N, C) + + Returns: + Return a tensor that can be used for CTC decoding. + Its shape is (N, T, C) + """ + x = self.encoder_output_layer(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) + return x + + @torch.jit.export + def decoder_forward( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs. Each sublist contains IDs for an utterance. + The IDs can be either phone IDs or word piece IDs. + sos_id: + sos token id + eos_id: + eos token id + + Returns: + A scalar, the **sum** of label smoothing loss over utterances + in the batch without any normalization. + """ + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) + + device = memory.device + ys_in_pad = ys_in_pad.to(device) + ys_out_pad = ys_out_pad.to(device) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) + + decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) + + return decoder_loss + + @torch.jit.export + def decoder_nll( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[torch.Tensor], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder with shape (T, N, C) + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs (e.g., word piece IDs). + Each sublist represents an utterance. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + Returns: + A 2-D tensor of shape (len(token_ids), max_token_length) + representing the cross entropy loss (i.e., negative log-likelihood). + """ + # The common part between this function and decoder_forward could be + # extracted as a separate function. + if isinstance(token_ids[0], torch.Tensor): + # This branch is executed by torchscript in C++. + # See https://github.com/k2-fsa/k2/pull/870 + # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 + token_ids = [tolist(t) for t in token_ids] + + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) + + device = memory.device + ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, B, F) + pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) + pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + pred_pad.view(-1, self.decoder_num_class), + ys_out_pad.view(-1), + ignore_index=-1, + reduction="none", + ) + + nll = nll.view(pred_pad.shape[0], -1) + + return nll + + +class TransformerEncoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerEncoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + normalize_before: + whether to use layer_norm before the first block. + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerEncoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerEncoderLayer, self).__setstate__(state) + + def forward( + self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional) + + Shape: + src: (S, N, E). + src_mask: (S, S). + src_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + residual = src + if self.normalize_before: + src = self.norm1(src) + src2 = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = residual + self.dropout1(src2) + if not self.normalize_before: + src = self.norm1(src) + + residual = src + if self.normalize_before: + src = self.norm2(src) + src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) + src = residual + self.dropout2(src2) + if not self.normalize_before: + src = self.norm2(src) + return src + + +class TransformerDecoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerDecoderLayer. + Add support of normalize_before, + i.e., use layer_norm before the first block. + + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + + Examples:: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: str = "relu", + normalize_before: bool = True, + ) -> None: + super(TransformerDecoderLayer, self).__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) + # Implementation of Feedforward model + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + self.normalize_before = normalize_before + + def __setstate__(self, state): + if "activation" not in state: + state["activation"] = nn.functional.relu + super(TransformerDecoderLayer, self).__setstate__(state) + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: + the sequence to the decoder layer (required). + memory: + the sequence from the last layer of the encoder (required). + tgt_mask: + the mask for the tgt sequence (optional). + memory_mask: + the mask for the memory sequence (optional). + tgt_key_padding_mask: + the mask for the tgt keys per batch (optional). + memory_key_padding_mask: + the mask for the memory keys per batch (optional). + + Shape: + tgt: (T, N, E). + memory: (S, N, E). + tgt_mask: (T, T). + memory_mask: (T, S). + tgt_key_padding_mask: (N, T). + memory_key_padding_mask: (N, S). + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number + """ + residual = tgt + if self.normalize_before: + tgt = self.norm1(tgt) + tgt2 = self.self_attn( + tgt, + tgt, + tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask, + )[0] + tgt = residual + self.dropout1(tgt2) + if not self.normalize_before: + tgt = self.norm1(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm2(tgt) + tgt2 = self.src_attn( + tgt, + memory, + memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = residual + self.dropout2(tgt2) + if not self.normalize_before: + tgt = self.norm2(tgt) + + residual = tgt + if self.normalize_before: + tgt = self.norm3(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = residual + self.dropout3(tgt2) + if not self.normalize_before: + tgt = self.norm3(tgt) + return tgt + + +def _get_activation_fn(activation: str): + if activation == "relu": + return nn.functional.relu + elif activation == "gelu": + return nn.functional.gelu + + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) + + +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) + + Note:: + + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) + """ + + def __init__(self, d_model: int, dropout: float = 0.1) -> None: + """ + Args: + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. + """ + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout) + # not doing: self.pe = None because of errors thrown by torchscript + self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) + + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is (1, T1, d_model). The shape of the input x + is (N, T, d_model). If T > T1, then we change the shape of self.pe + to (N, T, d_model). Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape (N, T, C). + Returns: + Return None. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + # Now pe is of shape (1, T, d_model), where T is x.size(1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding. + + Args: + x: + Its shape is (N, T, C) + + Returns: + Return a tensor of shape (N, T, C) + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1), :] + return self.dropout(x) + + +class Noam(object): + """ + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa + + Args: + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps + """ + + def __init__( + self, + params, + model_size: int = 256, + factor: float = 10.0, + warm_step: int = 25000, + weight_decay=0, + ) -> None: + """Construct an Noam object.""" + self.optimizer = torch.optim.Adam( + params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay + ) + self._step = 0 + self.warmup = warm_step + self.factor = factor + self.model_size = model_size + self._rate = 0 + + @property + def param_groups(self): + """Return param_groups.""" + return self.optimizer.param_groups + + def step(self): + """Update parameters and rate.""" + self._step += 1 + rate = self.rate() + for p in self.optimizer.param_groups: + p["lr"] = rate + self._rate = rate + self.optimizer.step() + + def rate(self, step=None): + """Implement `lrate` above.""" + if step is None: + step = self._step + return ( + self.factor + * self.model_size ** (-0.5) + * min(step ** (-0.5), step * self.warmup ** (-1.5)) + ) + + def zero_grad(self): + """Reset gradient.""" + self.optimizer.zero_grad() + + def state_dict(self): + """Return state_dict.""" + return { + "_step": self._step, + "warmup": self.warmup, + "factor": self.factor, + "model_size": self.model_size, + "_rate": self._rate, + "optimizer": self.optimizer.state_dict(), + } + + def load_state_dict(self, state_dict): + """Load state_dict.""" + for key, value in state_dict.items(): + if key == "optimizer": + self.optimizer.load_state_dict(state_dict["optimizer"]) + else: + setattr(self, key, value) + + +def encoder_padding_mask( + max_len: int, supervisions: Optional[Supervisions] = None +) -> Optional[torch.Tensor]: + """Make mask tensor containing indexes of padded part. + + TODO:: + This function **assumes** that the model uses + a subsampling factor of 4. We should remove that + assumption later. + + Args: + max_len: + Maximum length of input features. + CAUTION: It is the length after subsampling. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Tensor: Mask tensor of dimension (batch_size, input_length), + True denote the masked indices. + """ + if supervisions is None: + return None + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"], + supervisions["num_frames"], + ), + 1, + ).to(torch.int32) + + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] + for idx in range(supervision_segments.size(0)): + # Note: TorchScript doesn't allow to unpack tensors as tuples + sequence_idx = supervision_segments[idx, 0].item() + start_frame = supervision_segments[idx, 1].item() + num_frames = supervision_segments[idx, 2].item() + lengths[sequence_idx] = start_frame + num_frames + + lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] + bs = int(len(lengths)) + seq_range = torch.arange(0, max_len, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) + # Note: TorchScript doesn't implement Tensor.new() + seq_length_expand = torch.tensor( + lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype + ).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + return mask + + +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: + """Generate a length mask for input. + + The masked position are filled with True, + Unmasked positions are filled with False. + + Args: + ys_pad: + padded tensor of dimension (batch_size, input_length). + ignore_id: + the ignored number (the padding number) in ys_pad + + Returns: + Tensor: + a bool tensor of the same shape as the input tensor. + """ + ys_mask = ys_pad == ignore_id + return ys_mask + + +def generate_square_subsequent_mask(sz: int) -> torch.Tensor: + """Generate a square mask for the sequence. The masked positions are + filled with float('-inf'). Unmasked positions are filled with float(0.0). + The mask can be used for masked self-attention. + + For instance, if sz is 3, it returns:: + + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0]]) + + Args: + sz: mask size + + Returns: + A square mask of dimension (sz, sz) + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask + + +def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: + """Prepend sos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + sos_id: + The ID of the SOS token. + + Return: + Return a new list-of-list, where each sublist starts + with SOS ID. + """ + return [[sos_id] + utt for utt in token_ids] + + +def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: + """Append eos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + eos_id: + The ID of the EOS token. + + Return: + Return a new list-of-list, where each sublist ends + with EOS ID. + """ + return [utt + [eos_id] for utt in token_ids] + + +def tolist(t: torch.Tensor) -> List[int]: + """Used by jit""" + return torch.jit.annotate(List[int], t.tolist()) diff --git a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/README.md b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/README.md new file mode 100644 index 000000000..94d4ed6a3 --- /dev/null +++ b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/README.md @@ -0,0 +1,4 @@ + +Please visit + +for how to run this recipe. diff --git a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/__init__.py b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py new file mode 100644 index 000000000..e075a2d03 --- /dev/null +++ b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -0,0 +1,385 @@ +# Copyright 2021 Piotr Żelasko +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path + +from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from lhotse.dataset import ( + BucketingSampler, + CutConcatenate, + CutMix, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class LibriSpeechAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--full-libri", + type=str2bool, + default=True, + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: + logging.info("About to get Musan cuts") + cuts_musan = load_manifest( + self.args.manifest_dir / "cuts_musan.json.gz" + ) + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + transforms.append( + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=2, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using BucketingSampler.") + train_sampler = BucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + bucket_method="equal_duration", + drop_last=True, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = BucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = BucketingSampler( + cuts, max_duration=self.args.max_duration, shuffle=False + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_clean_100_cuts(self) -> CutSet: + logging.info("About to get train-clean-100 cuts") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-100.json.gz" + ) + + @lru_cache() + def train_clean_360_cuts(self) -> CutSet: + logging.info("About to get train-clean-360 cuts") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-360.json.gz" + ) + + @lru_cache() + def train_other_500_cuts(self) -> CutSet: + logging.info("About to get train-other-500 cuts") + return load_manifest( + self.args.manifest_dir / "cuts_train-other-500.json.gz" + ) + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest(self.args.manifest_dir / "cuts_dev-clean.json.gz") + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest(self.args.manifest_dir / "cuts_dev-other.json.gz") + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest(self.args.manifest_dir / "cuts_test-clean.json.gz") + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz") diff --git a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/decode.py b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/decode.py new file mode 100755 index 000000000..9c964c2aa --- /dev/null +++ b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/decode.py @@ -0,0 +1,505 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from model import TdnnLstm + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.decode import ( + get_lattice, + nbest_decoding, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=19, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=5, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + parser.add_argument( + "--method", + type=str, + default="whole-lattice-rescoring", + help="""Decoding method. + Supported values are: + - (1) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (2) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (3) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--export", + type=str2bool, + default=False, + help="""When enabled, the averaged model is saved to + tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved. + pretrained.pt contains a dict {"model": model.state_dict()}, + which can be loaded by `icefall.checkpoint.load_checkpoint()`. + """, + ) + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "exp_dir": Path("tdnn_lstm_ctc/exp/"), + "lang_dir": Path("data/lang_phone"), + "lm_dir": Path("data/lm"), + "feature_dim": 80, + "subsampling_factor": 3, + "search_beam": 20, + "output_beam": 5, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: k2.Fsa, + batch: dict, + lexicon: Lexicon, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + - params.method is "nbest", it uses nbest decoding without LM rescoring. + - params.method is "nbest-rescoring", it uses nbest LM rescoring. + - params.method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + lexicon: + It contains word symbol table. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = HLG.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + feature = feature.permute(0, 2, 1) # now feature is (N, C, T) + + nnet_output = model(feature) + # nnet_output is (N, T, C) + + supervisions = batch["supervisions"] + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + ) + + if params.method in ["1best", "nbest"]: + if params.method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-{params.num_paths}" + hyps = get_texts(best_path) + hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + else: + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + + ans = dict() + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: k2.Fsa, + lexicon: Lexicon, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. + lexicon: + It contains word symbol table. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + results = [] + + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + batch=batch, + lexicon=lexicon, + G=G, + ) + + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + + num_cuts += len(batch["supervisions"]["text"]) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}-{key}", results) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_phone_id = max(lexicon.tokens) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + ) + HLG = HLG.to(device) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]: + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") + G = k2.Fsa.from_dict(d).to(device) + + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + model = TdnnLstm( + num_features=params.feature_dim, + num_classes=max_phone_id + 1, # +1 for the blank symbol + subsampling_factor=params.subsampling_factor, + ) + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + if params.export: + logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) + return + + model.to(device) + model.eval() + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + lexicon=lexicon, + G=G, + ) + + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/model.py b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/model.py new file mode 100644 index 000000000..5e04c11b4 --- /dev/null +++ b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/model.py @@ -0,0 +1,103 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.nn as nn + + +class TdnnLstm(nn.Module): + def __init__( + self, num_features: int, num_classes: int, subsampling_factor: int = 3 + ) -> None: + """ + Args: + num_features: + The input dimension of the model. + num_classes: + The output dimension of the model. + subsampling_factor: + It reduces the number of output frames by this factor. + """ + super().__init__() + self.num_features = num_features + self.num_classes = num_classes + self.subsampling_factor = subsampling_factor + self.tdnn = nn.Sequential( + nn.Conv1d( + in_channels=num_features, + out_channels=500, + kernel_size=3, + stride=1, + padding=1, + ), + nn.ReLU(inplace=True), + nn.BatchNorm1d(num_features=500, affine=False), + nn.Conv1d( + in_channels=500, + out_channels=500, + kernel_size=3, + stride=1, + padding=1, + ), + nn.ReLU(inplace=True), + nn.BatchNorm1d(num_features=500, affine=False), + nn.Conv1d( + in_channels=500, + out_channels=500, + kernel_size=3, + stride=self.subsampling_factor, # stride: subsampling_factor! + padding=1, + ), + nn.ReLU(inplace=True), + nn.BatchNorm1d(num_features=500, affine=False), + ) + self.lstms = nn.ModuleList( + [ + nn.LSTM(input_size=500, hidden_size=500, num_layers=1) + for _ in range(5) + ] + ) + self.lstm_bnorms = nn.ModuleList( + [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)] + ) + self.dropout = nn.Dropout(0.2) + self.linear = nn.Linear(in_features=500, out_features=self.num_classes) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + Its shape is [N, C, T] + + Returns: + The output tensor has shape [N, T, C] + """ + x = self.tdnn(x) + x = x.permute(2, 0, 1) # (N, C, T) -> (T, N, C) -> how LSTM expects it + for lstm, bnorm in zip(self.lstms, self.lstm_bnorms): + x_new, _ = lstm(x) + x_new = bnorm(x_new.permute(1, 2, 0)).permute( + 2, 0, 1 + ) # (T, N, C) -> (N, C, T) -> (T, N, C) + x_new = self.dropout(x_new) + x = x_new + x # skip connections + x = x.transpose( + 1, 0 + ) # (T, N, C) -> (N, T, C) -> linear expects "features" in the last dim + x = self.linear(x) + x = nn.functional.log_softmax(x, dim=-1) + return x diff --git a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/pretrained.py b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/pretrained.py new file mode 100755 index 000000000..2baeb6bba --- /dev/null +++ b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/pretrained.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from model import TdnnLstm +from torch.nn.utils.rnn import pad_sequence + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_whole_lattice, +) +from icefall.utils import AttributeDict, get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--words-file", + type=str, + required=True, + help="Path to words.txt", + ) + + parser.add_argument( + "--HLG", type=str, required=True, help="Path to HLG.pt." + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.8, + help=""" + Used only when method is whole-lattice-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "feature_dim": 80, + "subsampling_factor": 3, + "num_classes": 72, + "sample_rate": 16000, + "search_beam": 20, + "output_beam": 5, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = TdnnLstm( + num_features=params.feature_dim, + num_classes=params.num_classes, + subsampling_factor=params.subsampling_factor, + ) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"]) + model.to(device) + model.eval() + + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method == "whole-lattice-rescoring": + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = G.to(device) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G.lm_scores = G.scores.clone() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + features = features.permute(0, 2, 1) # now features is (N, C, T) + + with torch.no_grad(): + nnet_output = model(features) + # nnet_output is (N, T, C) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/train.py b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/train.py new file mode 100755 index 000000000..7439e157a --- /dev/null +++ b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/train.py @@ -0,0 +1,603 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +from asr_datamodule import LibriSpeechAsrDataModule +from lhotse.utils import fix_random_seed +from model import TdnnLstm +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.optim.lr_scheduler import StepLR +from torch.utils.tensorboard import SummaryWriter + +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.graph_compiler import CtcTrainingGraphCompiler +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + tdnn_lstm_ctc/exp/epoch-{start_epoch-1}.pt + """, + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + is saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - exp_dir: It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + + - lang_dir: It contains language related input files such as + "lexicon.txt" + + - lr: It specifies the initial learning rate + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - weight_decay: The weight_decay for the optimizer. + + - subsampling_factor: The subsampling factor for the model. + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval` is 0 + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + """ + params = AttributeDict( + { + "exp_dir": Path("tdnn_lstm_ctc/exp"), + "lang_dir": Path("data/lang_phone"), + "lr": 1e-3, + "feature_dim": 80, + "weight_decay": 5e-4, + "subsampling_factor": 3, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 10, + "reset_interval": 200, + "valid_interval": 1000, + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + scheduler: torch.optim.lr_scheduler._LRScheduler, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: CtcTrainingGraphCompiler, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of TdnnLstm in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = graph_compiler.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + feature = feature.permute(0, 2, 1) # now feature is (N, C, T) + assert feature.ndim == 3 + feature = feature.to(device) + + with torch.set_grad_enabled(is_training): + nnet_output = model(feature) + # nnet_output is (N, T, C) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervisions = batch["supervisions"] + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + decoding_graph = graph_compiler.compile(texts) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["loss"] = loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: CtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process. The validation loss + is saved in `params.valid_loss`. + """ + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + graph_compiler: CtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + # summary stats. + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + if batch_idx % params.log_interval == 0: + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, + "train/valid_", + params.batch_idx_train, + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(42) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_phone_id = max(lexicon.tokens) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device) + + model = TdnnLstm( + num_features=params.feature_dim, + num_classes=max_phone_id + 1, # +1 for the blank symbol + subsampling_factor=params.subsampling_factor, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + optimizer = optim.AdamW( + model.parameters(), + lr=params.lr, + weight_decay=params.weight_decay, + ) + scheduler = StepLR(optimizer, step_size=8, gamma=0.1) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + scheduler.load_state_dict(checkpoints["scheduler"]) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + train_dl = librispeech.train_dataloaders(train_cuts) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + for epoch in range(params.start_epoch, params.num_epochs): + train_dl.sampler.set_epoch(epoch) + + if epoch > params.start_epoch: + logging.info(f"epoch {epoch}, lr: {scheduler.get_last_lr()[0]}") + + if tb_writer is not None: + tb_writer.add_scalar( + "train/lr", + scheduler.get_last_lr()[0], + params.batch_idx_train, + ) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + scheduler.step() + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + logging.info("Done!") + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +if __name__ == "__main__": + main() From ccab93e8e2092128b7911281ba7d8a71faf3bdef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sat, 15 Jan 2022 00:59:26 +0000 Subject: [PATCH 02/12] Working datamodule --- egs/fisher_swbd/ASR/conformer_ctc/train.py | 16 +- egs/fisher_swbd/ASR/prepare.sh | 4 +- .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 294 ++++++------------ 3 files changed, 105 insertions(+), 209 deletions(-) diff --git a/egs/fisher_swbd/ASR/conformer_ctc/train.py b/egs/fisher_swbd/ASR/conformer_ctc/train.py index c1fa814c0..29f9f6cb6 100755 --- a/egs/fisher_swbd/ASR/conformer_ctc/train.py +++ b/egs/fisher_swbd/ASR/conformer_ctc/train.py @@ -27,7 +27,7 @@ import k2 import torch import torch.multiprocessing as mp import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import AsrDataModule from conformer import Conformer from lhotse.utils import fix_random_seed from torch import Tensor @@ -620,17 +620,13 @@ def run(rank, world_size, args): if checkpoints: optimizer.load_state_dict(checkpoints["optimizer"]) - librispeech = LibriSpeechAsrDataModule(args) + datamodule = AsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() - train_dl = librispeech.train_dataloaders(train_cuts) + train_cuts = datamodule.train_cuts() + train_dl = datamodule.train_dataloaders(train_cuts) - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) + valid_cuts = datamodule.dev_cuts() + valid_dl = datamodule.valid_dataloaders(valid_cuts) scan_pessimistic_batches_for_oom( model=model, diff --git a/egs/fisher_swbd/ASR/prepare.sh b/egs/fisher_swbd/ASR/prepare.sh index bf61b1be9..dc89dfdac 100755 --- a/egs/fisher_swbd/ASR/prepare.sh +++ b/egs/fisher_swbd/ASR/prepare.sh @@ -92,7 +92,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # We assume that you have downloaded the LibriSpeech corpus # to $dl_dir/LibriSpeech mkdir -p data/manifests/fisher - lhotse prepare fisher-english $dl_dir data/manifests/fisher + lhotse prepare fisher-english --absolute-paths 1 $dl_dir data/manifests/fisher fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then @@ -100,7 +100,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then # We assume that you have downloaded the LibriSpeech corpus # to $dl_dir/LibriSpeech mkdir -p data/manifests/swbd - lhotse prepare switchboard --omit-silence $dl_dir/LDC97S62 data/manifests/swbd + lhotse prepare switchboard --absolute-paths 1 --omit-silence $dl_dir/LDC97S62 data/manifests/swbd fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then diff --git a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py index e075a2d03..40c8468f6 100644 --- a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -20,14 +20,16 @@ import logging from functools import lru_cache from pathlib import Path -from lhotse import CutSet, Fbank, FbankConfig, load_manifest +from tqdm import tqdm + +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse.dataset import ( BucketingSampler, - CutConcatenate, CutMix, + DynamicBucketingSampler, K2SpeechRecognitionDataset, + PerturbSpeed, PrecomputedFeatures, - SingleCutSampler, SpecAugment, ) from lhotse.dataset.input_strategies import OnTheFlyFeatures @@ -36,7 +38,12 @@ from torch.utils.data import DataLoader from icefall.utils import str2bool -class LibriSpeechAsrDataModule: +class Resample16kHz: + def __call__(self, cuts: CutSet) -> CutSet: + return cuts.resample(16000).with_recording_path_prefix('download') + + +class AsrDataModule: """ DataModule for k2 ASR experiments. It assumes there is always one train and valid dataloader, @@ -66,17 +73,10 @@ class LibriSpeechAsrDataModule: "effective batch sizes, sampling strategies, applied data " "augmentations, etc.", ) - group.add_argument( - "--full-libri", - type=str2bool, - default=True, - help="When enabled, use 960h LibriSpeech. " - "Otherwise, use 100h subset.", - ) group.add_argument( "--manifest-dir", type=Path, - default=Path("data/fbank"), + default=Path("data/manifests"), help="Path to directory with train/valid/test cuts.", ) group.add_argument( @@ -86,13 +86,6 @@ class LibriSpeechAsrDataModule: help="Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.", ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) group.add_argument( "--num-buckets", type=int, @@ -100,32 +93,10 @@ class LibriSpeechAsrDataModule: help="The number of buckets for the BucketingSampler" "(you might want to increase it for larger datasets).", ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) group.add_argument( "--on-the-fly-feats", type=str2bool, - default=False, + default=True, help="When enabled, use on-the-fly cut mixing and feature " "extraction. Will drop existing precomputed feature manifests " "if available.", @@ -137,30 +108,15 @@ class LibriSpeechAsrDataModule: help="When enabled (=default), the examples will be " "shuffled for each epoch.", ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) group.add_argument( "--num-workers", type=int, - default=2, + default=8, help="The number of training dataloader workers that " "collect the batches.", ) - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - group.add_argument( "--spec-aug-time-warp-factor", type=int, @@ -171,52 +127,28 @@ class LibriSpeechAsrDataModule: "A value less than 1 means to disable time warp.", ) - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") cuts_musan = load_manifest( - self.args.manifest_dir / "cuts_musan.json.gz" + self.args.manifest_dir / "musan_cuts.jsonl.gz" ) - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - transforms.append( + input_strategy = PrecomputedFeatures() + if self.args.on_the_fly_feats: + input_strategy = OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80, sampling_rate=16000)), + ) + + train = K2SpeechRecognitionDataset( + input_strategy=input_strategy, + cut_transforms=[ + PerturbSpeed(factors=[0.9, 1.1], p=2 / 3, preserve_id=True), + Resample16kHz(), CutMix( cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info( - f"Time warp factor: {self.args.spec_aug_time_warp_factor}" - ) - input_transforms.append( + ), + ], + input_transforms=[ SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, num_frame_masks=2, @@ -224,56 +156,19 @@ class LibriSpeechAsrDataModule: num_feature_masks=2, frames_mask_size=100, ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, + ], + return_cuts=True, ) - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=True, + ) - if self.args.bucketing_sampler: - logging.info("Using BucketingSampler.") - train_sampler = BucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - bucket_method="equal_duration", - drop_last=True, - ) - else: - logging.info("Using SingleCutSampler.") - train_sampler = SingleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) logging.info("About to create train dataloader") - train_dl = DataLoader( train, sampler=train_sampler, @@ -285,39 +180,34 @@ class LibriSpeechAsrDataModule: return train_dl def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms logging.info("About to create dev dataset") + input_strategy = PrecomputedFeatures() if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80)) - ), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, + input_strategy = OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80, sampling_rate=16000)), ) + + validate = K2SpeechRecognitionDataset( + return_cuts=True, + input_strategy=input_strategy, + cut_transforms=[ + Resample16kHz(), + ], + ) + valid_sampler = BucketingSampler( cuts_valid, max_duration=self.args.max_duration, shuffle=False, ) + logging.info("About to create dev dataloader") valid_dl = DataLoader( validate, sampler=valid_sampler, batch_size=None, - num_workers=2, + num_workers=self.args.num_workers, persistent_workers=False, ) @@ -325,11 +215,19 @@ class LibriSpeechAsrDataModule: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") + + input_strategy = PrecomputedFeatures() + if self.args.on_the_fly_feats: + input_strategy = OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80, sampling_rate=16000)), + ) + test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, + return_cuts=True, + input_strategy=input_strategy, + cut_transforms=[ + Resample16kHz(), + ], ) sampler = BucketingSampler( cuts, max_duration=self.args.max_duration, shuffle=False @@ -344,42 +242,44 @@ class LibriSpeechAsrDataModule: return test_dl @lru_cache() - def train_clean_100_cuts(self) -> CutSet: - logging.info("About to get train-clean-100 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-100.json.gz" + def train_cuts(self) -> CutSet: + logging.info("About to get train Fisher + SWBD cuts") + return load_manifest_lazy( + self.args.manifest_dir + / "train_utterances_fisher-swbd_cuts.jsonl.gz" ) @lru_cache() - def train_clean_360_cuts(self) -> CutSet: - logging.info("About to get train-clean-360 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-clean-360.json.gz" + def dev_cuts(self) -> CutSet: + logging.info("About to get dev Fisher + SWBD cuts") + return load_manifest_lazy( + self.args.manifest_dir / "dev_utterances_fisher-swbd_cuts.jsonl.gz" ) @lru_cache() - def train_other_500_cuts(self) -> CutSet: - logging.info("About to get train-other-500 cuts") - return load_manifest( - self.args.manifest_dir / "cuts_train-other-500.json.gz" - ) - - @lru_cache() - def dev_clean_cuts(self) -> CutSet: - logging.info("About to get dev-clean cuts") - return load_manifest(self.args.manifest_dir / "cuts_dev-clean.json.gz") - - @lru_cache() - def dev_other_cuts(self) -> CutSet: - logging.info("About to get dev-other cuts") - return load_manifest(self.args.manifest_dir / "cuts_dev-other.json.gz") - - @lru_cache() - def test_clean_cuts(self) -> CutSet: + def test_cuts(self) -> CutSet: logging.info("About to get test-clean cuts") - return load_manifest(self.args.manifest_dir / "cuts_test-clean.json.gz") + raise NotImplemented - @lru_cache() - def test_other_cuts(self) -> CutSet: - logging.info("About to get test-other cuts") - return load_manifest(self.args.manifest_dir / "cuts_test-other.json.gz") + +def test(): + parser = argparse.ArgumentParser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + adm = AsrDataModule(args) + + cuts = adm.train_cuts() + dl = adm.train_dataloaders(cuts) + for i, batch in tqdm(enumerate(dl)): + if i == 100: + break + + cuts = adm.dev_cuts() + dl = adm.valid_dataloaders(cuts) + for i, batch in tqdm(enumerate(dl)): + if i == 100: + break + + +if __name__ == '__main__': + test() From a7666d864c96b9c3bdf1a980535232b8f2db088c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sat, 15 Jan 2022 01:27:31 +0000 Subject: [PATCH 03/12] Missing steps in prepare.sh --- .../normalize_and_filter_supervisions.py | 51 +++++++++++------ egs/fisher_swbd/ASR/prepare.sh | 55 ++++++++++++------- 2 files changed, 68 insertions(+), 38 deletions(-) diff --git a/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py b/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py index be9715578..9933d3a4f 100644 --- a/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py +++ b/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py @@ -17,15 +17,34 @@ def get_args(): return parser.parse_args() -# Note: the functions "normalize" and "keep" implement the logic similar to -# Kaldi's data prep scripts for Fisher: -# https://github.com/kaldi-asr/kaldi/blob/master/egs/fisher_swbd/s5/local/fisher_data_prep.sh -# and for SWBD: -# https://github.com/kaldi-asr/kaldi/blob/master/egs/fisher_swbd/s5/local/swbd1_data_prep.sh +class FisherSwbdNormalizer: + """ + Note: the functions "normalize" and "keep" implement the logic similar to + Kaldi's data prep scripts for Fisher: + https://github.com/kaldi-asr/kaldi/blob/master/egs/fisher_swbd/s5/local/fisher_data_prep.sh + and for SWBD: + https://github.com/kaldi-asr/kaldi/blob/master/egs/fisher_swbd/s5/local/swbd1_data_prep.sh + + One notable difference is that we don't change [cough], [lipsmack], etc. to [noise]. + We also don't implement all the edge cases of normalization from Kaldi + (hopefully won't make too much difference). + """ -class Normalizer: def __init__(self) -> None: + + self.remove_regexp_before = re.compile( + r"|".join([ + # special symbols + r"\[\[SKIP.*\]\]", + r"\[SKIP.*\]", + r"\[PAUSE.*\]", + r"\[SILENCE\]", + r"", + r"", + ]) + ) + # tuples of (pattern, replacement) # note: Kaldi replaces sighs, coughs, etc with [noise]. # We don't do that here. @@ -63,19 +82,12 @@ class Normalizer: (re.compile(r"\s-\s"), r" "), (re.compile(r"\s-\s"), r" "), # special symbol with trailing dash - (re.compile(r"(\[\w+\])-"), r"\1"), + (re.compile(r"(\[.*?\])-"), r"\1"), ] # unwanted symbols in the transcripts - self.remove_regexp = re.compile( + self.remove_regexp_after = re.compile( r"|".join([ - # special symbols - r"\[\[SKIP.*\]\]", - r"\[SKIP.*\]", - r"\[PAUSE.*\]", - r"\[SILENCE\]", - r"", - r"", # remaining punctuation r"\.", r",", @@ -92,12 +104,15 @@ class Normalizer: def normalize(self, text: str) -> str: text = text.upper() - # first replace + # first remove + text = self.remove_regexp_before.sub("", text) + + # then replace for pattern, sub in self.replace_regexps: text = pattern.sub(sub, text) # then remove - text = self.remove_regexp.sub("", text) + text = self.remove_regexp_after.sub("", text) # then clean up whitespace text = self.whitespace_regexp.sub(" ", text).strip() @@ -159,6 +174,8 @@ def test(): "-[ADV]AN[TAGE]", "-[ADV]AN[TAGE]-", "[WEA[SONABLE]-/REASONABLE]", + "[VOCALIZED-NOISE]-", + "~BULL", ]: print(text) print(normalizer.normalize(text)) diff --git a/egs/fisher_swbd/ASR/prepare.sh b/egs/fisher_swbd/ASR/prepare.sh index dc89dfdac..859e0d34e 100755 --- a/egs/fisher_swbd/ASR/prepare.sh +++ b/egs/fisher_swbd/ASR/prepare.sh @@ -56,12 +56,6 @@ log() { log "dl_dir: $dl_dir" -if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then - log "Stage -1: Download LM" - #[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm - #./local/download_lm.py --out-dir=$dl_dir/lm -fi - if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Download data" @@ -116,35 +110,60 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then set -x + # Combine Fisher and SWBD recordings and supervisions lhotse combine \ data/manifests/fisher/recordings.jsonl.gz \ data/manifests/swbd/swbd_recordings.jsonl \ data/manifests/fisher-swbd_recordings.jsonl.gz - lhotse combine \ data/manifests/fisher/supervisions.jsonl.gz \ data/manifests/swbd/swbd_supervisions.jsonl \ data/manifests/fisher-swbd_supervisions.jsonl.gz + # Normalize text and remove supervisions that are not useful / hard to handle. python local/normalize_and_filter_supervisions.py \ data/manifests/fisher-swbd_supervisions.jsonl.gz \ data/manifests/fisher-swbd_supervisions_norm.jsonl.gz \ + # Create cuts that span whole recording sessions. lhotse cut simple \ -r data/manifests/fisher-swbd_recordings.jsonl.gz \ -s data/manifests/fisher-swbd_supervisions_norm.jsonl.gz \ data/manifests/fisher-swbd_cuts_unshuf.jsonl.gz + # Shuffle the cuts (pure bash pipes are fast). + # We could technically skip this step but this helps ensure + # SWBD is not only seen towards the end of training. gunzip -c data/manifests/fisher-swbd_cuts_unshuf.jsonl.gz \ | shuf \ | gzip -c \ > data/manifests/fisher-swbd_cuts.jsonl.gz + # Create train/dev split -- 20 sessions for dev is about ~2h, should be good. + num_cuts="$(gunzip -c data/manifests/fisher-swbd_cuts.jsonl.gz | wc -l)" + num_dev_sessions=20 + lhotse subset --first $num_dev_sessions \ + data/manifests/fisher-swbd_cuts.jsonl.gz \ + data/manifests/dev_fisher-swbd_cuts.jsonl.gz + lhotse subset --last $((num_cuts-num_dev_sessions)) \ + data/manifests/fisher-swbd_cuts.jsonl.gz \ + data/manifests/train_fisher-swbd_cuts.jsonl.gz + + # Finally, split the full-session cuts into one cut per supervision segment. + # In case any segments are overlapping we would discard the info about overlaps. + # (overlaps are unlikely for this dataset because each cut sees only one channel). + lhotse cut trim-to-supervisions \ + --discard-overlapping \ + data/manifests/train_fisher-swbd_cuts.jsonl.gz \ + data/manifests/train_utterances_fisher-swbd_cuts.jsonl.gz + lhotse cut trim-to-supervisions \ + --discard-overlapping \ + data/manifests/dev_fisher-swbd_cuts.jsonl.gz \ + data/manifests/dev_utterances_fisher-swbd_cuts.jsonl.gz + set +x fi -# TODO: optional stage 5, compute features - if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "Stage 6: Dump transcripts for LM training" mkdir -p data/lm @@ -154,18 +173,6 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then > data/lm/transcript_words.txt fi -#if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then -# log "Stage 3: Compute fbank for librispeech" -# mkdir -p data/fbank -# ./local/compute_fbank_librispeech.py -#fi -# -#if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then -# log "Stage 4: Compute fbank for musan" -# mkdir -p data/fbank -# ./local/compute_fbank_musan.py -#fi - if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then log "Stage 7: Prepare lexicon using g2p_en" lang_dir=data/lang_phone @@ -186,6 +193,12 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then | awk '{print $0,NR+2}' \ >> $lang_dir/words.txt + # Add remaining special word symbols expected by LM scripts. + num_words=$(wc -l $lang_dir/words.txt) + echo " $((num_words))" + echo " $((num_words+1))" + echo "#0 $((num_words+2))" + if [ ! -f $lang_dir/L_disambig.pt ]; then pip install g2p_en ./local/prepare_lang_g2pen.py --lang-dir $lang_dir From 3582599a330aea462f609653c8883a8ce5260f67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sat, 15 Jan 2022 02:05:27 +0000 Subject: [PATCH 04/12] fix --- .../ASR/local/normalize_and_filter_supervisions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py b/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py index 9933d3a4f..22f265c9d 100644 --- a/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py +++ b/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py @@ -135,7 +135,7 @@ def main(): sups = load_manifest_lazy_or_eager(args.input_sups) assert isinstance(sups, SupervisionSet) - normalizer = Normalizer() + normalizer = FisherSwbdNormalizer() tot, skip = 0, 0 with SupervisionSet.open_writer(args.output_sups) as writer: @@ -155,7 +155,7 @@ def main(): def test(): - normalizer = Normalizer() + normalizer = FisherSwbdNormalizer() for text in [ "[laughterr]", "[laugh] oh this is great [silence] yes", From 186f5f1ba4040c371fe9dfb6c351d2c43bfca8de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sat, 15 Jan 2022 05:03:45 +0000 Subject: [PATCH 05/12] fix --- egs/fisher_swbd/ASR/prepare.sh | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/egs/fisher_swbd/ASR/prepare.sh b/egs/fisher_swbd/ASR/prepare.sh index 859e0d34e..ca94a1dc8 100755 --- a/egs/fisher_swbd/ASR/prepare.sh +++ b/egs/fisher_swbd/ASR/prepare.sh @@ -181,7 +181,7 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then # Add special words to words.txt echo " 0" > $lang_dir/words.txt echo "!SIL 1" >> $lang_dir/words.txt - echo " 2" >> $lang_dir/words.txt + echo "[UNK] 2" >> $lang_dir/words.txt # Add regular words to words.txt gunzip -c data/manifests/fisher-swbd_supervisions_norm.jsonl.gz \ @@ -195,9 +195,11 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then # Add remaining special word symbols expected by LM scripts. num_words=$(wc -l $lang_dir/words.txt) - echo " $((num_words))" - echo " $((num_words+1))" - echo "#0 $((num_words+2))" + echo " ${num_words}" >> $lang_dir/words.txt + num_words=$(wc -l $lang_dir/words.txt) + echo " ${num_words}" >> $lang_dir/words.txt + num_words=$(wc -l $lang_dir/words.txt) + echo "#0 ${num_words}" >> $lang_dir/words.txt if [ ! -f $lang_dir/L_disambig.pt ]; then pip install g2p_en From 2027b55233016b23d351af9b3d9ffccd8e16a775 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Sun, 16 Jan 2022 00:03:17 +0000 Subject: [PATCH 06/12] Fixes --- egs/fisher_swbd/ASR/conformer_ctc/train.py | 2 +- egs/fisher_swbd/ASR/local/prepare_lang_bpe.py | 4 ++-- egs/fisher_swbd/ASR/prepare.sh | 8 +++++--- egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py | 1 + 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/egs/fisher_swbd/ASR/conformer_ctc/train.py b/egs/fisher_swbd/ASR/conformer_ctc/train.py index 29f9f6cb6..bda74451f 100755 --- a/egs/fisher_swbd/ASR/conformer_ctc/train.py +++ b/egs/fisher_swbd/ASR/conformer_ctc/train.py @@ -717,7 +717,7 @@ def scan_pessimistic_batches_for_oom( def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + AsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) diff --git a/egs/fisher_swbd/ASR/local/prepare_lang_bpe.py b/egs/fisher_swbd/ASR/local/prepare_lang_bpe.py index cf32f308d..be92710d2 100755 --- a/egs/fisher_swbd/ASR/local/prepare_lang_bpe.py +++ b/egs/fisher_swbd/ASR/local/prepare_lang_bpe.py @@ -152,7 +152,7 @@ def generate_lexicon( lexicon.append((word, pieces)) # The OOV word is - lexicon.append(("", [sp.id_to_piece(sp.unk_id())])) + lexicon.append(("[UNK]", [sp.id_to_piece(sp.unk_id())])) token2id: Dict[str, int] = dict() for i in range(sp.vocab_size()): @@ -197,7 +197,7 @@ def main(): words = word_sym_table.symbols - excluded = ["", "!SIL", "", "", "#0", "", ""] + excluded = ["", "!SIL", "", "[UNK]", "#0", "", ""] for w in excluded: if w in words: words.remove(w) diff --git a/egs/fisher_swbd/ASR/prepare.sh b/egs/fisher_swbd/ASR/prepare.sh index ca94a1dc8..cfc621f23 100755 --- a/egs/fisher_swbd/ASR/prepare.sh +++ b/egs/fisher_swbd/ASR/prepare.sh @@ -103,6 +103,8 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then # to data/musan mkdir -p data/manifests lhotse prepare musan $dl_dir/musan data/manifests + lhotse combine data/manifests/recordings_{music,speech,noise}.json data/manifests/recordings_musan.jsonl.gz + lhotse cut simple -r data/manifests/recordings_musan.jsonl.gz data/manifests/musan_cuts.jsonl.gz fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then @@ -194,11 +196,11 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then >> $lang_dir/words.txt # Add remaining special word symbols expected by LM scripts. - num_words=$(wc -l $lang_dir/words.txt) + num_words=$(cat $lang_dir/words.txt | wc -l) echo " ${num_words}" >> $lang_dir/words.txt - num_words=$(wc -l $lang_dir/words.txt) + num_words=$(cat $lang_dir/words.txt | wc -l) echo " ${num_words}" >> $lang_dir/words.txt - num_words=$(wc -l $lang_dir/words.txt) + num_words=$(cat $lang_dir/words.txt | wc -l) echo "#0 ${num_words}" >> $lang_dir/words.txt if [ ! -f $lang_dir/L_disambig.pt ]; then diff --git a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py index 40c8468f6..7abe169d1 100644 --- a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -167,6 +167,7 @@ class AsrDataModule: num_buckets=self.args.num_buckets, drop_last=True, ) + train_sampler.filter(lambda cut: 1.0 <= cut.duration <= 15.0) logging.info("About to create train dataloader") train_dl = DataLoader( From 4426715bc81fa6f33b7c58c5fd5ce52390e063d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 17 Jan 2022 22:37:45 +0000 Subject: [PATCH 07/12] Fix decode, remove unwanted files --- .../ASR/conformer_ctc/asr_datamodule.py | 287 +++- egs/fisher_swbd/ASR/conformer_ctc/decode.py | 17 +- .../ASR/local/compute_fbank_librispeech.py | 100 -- .../ASR/local/compute_fbank_musan.py | 97 -- .../convert_transcript_words_to_tokens.py | 107 -- .../ASR/local/display_manifest_statistics.py | 215 --- egs/fisher_swbd/ASR/local/download_lm.py | 97 -- egs/fisher_swbd/ASR/local/prepare_lang.py | 413 ----- .../ASR/local/test_prepare_lang.py | 106 -- egs/fisher_swbd/ASR/prepare.sh | 9 +- .../ASR/streaming_conformer_ctc/README.md | 90 -- .../streaming_conformer_ctc/asr_datamodule.py | 1 - .../ASR/streaming_conformer_ctc/conformer.py | 1355 ----------------- .../label_smoothing.py | 1 - .../streaming_decode.py | 521 ------- .../streaming_conformer_ctc/subsampling.py | 1 - .../ASR/streaming_conformer_ctc/train.py | 745 --------- .../streaming_conformer_ctc/transformer.py | 966 ------------ egs/fisher_swbd/ASR/tdnn_lstm_ctc/README.md | 4 - egs/fisher_swbd/ASR/tdnn_lstm_ctc/__init__.py | 0 .../ASR/tdnn_lstm_ctc/asr_datamodule.py | 286 ---- egs/fisher_swbd/ASR/tdnn_lstm_ctc/decode.py | 505 ------ egs/fisher_swbd/ASR/tdnn_lstm_ctc/model.py | 103 -- .../ASR/tdnn_lstm_ctc/pretrained.py | 277 ---- egs/fisher_swbd/ASR/tdnn_lstm_ctc/train.py | 603 -------- 25 files changed, 301 insertions(+), 6605 deletions(-) mode change 120000 => 100644 egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py delete mode 100755 egs/fisher_swbd/ASR/local/compute_fbank_librispeech.py delete mode 100755 egs/fisher_swbd/ASR/local/compute_fbank_musan.py delete mode 100755 egs/fisher_swbd/ASR/local/convert_transcript_words_to_tokens.py delete mode 100755 egs/fisher_swbd/ASR/local/display_manifest_statistics.py delete mode 100755 egs/fisher_swbd/ASR/local/download_lm.py delete mode 100755 egs/fisher_swbd/ASR/local/prepare_lang.py delete mode 100755 egs/fisher_swbd/ASR/local/test_prepare_lang.py delete mode 100644 egs/fisher_swbd/ASR/streaming_conformer_ctc/README.md delete mode 120000 egs/fisher_swbd/ASR/streaming_conformer_ctc/asr_datamodule.py delete mode 100644 egs/fisher_swbd/ASR/streaming_conformer_ctc/conformer.py delete mode 120000 egs/fisher_swbd/ASR/streaming_conformer_ctc/label_smoothing.py delete mode 100755 egs/fisher_swbd/ASR/streaming_conformer_ctc/streaming_decode.py delete mode 120000 egs/fisher_swbd/ASR/streaming_conformer_ctc/subsampling.py delete mode 100755 egs/fisher_swbd/ASR/streaming_conformer_ctc/train.py delete mode 100644 egs/fisher_swbd/ASR/streaming_conformer_ctc/transformer.py delete mode 100644 egs/fisher_swbd/ASR/tdnn_lstm_ctc/README.md delete mode 100644 egs/fisher_swbd/ASR/tdnn_lstm_ctc/__init__.py delete mode 100644 egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py delete mode 100755 egs/fisher_swbd/ASR/tdnn_lstm_ctc/decode.py delete mode 100644 egs/fisher_swbd/ASR/tdnn_lstm_ctc/model.py delete mode 100755 egs/fisher_swbd/ASR/tdnn_lstm_ctc/pretrained.py delete mode 100755 egs/fisher_swbd/ASR/tdnn_lstm_ctc/train.py diff --git a/egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py deleted file mode 120000 index fa1b8cca3..000000000 --- a/egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../tdnn_lstm_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py new file mode 100644 index 000000000..7abe169d1 --- /dev/null +++ b/egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py @@ -0,0 +1,286 @@ +# Copyright 2021 Piotr Żelasko +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path + +from tqdm import tqdm + +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( + BucketingSampler, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PerturbSpeed, + PrecomputedFeatures, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class Resample16kHz: + def __call__(self, cuts: CutSet) -> CutSet: + return cuts.resample(16000).with_recording_path_prefix('download') + + +class AsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/manifests"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=True, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=8, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: + logging.info("About to get Musan cuts") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) + + input_strategy = PrecomputedFeatures() + if self.args.on_the_fly_feats: + input_strategy = OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80, sampling_rate=16000)), + ) + + train = K2SpeechRecognitionDataset( + input_strategy=input_strategy, + cut_transforms=[ + PerturbSpeed(factors=[0.9, 1.1], p=2 / 3, preserve_id=True), + Resample16kHz(), + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ), + ], + input_transforms=[ + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=2, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ], + return_cuts=True, + ) + + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=True, + ) + train_sampler.filter(lambda cut: 1.0 <= cut.duration <= 15.0) + + logging.info("About to create train dataloader") + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + + logging.info("About to create dev dataset") + input_strategy = PrecomputedFeatures() + if self.args.on_the_fly_feats: + input_strategy = OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80, sampling_rate=16000)), + ) + + validate = K2SpeechRecognitionDataset( + return_cuts=True, + input_strategy=input_strategy, + cut_transforms=[ + Resample16kHz(), + ], + ) + + valid_sampler = BucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + + input_strategy = PrecomputedFeatures() + if self.args.on_the_fly_feats: + input_strategy = OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80, sampling_rate=16000)), + ) + + test = K2SpeechRecognitionDataset( + return_cuts=True, + input_strategy=input_strategy, + cut_transforms=[ + Resample16kHz(), + ], + ) + sampler = BucketingSampler( + cuts, max_duration=self.args.max_duration, shuffle=False + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train Fisher + SWBD cuts") + return load_manifest_lazy( + self.args.manifest_dir + / "train_utterances_fisher-swbd_cuts.jsonl.gz" + ) + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev Fisher + SWBD cuts") + return load_manifest_lazy( + self.args.manifest_dir / "dev_utterances_fisher-swbd_cuts.jsonl.gz" + ) + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + raise NotImplemented + + +def test(): + parser = argparse.ArgumentParser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + adm = AsrDataModule(args) + + cuts = adm.train_cuts() + dl = adm.train_dataloaders(cuts) + for i, batch in tqdm(enumerate(dl)): + if i == 100: + break + + cuts = adm.dev_cuts() + dl = adm.valid_dataloaders(cuts) + for i, batch in tqdm(enumerate(dl)): + if i == 100: + break + + +if __name__ == '__main__': + test() diff --git a/egs/fisher_swbd/ASR/conformer_ctc/decode.py b/egs/fisher_swbd/ASR/conformer_ctc/decode.py index 177e33a6e..3185ff777 100755 --- a/egs/fisher_swbd/ASR/conformer_ctc/decode.py +++ b/egs/fisher_swbd/ASR/conformer_ctc/decode.py @@ -26,7 +26,7 @@ import k2 import sentencepiece as spm import torch import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule +from asr_datamodule import AsrDataModule from conformer import Conformer from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler @@ -534,7 +534,7 @@ def save_results( @torch.no_grad() def main(): parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) + AsrDataModule.add_arguments(parser) args = parser.parse_args() args.exp_dir = Path(args.exp_dir) args.lang_dir = Path(args.lang_dir) @@ -662,16 +662,13 @@ def main(): num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") - librispeech = LibriSpeechAsrDataModule(args) + datamodule = AsrDataModule(args) - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() + fisher_swbd_dev_cuts = datamodule.dev_cuts() + fisher_swbd_dev_dataloader = datamodule.test_dataloaders(fisher_swbd_dev_cuts) - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] + test_sets = ["dev-fisher-swbd"] + test_dl = [fisher_swbd_dev_dataloader] for test_set, test_dl in zip(test_sets, test_dl): results_dict = decode_dataset( diff --git a/egs/fisher_swbd/ASR/local/compute_fbank_librispeech.py b/egs/fisher_swbd/ASR/local/compute_fbank_librispeech.py deleted file mode 100755 index b26034eb2..000000000 --- a/egs/fisher_swbd/ASR/local/compute_fbank_librispeech.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the LibriSpeech dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import logging -import os -from pathlib import Path - -import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_librispeech(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) - num_mel_bins = 80 - - dataset_parts = ( - "dev-clean", - "dev-other", - "test-clean", - "test-other", - "train-clean-100", - "train-clean-360", - "train-other-500", - ) - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, output_dir=src_dir - ) - assert manifests is not None - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - for partition, m in manifests.items(): - if (output_dir / f"cuts_{partition}.json.gz").is_file(): - logging.info(f"{partition} already exists - skipping.") - continue - logging.info(f"Processing {partition}") - cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], - ) - if "train" in partition: - cut_set = ( - cut_set - + cut_set.perturb_speed(0.9) - + cut_set.perturb_speed(1.1) - ) - cut_set = cut_set.compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/feats_{partition}", - # when an executor is specified, make more partitions - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomHdf5Writer, - ) - cut_set.to_json(output_dir / f"cuts_{partition}.json.gz") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - - compute_fbank_librispeech() diff --git a/egs/fisher_swbd/ASR/local/compute_fbank_musan.py b/egs/fisher_swbd/ASR/local/compute_fbank_musan.py deleted file mode 100755 index d44524e70..000000000 --- a/egs/fisher_swbd/ASR/local/compute_fbank_musan.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the musan dataset. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" - -import logging -import os -from pathlib import Path - -import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine -from lhotse.recipes.utils import read_manifests_if_cached - -from icefall.utils import get_executor - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - - -def compute_fbank_musan(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - num_jobs = min(15, os.cpu_count()) - num_mel_bins = 80 - - dataset_parts = ( - "music", - "speech", - "noise", - ) - manifests = read_manifests_if_cached( - dataset_parts=dataset_parts, output_dir=src_dir - ) - assert manifests is not None - - musan_cuts_path = output_dir / "cuts_musan.json.gz" - - if musan_cuts_path.is_file(): - logging.info(f"{musan_cuts_path} already exists - skipping") - return - - logging.info("Extracting features for Musan") - - extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) - - with get_executor() as ex: # Initialize the executor only once. - # create chunks of Musan with duration 5 - 10 seconds - musan_cuts = ( - CutSet.from_manifests( - recordings=combine( - part["recordings"] for part in manifests.values() - ) - ) - .cut_into_windows(10.0) - .filter(lambda c: c.duration > 5) - .compute_and_store_features( - extractor=extractor, - storage_path=f"{output_dir}/feats_musan", - num_jobs=num_jobs if ex is None else 80, - executor=ex, - storage_type=LilcomHdf5Writer, - ) - ) - musan_cuts.to_json(musan_cuts_path) - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_musan() diff --git a/egs/fisher_swbd/ASR/local/convert_transcript_words_to_tokens.py b/egs/fisher_swbd/ASR/local/convert_transcript_words_to_tokens.py deleted file mode 100755 index 133499c8b..000000000 --- a/egs/fisher_swbd/ASR/local/convert_transcript_words_to_tokens.py +++ /dev/null @@ -1,107 +0,0 @@ -#!/usr/bin/env python3 - -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -""" -Convert a transcript file containing words to a corpus file containing tokens -for LM training with the help of a lexicon. - -If the lexicon contains phones, the resulting LM will be a phone LM; If the -lexicon contains word pieces, the resulting LM will be a word piece LM. - -If a word has multiple pronunciations, the one that appears first in the lexicon -is kept; others are removed. - -If the input transcript is: - - hello zoo world hello - world zoo - foo zoo world hellO - -and if the lexicon is - - SPN - hello h e l l o 2 - hello h e l l o - world w o r l d - zoo z o o - -Then the output is - - h e l l o 2 z o o w o r l d h e l l o 2 - w o r l d z o o - SPN z o o w o r l d SPN -""" - -import argparse -from pathlib import Path -from typing import Dict, List - -from generate_unique_lexicon import filter_multiple_pronunications - -from icefall.lexicon import read_lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--transcript", - type=str, - help="The input transcript file." - "We assume that the transcript file consists of " - "lines. Each line consists of space separated words.", - ) - parser.add_argument("--lexicon", type=str, help="The input lexicon file.") - parser.add_argument( - "--oov", type=str, default="", help="The OOV word." - ) - - return parser.parse_args() - - -def process_line( - lexicon: Dict[str, List[str]], line: str, oov_token: str -) -> None: - """ - Args: - lexicon: - A dict containing pronunciations. Its keys are words and values - are pronunciations (i.e., tokens). - line: - A line of transcript consisting of space(s) separated words. - oov_token: - The pronunciation of the oov word if a word in `line` is not present - in the lexicon. - Returns: - Return None. - """ - s = "" - words = line.strip().split() - for i, w in enumerate(words): - tokens = lexicon.get(w, oov_token) - s += " ".join(tokens) - s += " " - print(s.strip()) - - -def main(): - args = get_args() - assert Path(args.lexicon).is_file() - assert Path(args.transcript).is_file() - assert len(args.oov) > 0 - - # Only the first pronunciation of a word is kept - lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon)) - - lexicon = dict(lexicon) - - assert args.oov in lexicon - - oov_token = lexicon[args.oov] - - with open(args.transcript) as f: - for line in f: - process_line(lexicon=lexicon, line=line, oov_token=oov_token) - - -if __name__ == "__main__": - main() diff --git a/egs/fisher_swbd/ASR/local/display_manifest_statistics.py b/egs/fisher_swbd/ASR/local/display_manifest_statistics.py deleted file mode 100755 index 15bd206fa..000000000 --- a/egs/fisher_swbd/ASR/local/display_manifest_statistics.py +++ /dev/null @@ -1,215 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file displays duration statistics of utterances in a manifest. -You can use the displayed value to choose minimum/maximum duration -to remove short and long utterances during the training. - -See the function `remove_short_and_long_utt()` in transducer/train.py -for usage. -""" - - -from lhotse import load_manifest - - -def main(): - path = "./data/fbank/cuts_train-clean-100.json.gz" - path = "./data/fbank/cuts_train-clean-360.json.gz" - path = "./data/fbank/cuts_train-other-500.json.gz" - path = "./data/fbank/cuts_dev-clean.json.gz" - path = "./data/fbank/cuts_dev-other.json.gz" - path = "./data/fbank/cuts_test-clean.json.gz" - path = "./data/fbank/cuts_test-other.json.gz" - - cuts = load_manifest(path) - cuts.describe() - - -if __name__ == "__main__": - main() - -""" -## train-clean-100 -Cuts count: 85617 -Total duration (hours): 303.8 -Speech duration (hours): 303.8 (100.0%) -*** -Duration statistics (seconds): -mean 12.8 -std 3.8 -min 1.3 -0.1% 1.9 -0.5% 2.2 -1% 2.5 -5% 4.2 -10% 6.4 -25% 11.4 -50% 13.8 -75% 15.3 -90% 16.7 -95% 17.3 -99% 18.1 -99.5% 18.4 -99.9% 18.8 -max 27.2 - -## train-clean-360 -Cuts count: 312042 -Total duration (hours): 1098.2 -Speech duration (hours): 1098.2 (100.0%) -*** -Duration statistics (seconds): -mean 12.7 -std 3.8 -min 1.0 -0.1% 1.8 -0.5% 2.2 -1% 2.5 -5% 4.2 -10% 6.2 -25% 11.2 -50% 13.7 -75% 15.3 -90% 16.6 -95% 17.3 -99% 18.1 -99.5% 18.4 -99.9% 18.8 -max 33.0 - -## train-other 500 -Cuts count: 446064 -Total duration (hours): 1500.6 -Speech duration (hours): 1500.6 (100.0%) -*** -Duration statistics (seconds): -mean 12.1 -std 4.2 -min 0.8 -0.1% 1.7 -0.5% 2.1 -1% 2.3 -5% 3.5 -10% 5.0 -25% 9.8 -50% 13.4 -75% 15.1 -90% 16.5 -95% 17.2 -99% 18.1 -99.5% 18.4 -99.9% 18.9 -max 31.0 - -## dev-clean -Cuts count: 2703 -Total duration (hours): 5.4 -Speech duration (hours): 5.4 (100.0%) -*** -Duration statistics (seconds): -mean 7.2 -std 4.7 -min 1.4 -0.1% 1.6 -0.5% 1.8 -1% 1.9 -5% 2.4 -10% 2.7 -25% 3.8 -50% 5.9 -75% 9.3 -90% 13.3 -95% 16.4 -99% 23.8 -99.5% 28.5 -99.9% 32.3 -max 32.6 - -## dev-other -Cuts count: 2864 -Total duration (hours): 5.1 -Speech duration (hours): 5.1 (100.0%) -*** -Duration statistics (seconds): -mean 6.4 -std 4.3 -min 1.1 -0.1% 1.3 -0.5% 1.7 -1% 1.8 -5% 2.2 -10% 2.6 -25% 3.5 -50% 5.3 -75% 7.9 -90% 12.0 -95% 15.0 -99% 22.2 -99.5% 27.1 -99.9% 32.4 -max 35.2 - -## test-clean -Cuts count: 2620 -Total duration (hours): 5.4 -Speech duration (hours): 5.4 (100.0%) -*** -Duration statistics (seconds): -mean 7.4 -std 5.2 -min 1.3 -0.1% 1.6 -0.5% 1.8 -1% 2.0 -5% 2.3 -10% 2.7 -25% 3.7 -50% 5.8 -75% 9.6 -90% 14.6 -95% 17.8 -99% 25.5 -99.5% 28.4 -99.9% 32.8 -max 35.0 - -## test-other -Cuts count: 2939 -Total duration (hours): 5.3 -Speech duration (hours): 5.3 (100.0%) -*** -Duration statistics (seconds): -mean 6.5 -std 4.4 -min 1.2 -0.1% 1.5 -0.5% 1.8 -1% 1.9 -5% 2.3 -10% 2.6 -25% 3.4 -50% 5.2 -75% 8.2 -90% 12.6 -95% 15.8 -99% 21.4 -99.5% 23.8 -99.9% 33.5 -max 34.5 -""" diff --git a/egs/fisher_swbd/ASR/local/download_lm.py b/egs/fisher_swbd/ASR/local/download_lm.py deleted file mode 100755 index 94d23afed..000000000 --- a/egs/fisher_swbd/ASR/local/download_lm.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file downloads the following LibriSpeech LM files: - - - 3-gram.pruned.1e-7.arpa.gz - - 4-gram.arpa.gz - - librispeech-vocab.txt - - librispeech-lexicon.txt - -from http://www.openslr.org/resources/11 -and save them in the user provided directory. - -Files are not re-downloaded if they already exist. - -Usage: - ./local/download_lm.py --out-dir ./download/lm -""" - -import argparse -import gzip -import logging -import os -import shutil -from pathlib import Path - -from lhotse.utils import urlretrieve_progress -from tqdm.auto import tqdm - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--out-dir", type=str, help="Output directory.") - - args = parser.parse_args() - return args - - -def main(out_dir: str): - url = "http://www.openslr.org/resources/11" - out_dir = Path(out_dir) - - files_to_download = ( - "3-gram.pruned.1e-7.arpa.gz", - "4-gram.arpa.gz", - "librispeech-vocab.txt", - "librispeech-lexicon.txt", - ) - - for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"): - filename = out_dir / f - if filename.is_file() is False: - urlretrieve_progress( - f"{url}/{f}", - filename=filename, - desc=f"Downloading {filename}", - ) - else: - logging.info(f"{filename} already exists - skipping") - - if ".gz" in str(filename): - unzipped = Path(os.path.splitext(filename)[0]) - if unzipped.is_file() is False: - with gzip.open(filename, "rb") as f_in: - with open(unzipped, "wb") as f_out: - shutil.copyfileobj(f_in, f_out) - else: - logging.info(f"{unzipped} already exist - skipping") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - - args = get_args() - logging.info(f"out_dir: {args.out_dir}") - - main(out_dir=args.out_dir) diff --git a/egs/fisher_swbd/ASR/local/prepare_lang.py b/egs/fisher_swbd/ASR/local/prepare_lang.py deleted file mode 100755 index d913756a1..000000000 --- a/egs/fisher_swbd/ASR/local/prepare_lang.py +++ /dev/null @@ -1,413 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input a lexicon file "data/lang_phone/lexicon.txt" -consisting of words and tokens (i.e., phones) and does the following: - -1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt - -2. Generate tokens.txt, the token table mapping a token to a unique integer. - -3. Generate words.txt, the word table mapping a word to a unique integer. - -4. Generate L.pt, in k2 format. It can be loaded by - - d = torch.load("L.pt") - lexicon = k2.Fsa.from_dict(d) - -5. Generate L_disambig.pt, in k2 format. -""" -import argparse -import math -from collections import defaultdict -from pathlib import Path -from typing import Any, Dict, List, Tuple - -import k2 -import torch - -from icefall.lexicon import read_lexicon, write_lexicon -from icefall.utils import str2bool - -Lexicon = List[Tuple[str, List[str]]] - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain a file lexicon.txt. - Generated files by this script are saved into this directory. - """, - ) - - parser.add_argument( - "--debug", - type=str2bool, - default=False, - help="""True for debugging, which will generate - a visualization of the lexicon FST. - - Caution: If your lexicon contains hundreds of thousands - of lines, please set it to False! - """, - ) - - return parser.parse_args() - - -def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: - """Write a symbol to ID mapping to a file. - - Note: - No need to implement `read_mapping` as it can be done - through :func:`k2.SymbolTable.from_file`. - - Args: - filename: - Filename to save the mapping. - sym2id: - A dict mapping symbols to IDs. - Returns: - Return None. - """ - with open(filename, "w", encoding="utf-8") as f: - for sym, i in sym2id.items(): - f.write(f"{sym} {i}\n") - - -def get_tokens(lexicon: Lexicon) -> List[str]: - """Get tokens from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique tokens. - """ - ans = set() - for _, tokens in lexicon: - ans.update(tokens) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def get_words(lexicon: Lexicon) -> List[str]: - """Get words from a lexicon. - - Args: - lexicon: - It is the return value of :func:`read_lexicon`. - Returns: - Return a list of unique words. - """ - ans = set() - for word, _ in lexicon: - ans.add(word) - sorted_ans = sorted(list(ans)) - return sorted_ans - - -def add_disambig_symbols(lexicon: Lexicon) -> Tuple[Lexicon, int]: - """It adds pseudo-token disambiguation symbols #1, #2 and so on - at the ends of tokens to ensure that all pronunciations are different, - and that none is a prefix of another. - - See also add_lex_disambig.pl from kaldi. - - Args: - lexicon: - It is returned by :func:`read_lexicon`. - Returns: - Return a tuple with two elements: - - - The output lexicon with disambiguation symbols - - The ID of the max disambiguation symbol that appears - in the lexicon - """ - - # (1) Work out the count of each token-sequence in the - # lexicon. - count = defaultdict(int) - for _, tokens in lexicon: - count[" ".join(tokens)] += 1 - - # (2) For each left sub-sequence of each token-sequence, note down - # that it exists (for identifying prefixes of longer strings). - issubseq = defaultdict(int) - for _, tokens in lexicon: - tokens = tokens.copy() - tokens.pop() - while tokens: - issubseq[" ".join(tokens)] = 1 - tokens.pop() - - # (3) For each entry in the lexicon: - # if the token sequence is unique and is not a - # prefix of another word, no disambig symbol. - # Else output #1, or #2, #3, ... if the same token-seq - # has already been assigned a disambig symbol. - ans = [] - - # We start with #1 since #0 has its own purpose - first_allowed_disambig = 1 - max_disambig = first_allowed_disambig - 1 - last_used_disambig_symbol_of = defaultdict(int) - - for word, tokens in lexicon: - tokenseq = " ".join(tokens) - assert tokenseq != "" - if issubseq[tokenseq] == 0 and count[tokenseq] == 1: - ans.append((word, tokens)) - continue - - cur_disambig = last_used_disambig_symbol_of[tokenseq] - if cur_disambig == 0: - cur_disambig = first_allowed_disambig - else: - cur_disambig += 1 - - if cur_disambig > max_disambig: - max_disambig = cur_disambig - last_used_disambig_symbol_of[tokenseq] = cur_disambig - tokenseq += f" #{cur_disambig}" - ans.append((word, tokenseq.split())) - return ans, max_disambig - - -def generate_id_map(symbols: List[str]) -> Dict[str, int]: - """Generate ID maps, i.e., map a symbol to a unique ID. - - Args: - symbols: - A list of unique symbols. - Returns: - A dict containing the mapping between symbols and IDs. - """ - return {sym: i for i, sym in enumerate(symbols)} - - -def add_self_loops( - arcs: List[List[Any]], disambig_token: int, disambig_word: int -) -> List[List[Any]]: - """Adds self-loops to states of an FST to propagate disambiguation symbols - through it. They are added on each state with non-epsilon output symbols - on at least one arc out of the state. - - See also fstaddselfloops.pl from Kaldi. One difference is that - Kaldi uses OpenFst style FSTs and it has multiple final states. - This function uses k2 style FSTs and it does not need to add self-loops - to the final state. - - The input label of a self-loop is `disambig_token`, while the output - label is `disambig_word`. - - Args: - arcs: - A list-of-list. The sublist contains - `[src_state, dest_state, label, aux_label, score]` - disambig_token: - It is the token ID of the symbol `#0`. - disambig_word: - It is the word ID of the symbol `#0`. - - Return: - Return new `arcs` containing self-loops. - """ - states_needs_self_loops = set() - for arc in arcs: - src, dst, ilabel, olabel, score = arc - if olabel != 0: - states_needs_self_loops.add(src) - - ans = [] - for s in states_needs_self_loops: - ans.append([s, s, disambig_token, disambig_word, 0]) - - return arcs + ans - - -def lexicon_to_fst( - lexicon: Lexicon, - token2id: Dict[str, int], - word2id: Dict[str, int], - sil_token: str = "SIL", - sil_prob: float = 0.5, - need_self_loops: bool = False, -) -> k2.Fsa: - """Convert a lexicon to an FST (in k2 format) with optional silence at - the beginning and end of each word. - - Args: - lexicon: - The input lexicon. See also :func:`read_lexicon` - token2id: - A dict mapping tokens to IDs. - word2id: - A dict mapping words to IDs. - sil_token: - The silence token. - sil_prob: - The probability for adding a silence at the beginning and end - of the word. - need_self_loops: - If True, add self-loop to states with non-epsilon output symbols - on at least one arc out of the state. The input label for this - self loop is `token2id["#0"]` and the output label is `word2id["#0"]`. - Returns: - Return an instance of `k2.Fsa` representing the given lexicon. - """ - assert sil_prob > 0.0 and sil_prob < 1.0 - # CAUTION: we use score, i.e, negative cost. - sil_score = math.log(sil_prob) - no_sil_score = math.log(1.0 - sil_prob) - - start_state = 0 - loop_state = 1 # words enter and leave from here - sil_state = 2 # words terminate here when followed by silence; this state - # has a silence transition to loop_state. - next_state = 3 # the next un-allocated state, will be incremented as we go. - arcs = [] - - assert token2id[""] == 0 - assert word2id[""] == 0 - - eps = 0 - - sil_token = token2id[sil_token] - - arcs.append([start_state, loop_state, eps, eps, no_sil_score]) - arcs.append([start_state, sil_state, eps, eps, sil_score]) - arcs.append([sil_state, loop_state, sil_token, eps, 0]) - - for word, tokens in lexicon: - assert len(tokens) > 0, f"{word} has no pronunciations" - cur_state = loop_state - - word = word2id[word] - tokens = [token2id[i] for i in tokens] - - for i in range(len(tokens) - 1): - w = word if i == 0 else eps - arcs.append([cur_state, next_state, tokens[i], w, 0]) - - cur_state = next_state - next_state += 1 - - # now for the last token of this word - # It has two out-going arcs, one to the loop state, - # the other one to the sil_state. - i = len(tokens) - 1 - w = word if i == 0 else eps - arcs.append([cur_state, loop_state, tokens[i], w, no_sil_score]) - arcs.append([cur_state, sil_state, tokens[i], w, sil_score]) - - if need_self_loops: - disambig_token = token2id["#0"] - disambig_word = word2id["#0"] - arcs = add_self_loops( - arcs, - disambig_token=disambig_token, - disambig_word=disambig_word, - ) - - final_state = next_state - arcs.append([loop_state, final_state, -1, -1, 0]) - arcs.append([final_state]) - - arcs = sorted(arcs, key=lambda arc: arc[0]) - arcs = [[str(i) for i in arc] for arc in arcs] - arcs = [" ".join(arc) for arc in arcs] - arcs = "\n".join(arcs) - - fsa = k2.Fsa.from_str(arcs, acceptor=False) - return fsa - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - lexicon_filename = lang_dir / "lexicon.txt" - sil_token = "SIL" - sil_prob = 0.5 - - lexicon = read_lexicon(lexicon_filename) - tokens = get_tokens(lexicon) - words = get_words(lexicon) - - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - - for i in range(max_disambig + 1): - disambig = f"#{i}" - assert disambig not in tokens - tokens.append(f"#{i}") - - assert "" not in tokens - tokens = [""] + tokens - - assert "" not in words - assert "#0" not in words - assert "" not in words - assert "" not in words - - words = [""] + words + ["#0", "", ""] - - token2id = generate_id_map(tokens) - word2id = generate_id_map(words) - - write_mapping(lang_dir / "tokens.txt", token2id) - write_mapping(lang_dir / "words.txt", word2id) - write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) - - L = lexicon_to_fst( - lexicon, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - ) - - L_disambig = lexicon_to_fst( - lexicon_disambig, - token2id=token2id, - word2id=word2id, - sil_token=sil_token, - sil_prob=sil_prob, - need_self_loops=True, - ) - torch.save(L.as_dict(), lang_dir / "L.pt") - torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt") - - if args.debug: - labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt") - aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt") - - L.labels_sym = labels_sym - L.aux_labels_sym = aux_labels_sym - L.draw(f"{lang_dir / 'L.svg'}", title="L.pt") - - L_disambig.labels_sym = labels_sym - L_disambig.aux_labels_sym = aux_labels_sym - L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt") - - -if __name__ == "__main__": - main() diff --git a/egs/fisher_swbd/ASR/local/test_prepare_lang.py b/egs/fisher_swbd/ASR/local/test_prepare_lang.py deleted file mode 100755 index d4cf62bba..000000000 --- a/egs/fisher_swbd/ASR/local/test_prepare_lang.py +++ /dev/null @@ -1,106 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) - -import os -import tempfile - -import k2 -from prepare_lang import ( - add_disambig_symbols, - generate_id_map, - get_phones, - get_words, - lexicon_to_fst, - read_lexicon, - write_lexicon, - write_mapping, -) - - -def generate_lexicon_file() -> str: - fd, filename = tempfile.mkstemp() - os.close(fd) - s = """ - !SIL SIL - SPN - SPN - f f - a a - foo f o o - bar b a r - bark b a r k - food f o o d - food2 f o o d - fo f o - """.strip() - with open(filename, "w") as f: - f.write(s) - return filename - - -def test_read_lexicon(filename: str): - lexicon = read_lexicon(filename) - phones = get_phones(lexicon) - words = get_words(lexicon) - print(lexicon) - print(phones) - print(words) - lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) - print(lexicon_disambig) - print("max disambig:", f"#{max_disambig}") - - phones = ["", "SIL", "SPN"] + phones - for i in range(max_disambig + 1): - phones.append(f"#{i}") - words = [""] + words - - phone2id = generate_id_map(phones) - word2id = generate_id_map(words) - - print(phone2id) - print(word2id) - - write_mapping("phones.txt", phone2id) - write_mapping("words.txt", word2id) - - write_lexicon("a.txt", lexicon) - write_lexicon("a_disambig.txt", lexicon_disambig) - - fsa = lexicon_to_fst(lexicon, phone2id=phone2id, word2id=word2id) - fsa.labels_sym = k2.SymbolTable.from_file("phones.txt") - fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") - fsa.draw("L.pdf", title="L") - - fsa_disambig = lexicon_to_fst( - lexicon_disambig, phone2id=phone2id, word2id=word2id - ) - fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") - fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") - fsa_disambig.draw("L_disambig.pdf", title="L_disambig") - - -def main(): - filename = generate_lexicon_file() - test_read_lexicon(filename) - os.remove(filename) - - -if __name__ == "__main__": - main() diff --git a/egs/fisher_swbd/ASR/prepare.sh b/egs/fisher_swbd/ASR/prepare.sh index cfc621f23..9ef2d7363 100755 --- a/egs/fisher_swbd/ASR/prepare.sh +++ b/egs/fisher_swbd/ASR/prepare.sh @@ -135,7 +135,8 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then # Shuffle the cuts (pure bash pipes are fast). # We could technically skip this step but this helps ensure - # SWBD is not only seen towards the end of training. + # SWBD is not only seen towards the end of training + # (we concatenated it after Fisher). gunzip -c data/manifests/fisher-swbd_cuts_unshuf.jsonl.gz \ | shuf \ | gzip -c \ @@ -163,6 +164,9 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then data/manifests/dev_fisher-swbd_cuts.jsonl.gz \ data/manifests/dev_utterances_fisher-swbd_cuts.jsonl.gz + # Display some statistics about the data. + lhotse cut describe data/manifests/train_utterances_fisher-swbd_cuts.jsonl.gz + lhotse cut describe data/manifests/dev_utterances_fisher-swbd_cuts.jsonl.gz set +x fi @@ -204,6 +208,9 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then echo "#0 ${num_words}" >> $lang_dir/words.txt if [ ! -f $lang_dir/L_disambig.pt ]; then + # We discard SWBD's lexicon and just use g2p_en + # It was trained on CMUdict and looks it up before + # resorting to an LSTM G2P model. pip install g2p_en ./local/prepare_lang_g2pen.py --lang-dir $lang_dir fi diff --git a/egs/fisher_swbd/ASR/streaming_conformer_ctc/README.md b/egs/fisher_swbd/ASR/streaming_conformer_ctc/README.md deleted file mode 100644 index 01be7090b..000000000 --- a/egs/fisher_swbd/ASR/streaming_conformer_ctc/README.md +++ /dev/null @@ -1,90 +0,0 @@ -## Train and Decode -Commands of data preparation/train/decode steps are almost the same with -../conformer_ctc experiment except some options. - -Please read the code and understand following new added options before running this experiment: - - For data preparation: - - Nothing new. - - For streaming_conformer_ctc/train.py: - - --dynamic-chunk-training - --short-chunk-proportion - - For streaming_conformer_ctc/streaming_decode.py: - - --chunk-size - --tailing-num-frames - --simulate-streaming - -## Performence and a trained model. - -The latest results with this streaming code is shown in following table: - -chunk size | wer on test-clean | wer on test-other --- | -- | -- -full | 3.53 | 8.52 -40(1.96s) | 3.78 | 9.38 -32(1.28s) | 3.82 | 9.44 -24(0.96s) | 3.95 | 9.76 -16(0.64s) | 4.06 | 9.98 -8(0.32s) | 4.30 | 10.55 -4(0.16s) | 5.88 | 12.01 - - -A trained model is also provided. -By run -``` -git clone https://huggingface.co/GuoLiyong/streaming_conformer - -# You may want to manually check md5sum values of downloaded files -# 8e633bc1de37f5ae57a2694ceee32a93 trained_streaming_conformer.pt -# 4c0aeefe26c784ec64873cc9b95420f1 L.pt -# d1f91d81005fb8ce4d65953a4a984ee7 Linv.pt -# e1c1902feb7b9fc69cd8d26e663c2608 bpe.model -# 8617e67159b0ff9118baa54f04db24cc tokens.txt -# 72b075ab5e851005cd854e666c82c3bb words.txt -``` - -If there is any different md5sum values, please run -``` -cd streaming_models -git lfs pull -``` -And check md5sum values again. - -Finally, following files will be downloaded: -
-streaming_models/  
-|-- lang_bpe  
-|   |-- L.pt  
-|   |-- Linv.pt  
-|   |-- bpe.model
-|   |-- tokens.txt
-|   `-- words.txt
-`-- trained_streaming_conformer.pt
-
- - -And run commands you will get the same results of previous table: -``` -trained_models=/path/to/downloaded/streaming_models/ -for chunk_size in 4 8 16 24 36 40 -1; do - ./streaming_conformer_ctc/streaming_decode.py \ - --chunk-size=${chunk_size} \ - --trained-dir=${trained_models} -done -``` -Results of following command is indentical to previous one, -but model consumes features chunk_by_chunk, i.e. a streaming way. -``` -trained_models=/path/to/downloaded/streaming_models/ -for chunk_size in 4 8 16 24 36 40 -1; do - ./streaming_conformer_ctc/streaming_decode.py \ - --simulate-streaming=True \ - --chunk-size=${chunk_size} \ - --trained-dir=${trained_models} -done -``` diff --git a/egs/fisher_swbd/ASR/streaming_conformer_ctc/asr_datamodule.py b/egs/fisher_swbd/ASR/streaming_conformer_ctc/asr_datamodule.py deleted file mode 120000 index a73848de9..000000000 --- a/egs/fisher_swbd/ASR/streaming_conformer_ctc/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../conformer_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/fisher_swbd/ASR/streaming_conformer_ctc/conformer.py b/egs/fisher_swbd/ASR/streaming_conformer_ctc/conformer.py deleted file mode 100644 index 33f8e2e15..000000000 --- a/egs/fisher_swbd/ASR/streaming_conformer_ctc/conformer.py +++ /dev/null @@ -1,1355 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import math -import warnings -from typing import Optional, Tuple - -import torch -from torch import Tensor, nn -from transformer import Supervisions, Transformer, encoder_padding_mask - - -# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/mask.py#L42 -def subsequent_chunk_mask( - size: int, - chunk_size: int, - num_left_chunks: int = -1, - device: torch.device = torch.device("cpu"), -) -> torch.Tensor: - """Create mask for subsequent steps (size, size) with chunk size, - this is for streaming encoder - Args: - size (int): size of mask - chunk_size (int): size of chunk - num_left_chunks (int): number of left chunks - <0: use full chunk - >=0: use num_left_chunks - device (torch.device): "cpu" or "cuda" or torch.Tensor.device - Returns: - torch.Tensor: mask - Examples: - >>> subsequent_chunk_mask(4, 2) - [[1, 1, 0, 0], - [1, 1, 0, 0], - [1, 1, 1, 1], - [1, 1, 1, 1]] - """ - ret = torch.zeros(size, size, device=device, dtype=torch.bool) - for i in range(size): - if num_left_chunks < 0: - start = 0 - else: - start = max((i // chunk_size - num_left_chunks) * chunk_size, 0) - ending = min((i // chunk_size + 1) * chunk_size, size) - ret[i, start:ending] = True - return ret - - -class Conformer(Transformer): - """ - Args: - num_features (int): Number of input features - num_classes (int): Number of output classes - subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension - nhead (int): number of head - dim_feedforward (int): feedforward dimention - num_encoder_layers (int): number of encoder layers - num_decoder_layers (int): number of decoder layers - dropout (float): dropout rate - cnn_module_kernel (int): Kernel size of convolution module - normalize_before (bool): whether to use layer_norm before the first block. - vgg_frontend (bool): whether to use vgg frontend. - """ - - def __init__( - self, - num_features: int, - num_classes: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - num_decoder_layers: int = 6, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - normalize_before: bool = True, - vgg_frontend: bool = False, - use_feat_batchnorm: bool = False, - causal: bool = False, - ) -> None: - super(Conformer, self).__init__( - num_features=num_features, - num_classes=num_classes, - subsampling_factor=subsampling_factor, - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - num_encoder_layers=num_encoder_layers, - num_decoder_layers=num_decoder_layers, - dropout=dropout, - normalize_before=normalize_before, - vgg_frontend=vgg_frontend, - use_feat_batchnorm=use_feat_batchnorm, - ) - - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - - encoder_layer = ConformerEncoderLayer( - d_model, - nhead, - dim_feedforward, - dropout, - cnn_module_kernel, - normalize_before, - causal, - ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - self.normalize_before = normalize_before - if self.normalize_before: - self.after_norm = nn.LayerNorm(d_model) - else: - # Note: TorchScript detects that self.after_norm could be used inside forward() - # and throws an error without this change. - self.after_norm = identity - - def run_encoder( - self, - x: Tensor, - supervisions: Optional[Supervisions] = None, - dynamic_chunk_training: bool = False, - short_chunk_proportion: float = 0.5, - chunk_size: int = -1, - simulate_streaming: bool = False, - ) -> Tuple[Tensor, Optional[Tensor]]: - """ - Args: - x: - The model input. Its shape is (N, T, C). - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute encoder padding mask, which is used as memory key padding - mask for the decoder. - dynamic_chunk_training: - For training only. - IF True, train with dynamic right context for some batches - sampled with a distribution - if False, train with full right context all the time. - short_chunk_proportion: - For training only. - Proportion of samples that will be trained with dynamic chunk. - chunk_size: - For eval only. - right context when evaluating test utts. - -1 means all right context. - simulate_streaming=False, - For eval only. - If true, the feature will be feeded into the model chunk by chunk. - If false, the whole utts if feeded into the model together i.e. the - model only foward once. - - - Returns: - Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). - Tensor: Mask tensor of dimension (batch_size, input_length) - """ - if self.encoder.training: - return self.train_run_encoder( - x, supervisions, dynamic_chunk_training, short_chunk_proportion - ) - else: - return self.eval_run_encoder( - x, supervisions, chunk_size, simulate_streaming - ) - - def train_run_encoder( - self, - x: Tensor, - supervisions: Optional[Supervisions] = None, - dynamic_chunk_training: bool = False, - short_chunk_threshold: float = 0.5, - ) -> Tuple[Tensor, Optional[Tensor]]: - """ - Args: - x: - The model input. Its shape is (N, T, C). - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute encoder padding mask, which is used as memory key padding - mask for the decoder. - dynamic_chunk_training: - IF True, train with dynamic right context for some batches - sampled with a distribution - if False, train with full right context all the time. - short_chunk_proportion: - Proportion of samples that will be trained with dynamic chunk. - """ - x = self.encoder_embed(x) - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - src_key_padding_mask = encoder_padding_mask(x.size(0), supervisions) - if src_key_padding_mask is not None: - src_key_padding_mask = src_key_padding_mask.to(x.device) - - if dynamic_chunk_training: - max_len = x.size(0) - chunk_size = torch.randint(1, max_len, (1,)).item() - if chunk_size > (max_len * short_chunk_threshold): - chunk_size = max_len - else: - chunk_size = chunk_size % 25 + 1 - mask = ~subsequent_chunk_mask( - size=x.size(0), chunk_size=chunk_size, device=x.device - ) - x = self.encoder( - x, pos_emb, mask=mask, src_key_padding_mask=src_key_padding_mask - ) # (T, B, F) - else: - x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F) - - if self.normalize_before: - x = self.after_norm(x) - - return x, src_key_padding_mask - - def eval_run_encoder( - self, - feature: Tensor, - supervisions: Optional[Supervisions] = None, - chunk_size: int = -1, - simulate_streaming=False, - ) -> Tuple[Tensor, Optional[Tensor]]: - """ - Args: - feature: - The model input. Its shape is (N, T, C). - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute encoder padding mask, which is used as memory key padding - mask for the decoder. - - Returns: - Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). - Tensor: Mask tensor of dimension (batch_size, input_length) - """ - # feature.shape: N T C - num_frames = feature.size(1) - - # As temporarily in icefall only subsampling_rate == 4 is supported, - # following parameters are hard-coded here. - # Change it accordingly if other subsamling_rate are supported. - embed_left_context = 7 - embed_conv_right_context = 3 - subsampling_rate = 4 - stride = chunk_size * subsampling_rate - decoding_window = embed_conv_right_context + stride - - # This is also only compatible to sumsampling_rate == 4 - length_after_subsampling = ((feature.size(1) - 1) // 2 - 1) // 2 - src_key_padding_mask = encoder_padding_mask( - length_after_subsampling, supervisions - ) - if src_key_padding_mask is not None: - src_key_padding_mask = src_key_padding_mask.to(feature.device) - - if chunk_size < 0: - # non-streaming decoding - x = self.encoder_embed(feature) - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - x = self.encoder( - x, pos_emb, src_key_padding_mask=src_key_padding_mask - ) # (T, B, F) - else: - if simulate_streaming: - # simulate chunk_by_chunk streaming decoding - # Results of this branch should be identical to following - # "else" branch. - # But this branch is a little slower - # as the feature is feeded chunk by chunk - - # store the result of chunk_by_chunk decoding - encoder_output = [] - - # caches - pos_emb_positive = [] - pos_emb_negative = [] - pos_emb_central = None - encoder_cache = [None for i in range(len(self.encoder.layers))] - conv_cache = [None for i in range(len(self.encoder.layers))] - - # start chunk_by_chunk decoding - offset = 0 - for cur in range( - 0, num_frames - embed_left_context + 1, stride - ): - end = min(cur + decoding_window, num_frames) - cur_feature = feature[:, cur:end, :] - cur_feature = self.encoder_embed(cur_feature) - cur_embed, cur_pos_emb = self.encoder_pos( - cur_feature, offset - ) - cur_embed = cur_embed.permute( - 1, 0, 2 - ) # (B, T, F) -> (T, B, F) - - cur_T = cur_feature.size(1) - if cur == 0: - # for first chunk extract the central pos embedding - pos_emb_central = cur_pos_emb[ - 0, (chunk_size - 1), : - ].view(1, 1, -1) - cur_T -= 1 - pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0)) - pos_emb_negative.append(cur_pos_emb[0, -cur_T:]) - assert pos_emb_positive[-1].size(0) == cur_T - - pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze( - 0 - ) - pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze( - 0 - ) - cur_pos_emb = torch.cat( - [pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg], - dim=1, - ) - - x = self.encoder.chunk_forward( - cur_embed, - cur_pos_emb, - src_key_padding_mask=src_key_padding_mask[ - :, : offset + cur_embed.size(0) - ], - encoder_cache=encoder_cache, - conv_cache=conv_cache, - offset=offset, - ) # (T, B, F) - encoder_output.append(x) - offset += cur_embed.size(0) - - x = torch.cat(encoder_output, dim=0) - else: - # NOT simulate chunk_by_chunk decoding - # Results of this branch should be identical to previous - # simulate chunk_by_chunk decoding branch. - # But this branch is faster. - x = self.encoder_embed(feature) - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - mask = ~subsequent_chunk_mask( - size=x.size(0), chunk_size=chunk_size, device=x.device - ) - x = self.encoder( - x, - pos_emb, - mask=mask, - src_key_padding_mask=src_key_padding_mask, - ) # (T, B, F) - - if self.normalize_before: - x = self.after_norm(x) - - return x, src_key_padding_mask - - -class ConformerEncoderLayer(nn.Module): - """ - ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. - See: "Conformer: Convolution-augmented Transformer for Speech Recognition" - - Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - normalize_before: whether to use layer_norm before the first block. - - Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - normalize_before: bool = True, - causal: bool = False, - ) -> None: - super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention( - d_model, nhead, dropout=0.0 - ) - - self.feed_forward = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), - nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), - ) - - self.feed_forward_macaron = nn.Sequential( - nn.Linear(d_model, dim_feedforward), - Swish(), - nn.Dropout(dropout), - nn.Linear(dim_feedforward, d_model), - ) - - self.conv_module = ConvolutionModule( - d_model, cnn_module_kernel, causal=causal - ) - - self.norm_ff_macaron = nn.LayerNorm( - d_model - ) # for the macaron style FNN module - self.norm_ff = nn.LayerNorm(d_model) # for the FNN module - self.norm_mha = nn.LayerNorm(d_model) # for the MHA module - - self.ff_scale = 0.5 - - self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm( - d_model - ) # for the final output of the block - - self.dropout = nn.Dropout(dropout) - - self.normalize_before = normalize_before - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number - """ - - # macaron style feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) - if not self.normalize_before: - src = self.norm_ff_macaron(src) - - # multi-headed self-attention module - residual = src - if self.normalize_before: - src = self.norm_mha(src) - src_att = self.self_attn( - src, - src, - src, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - src = residual + self.dropout(src_att) - if not self.normalize_before: - src = self.norm_mha(src) - - # convolution module - residual = src - if self.normalize_before: - src = self.norm_conv(src) - src = residual + self.dropout(self.conv_module(src)) - if not self.normalize_before: - src = self.norm_conv(src) - - # feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) - if not self.normalize_before: - src = self.norm_ff(src) - - if self.normalize_before: - src = self.norm_final(src) - - return src - - def chunk_forward( - self, - src: Tensor, - pos_emb: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - encoder_cache: Optional[Tensor] = None, - conv_cache: Optional[Tensor] = None, - offset=0, - ) -> Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number - """ - - # macaron style feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout( - self.feed_forward_macaron(src) - ) - if not self.normalize_before: - src = self.norm_ff_macaron(src) - - # multi-headed self-attention module - residual = src - if self.normalize_before: - src = self.norm_mha(src) - if encoder_cache is None: - # src: [chunk_size, N, F] e.g. [8, 41, 512] - key = src - val = key - encoder_cache = key - else: - key = torch.cat([encoder_cache, src], dim=0) - val = key - encoder_cache = key - src_att = self.self_attn( - src, - key, - val, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - offset=offset, - )[0] - src = residual + self.dropout(src_att) - if not self.normalize_before: - src = self.norm_mha(src) - - # convolution module - residual = src # [chunk_size, N, F] e.g. [8, 41, 512] - if self.normalize_before: - src = self.norm_conv(src) - if conv_cache is not None: - src = torch.cat([conv_cache, src], dim=0) - conv_cache = src - - src = self.conv_module(src) - src = src[-residual.size(0) :, :, :] # noqa: E203 - - src = residual + self.dropout(src) - if not self.normalize_before: - src = self.norm_conv(src) - - # feed forward module - residual = src - if self.normalize_before: - src = self.norm_ff(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward(src)) - if not self.normalize_before: - src = self.norm_ff(src) - - if self.normalize_before: - src = self.norm_final(src) - - return src, encoder_cache, conv_cache - - -class ConformerEncoder(nn.TransformerEncoder): - r"""ConformerEncoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the ConformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - norm: the layer normalization component (optional). - - Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = conformer_encoder(src, pos_emb) - """ - - def __init__( - self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None - ) -> None: - super(ConformerEncoder, self).__init__( - encoder_layer=encoder_layer, num_layers=num_layers, norm=norm - ) - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - pos_emb: Positional embedding tensor (required). - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - """ - output = src - - for mod in self.layers: - output = mod( - output, - pos_emb, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - ) - - if self.norm is not None: - output = self.norm(output) - - return output - - def chunk_forward( - self, - src: Tensor, - pos_emb: Tensor, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - encoder_cache=None, - conv_cache=None, - offset=0, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - pos_emb: Positional embedding tensor (required). - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - """ - output = src - - for layer_index, mod in enumerate(self.layers): - output, e_cache, c_cache = mod.chunk_forward( - output, - pos_emb, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - encoder_cache=encoder_cache[layer_index], - conv_cache=conv_cache[layer_index], - offset=offset, - ) - encoder_cache[layer_index] = e_cache - conv_cache[layer_index] = c_cache - - if self.norm is not None: - output = self.norm(output) - - return output - - -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module. - - See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py - - Args: - d_model: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length. - - """ - - def __init__( - self, d_model: int, dropout_rate: float, max_len: int = 5000 - ) -> None: - """Construct an PositionalEncoding object.""" - super(RelPositionalEncoding, self).__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - - def extend_pe(self, x: Tensor, offset: int = 0) -> None: - """Reset the positional encodings.""" - x_size_1 = offset + x.size(1) - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x_size_1 * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str( - x.device - ): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vecotr and `j` means the - # position of key vector. We use position relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, time, `*`). - - Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). - torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). - - """ - self.extend_pe(x, offset) - x = x * self.xscale - x_size_1 = offset + x.size(1) - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - - x_size_1 - + 1 : self.pe.size(1) // 2 # noqa E203 - + x_size_1, - ] - x_T = x.size(1) - if offset > 0: - pos_emb = torch.cat([pos_emb[:, :x_T], pos_emb[:, -x_T:]], dim=1) - else: - pos_emb = torch.cat( - [ - pos_emb[:, : (x_T - 1)], - self.pe[0, self.pe.size(1) // 2].view( - 1, 1, self.pe.size(-1) - ), - pos_emb[:, -(x_T - 1) :], # noqa: E203 - ], - dim=1, - ) - - return self.dropout(x), self.dropout(pos_emb) - - -class RelPositionMultiheadAttention(nn.Module): - r"""Multi-Head Attention layer with relative position encoding - - See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - - Args: - embed_dim: total dimension of the model. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. - - Examples:: - - >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - ) -> None: - super(RelPositionMultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" - - self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) - - # linear transformation for positional encoding. - self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - - self._reset_parameters() - - def _reset_parameters(self) -> None: - nn.init.xavier_uniform_(self.in_proj.weight) - nn.init.constant_(self.in_proj.bias, 0.0) - nn.init.constant_(self.out_proj.bias, 0.0) - - nn.init.xavier_uniform_(self.pos_bias_u) - nn.init.xavier_uniform_(self.pos_bias_v) - - def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - offset=0, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask and a value is True, - the corresponding value on the attention layer will be ignored. When given - a byte mask and a value is non-zero, the corresponding value on the attention - layer will be ignored - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - return self.multi_head_attention_forward( - query, - key, - value, - pos_emb, - self.embed_dim, - self.num_heads, - self.in_proj.weight, - self.in_proj.bias, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - offset=offset, - ) - - def rel_shift(self, x: Tensor, offset=0) -> Tensor: - """Compute relative positional encoding. - - Args: - x: Input tensor (batch, head, time1, 2*time1-1). - time1 means the length of query vector. - - Returns: - Tensor: tensor of shape (batch, head, time1, time2) - (note: time2 == time1 + offset, since it is for - the key, while time1 is for the query). - """ - (batch_size, num_heads, time1, n) = x.shape - time2 = time1 + offset - assert n == 2 * time2 - 1 - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - - return x.as_strided( - (batch_size, num_heads, time1, time2), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) - - def multi_head_attention_forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_weight: Tensor, - in_proj_bias: Tensor, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = True, - attn_mask: Optional[Tensor] = None, - offset=0, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - embed_dim_to_check: total dimension of the model. - num_heads: parallel attention heads. - in_proj_weight, in_proj_bias: input projection weight and bias. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence - length, N is the batch size, E is the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - - tgt_len, bsz, embed_dim = query.size() - assert embed_dim == embed_dim_to_check - assert key.size(0) == value.size(0) and key.size(1) == value.size(1) - - head_dim = embed_dim // num_heads - assert ( - head_dim * num_heads == embed_dim - ), "embed_dim must be divisible by num_heads" - scaling = float(head_dim) ** -0.5 - - if torch.equal(query, key) and torch.equal(key, value): - # self-attention - q, k, v = nn.functional.linear( - query, in_proj_weight, in_proj_bias - ).chunk(3, dim=-1) - - elif torch.equal(key, value): - # encoder-decoder attention - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) - - else: - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = embed_dim * 2 - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim * 2 - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - v = nn.functional.linear(value, _w, _b) - - if attn_mask is not None: - assert ( - attn_mask.dtype == torch.float32 - or attn_mask.dtype == torch.float64 - or attn_mask.dtype == torch.float16 - or attn_mask.dtype == torch.uint8 - or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( - attn_mask.dtype - ) - if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." - ) - attn_mask = attn_mask.to(torch.bool) - - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError( - "The size of the 2D attn_mask is not correct." - ) - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [ - bsz * num_heads, - query.size(0), - key.size(0), - ]: - raise RuntimeError( - "The size of the 3D attn_mask is not correct." - ) - else: - raise RuntimeError( - "attn_mask's dimension {} is not supported".format( - attn_mask.dim() - ) - ) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if ( - key_padding_mask is not None - and key_padding_mask.dtype == torch.uint8 - ): - warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim) - k = k.contiguous().view(-1, bsz, num_heads, head_dim) - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) - - src_len = k.size(0) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz, "{} == {}".format( - key_padding_mask.size(0), bsz - ) - assert key_padding_mask.size(1) == src_len, "{} == {}".format( - key_padding_mask.size(1), src_len - ) - - q = q.transpose(0, 1) # (batch, time1, head, d_k) - - pos_emb_bsz = pos_emb.size(0) - assert pos_emb_bsz in (1, bsz) # actually it is 1 - p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - - # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) - p = p.permute(0, 2, 3, 1) - - q_with_bias_u = (q + self.pos_bias_u).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - q_with_bias_v = (q + self.pos_bias_v).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - # compute attention score - # first compute matrix a and matrix c - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul( - q_with_bias_u, k - ) # (batch, head, time1, time2) - - # compute matrix b and matrix d - matrix_bd = torch.matmul( - q_with_bias_v, p - ) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift( - matrix_bd, offset=offset - ) # [B, head, time1, time2] - attn_output_weights = ( - matrix_ac + matrix_bd - ) * scaling # (batch, head, time1, time2) - - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, -1 - ) - - assert list(attn_output_weights.size()) == [ - bsz * num_heads, - tgt_len, - src_len, - ] - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float("-inf")) - else: - attn_output_weights += attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), - ) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, src_len - ) - - attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=dropout_p, training=training - ) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear( - attn_output, out_proj_weight, out_proj_bias - ) - - if need_weights: - # average attention weights over heads - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - return attn_output, attn_output_weights.sum(dim=1) / num_heads - else: - return attn_output, None - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Conformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__( - self, - channels: int, - kernel_size: int, - bias: bool = True, - causal: bool = False, - ) -> None: - """Construct an ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - # from https://github.com/wenet-e2e/wenet/blob/main/wenet/transformer/convolution.py#L41 - if causal: - self.lorder = kernel_size - 1 - padding = 0 # manualy padding self.lorder zeros to the left later - else: - assert (kernel_size - 1) % 2 == 0 - self.lorder = 0 - padding = (kernel_size - 1) // 2 - self.depthwise_conv = nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=padding, - groups=channels, - bias=bias, - ) - self.norm = nn.BatchNorm1d(channels) - self.pointwise_conv2 = nn.Conv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - self.activation = Swish() - - def forward(self, x: Tensor) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - - Returns: - Tensor: Output tensor (#time, batch, channels). - - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # 1D Depthwise Conv - if self.lorder > 0: - # manualy padding self.lorder zeros to the left - # make depthwise_conv causal - x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) - x = self.depthwise_conv(x) - x = self.activation(self.norm(x)) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1) - - -class Swish(torch.nn.Module): - """Construct an Swish object.""" - - def forward(self, x: Tensor) -> Tensor: - """Return Swich activation function.""" - return x * torch.sigmoid(x) - - -def identity(x): - return x diff --git a/egs/fisher_swbd/ASR/streaming_conformer_ctc/label_smoothing.py b/egs/fisher_swbd/ASR/streaming_conformer_ctc/label_smoothing.py deleted file mode 120000 index 08734abd7..000000000 --- a/egs/fisher_swbd/ASR/streaming_conformer_ctc/label_smoothing.py +++ /dev/null @@ -1 +0,0 @@ -../conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/fisher_swbd/ASR/streaming_conformer_ctc/streaming_decode.py b/egs/fisher_swbd/ASR/streaming_conformer_ctc/streaming_decode.py deleted file mode 100755 index a74c51836..000000000 --- a/egs/fisher_swbd/ASR/streaming_conformer_ctc/streaming_decode.py +++ /dev/null @@ -1,521 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from conformer import Conformer -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py#L166 -def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: - new_hyp: List[int] = [] - cur = 0 - while cur < len(hyp): - if hyp[cur] != 0: - new_hyp.append(hyp[cur]) - prev = cur - while cur < len(hyp) and hyp[cur] == hyp[prev]: - cur += 1 - return new_hyp - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=34, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--chunk-size", - type=int, - default=8, - help="Frames of right context" - "-1 for whole right context, i.e. non-streaming decoding", - ) - - parser.add_argument( - "--tailing-num-frames", - type=int, - default=20, - help="tailing dummy frames padded to the right," - "only used during decoding", - ) - - parser.add_argument( - "--simulate-streaming", - type=str2bool, - default=False, - help="simulate chunk by chunk decoding", - ) - parser.add_argument( - "--method", - type=str, - default="ctc-greedy-search", - help="Streaming Decoding method", - ) - - parser.add_argument( - "--export", - type=str2bool, - default=False, - help="""When enabled, the averaged model is saved to - conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved. - pretrained.pt contains a dict {"model": model.state_dict()}, - which can be loaded by `icefall.checkpoint.load_checkpoint()`. - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="streaming_conformer_ctc/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--trained-dir", - type=Path, - default=None, - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe", - help="The lang dir", - ) - - parser.add_argument( - "--avg-models", - type=str, - default=None, - help="Manually select models to average, seperated by comma;" - "e.g. 60,62,63,72", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang_bpe"), - # parameters for conformer - "causal": True, - "subsampling_factor": 4, - "vgg_frontend": False, - "use_feat_batchnorm": True, - "feature_dim": 80, - "nhead": 8, - "attention_dim": 512, - "num_decoder_layers": 6, - # parameters for decoding - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - bpe_model: Optional[spm.SentencePieceProcessor], - batch: dict, - word_table: k2.SymbolTable, - sos_id: int, - eos_id: int, - chunk_size: int = -1, - simulate_streaming=False, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if no rescoring is used, the key is the string `no_rescore`. - If LM rescoring is used, the key is the string `lm_scale_xxx`, - where `xxx` is the value of `lm_scale`. An example key is - `lm_scale_0.7` - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - - model: - The neural model. - bpe_model: - The BPE model. Used only when params.method is ctc-decoding. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - sos_id: - The token ID of the SOS. - eos_id: - The token ID of the EOS. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - feature = batch["inputs"] - device = torch.device("cuda") - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - # Extra dummy tailing frames my reduce deletion error - # example WITHOUT padding: - # CHAPTER SEVEN ON THE RACES OF MAN - # example WITH padding: - # CHAPTER SEVEN ON THE RACES OF (MAN->*) - tailing_frames = ( - torch.tensor([-23.0259]) - .expand([feature.size(0), params.tailing_num_frames, 80]) - .to(feature.device) - ) - feature = torch.cat([feature, tailing_frames], dim=1) - supervisions["num_frames"] += params.tailing_num_frames - - nnet_output, memory, memory_key_padding_mask = model( - feature, - supervisions, - chunk_size=chunk_size, - simulate_streaming=simulate_streaming, - ) - - assert params.method == "ctc-greedy-search" - key = "ctc-greedy-search" - batch_size = nnet_output.size(0) - maxlen = nnet_output.size(1) - topk_prob, topk_index = nnet_output.topk(1, dim=2) # (B, maxlen, 1) - topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) - topk_index = topk_index.masked_fill_( - memory_key_padding_mask, 0 - ) # (B, maxlen) - token_ids = [token_id.tolist() for token_id in topk_index] - token_ids = [ - remove_duplicates_and_blank(token_id) for token_id in token_ids - ] - hyps = bpe_model.decode(token_ids) - hyps = [s.split() for s in hyps] - return {key: hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - bpe_model: Optional[spm.SentencePieceProcessor], - word_table: k2.SymbolTable, - sos_id: int, - eos_id: int, - chunk_size: int = -1, - simulate_streaming=False, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - bpe_model: - The BPE model. Used only when params.method is ctc-decoding. - word_table: - It is the word symbol table. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. - chunk_size: - right context to simulate streaming decoding - -1 for whole right context, i.e. non-stream decoding - Returns: - Return a dict, whose key may be "no-rescore" if no LM rescoring - is used, or it may be "lm_scale_0.7" if LM rescoring is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - results = [] - - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - - hyps_dict = decode_one_batch( - params=params, - model=model, - bpe_model=bpe_model, - batch=batch, - word_table=word_table, - sos_id=sos_id, - eos_id=eos_id, - chunk_size=chunk_size, - simulate_streaming=simulate_streaming, - ) - - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): - ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) - - num_cuts += len(batch["supervisions"]["text"]) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) - - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], -): - if params.method == "attention-decoder": - # Set it to False since there are too many logs. - enable_log = False - else: - enable_log = True - test_set_wers = dict() - if params.avg_models is not None: - avg_models = params.avg_models.replace(",", "_") - result_file_prefix = f"epoch-avg-{avg_models}-chunksize \ - -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-" - else: - result_file_prefix = f"epoch-{params.epoch}-avg-{params.avg}-chunksize \ - -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-" - for key, results in results_dict.items(): - recog_path = ( - params.exp_dir - / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt" - ) - store_transcripts(filename=recog_path, texts=results) - if enable_log: - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.exp_dir - / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=enable_log - ) - test_set_wers[key] = wer - - if enable_log: - logging.info( - "Wrote detailed error stats to {}".format(errs_filename) - ) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - - setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") - logging.info("Decoding started") - logging.info(params) - - if params.trained_dir is not None: - params.lang_dir = Path(params.trained_dir) / "lang_bpe" - # used naming result files - params.epoch = "trained_model" - params.avg = 1 - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) - sos_id = graph_compiler.sos_id - eos_id = graph_compiler.eos_id - - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, - causal=params.causal, - ) - - if params.trained_dir is not None: - model_name = f"{params.trained_dir}/trained_streaming_conformer.pt" - load_checkpoint(model_name, model) - elif params.avg == 1 and params.avg_models is not None: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - filenames = [] - if params.avg_models is not None: - model_ids = params.avg_models.split(",") - for i in model_ids: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - else: - start = params.epoch - params.avg + 1 - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict(average_checkpoints(filenames)) - - if params.export: - logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) - return - - model.to(device) - model.eval() - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - librispeech = LibriSpeechAsrDataModule(args) - # CAUTION: `test_sets` is for displaying only. - # If you want to skip test-clean, you have to skip - # it inside the for loop. That is, use - # - # if test_set == 'test-clean': continue - # - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(str(params.lang_dir / "bpe.model")) - test_sets = ["test-clean", "test-other"] - for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - bpe_model=bpe_model, - word_table=lexicon.word_table, - sos_id=sos_id, - eos_id=eos_id, - chunk_size=params.chunk_size, - simulate_streaming=params.simulate_streaming, - ) - - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) - - logging.info("Done!") - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/fisher_swbd/ASR/streaming_conformer_ctc/subsampling.py b/egs/fisher_swbd/ASR/streaming_conformer_ctc/subsampling.py deleted file mode 120000 index 6fee09e58..000000000 --- a/egs/fisher_swbd/ASR/streaming_conformer_ctc/subsampling.py +++ /dev/null @@ -1 +0,0 @@ -../conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/fisher_swbd/ASR/streaming_conformer_ctc/train.py b/egs/fisher_swbd/ASR/streaming_conformer_ctc/train.py deleted file mode 100755 index 8b4d6701e..000000000 --- a/egs/fisher_swbd/ASR/streaming_conformer_ctc/train.py +++ /dev/null @@ -1,745 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from conformer import Conformer -from lhotse.utils import fix_random_seed -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.utils.tensorboard import SummaryWriter -from transformer import Noam - -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - encode_supervisions, - setup_logger, - str2bool, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=78, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - conformer_ctc/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_ctc/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_bpe_500", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--att-rate", - type=float, - default=0.8, - help="""The attention rate. - The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss - """, - ) - - parser.add_argument( - "--dynamic-chunk-training", - type=str2bool, - default=True, - help="Whether to use dynamic right context during training.", - ) - - parser.add_argument( - "--short-chunk-proportion", - type=float, - default=0.7, - help="Proportion of samples trained with short right context", - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - use_feat_batchnorm: Whether to do batch normalization for the - input features. - - - attention_dim: Hidden dim for multi-head attention model. - - - head: Number of heads of multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - beam_size: It is used in k2.ctc_loss - - - reduction: It is used in k2.ctc_loss - - - use_double_scores: It is used in k2.ctc_loss - - - weight_decay: The weight_decay for the optimizer. - - - lr_factor: The lr_factor for Noam optimizer. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 3000, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "use_feat_batchnorm": True, - "attention_dim": 512, - "nhead": 8, - "num_decoder_layers": 6, - # parameters for loss - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True, - # parameters for Noam - "weight_decay": 1e-6, - "lr_factor": 5.0, - "warm_step": 80000, - "env_info": get_env_info(), - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - batch: dict, - graph_compiler: BpeCtcTrainingGraphCompiler, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = graph_compiler.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - with torch.set_grad_enabled(is_training): - nnet_output, encoder_memory, memory_mask = model( - feature, - supervisions, - dynamic_chunk_training=params.dynamic_chunk_training, - short_chunk_proportion=params.short_chunk_proportion, - ) - # nnet_output is (N, T, C) - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - - token_ids = graph_compiler.texts_to_ids(texts) - - decoding_graph = graph_compiler.compile(token_ids) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - if params.att_rate != 0.0: - with torch.set_grad_enabled(is_training): - mmodel = model.module if hasattr(model, "module") else model - # Note: We need to generate an unsorted version of token_ids - # `encode_supervisions()` called above sorts text, but - # encoder_memory and memory_mask are not sorted, so we - # use an unsorted version `supervisions["text"]` to regenerate - # the token_ids - # - # See https://github.com/k2-fsa/icefall/issues/97 - # for more details - unsorted_token_ids = graph_compiler.texts_to_ids( - supervisions["text"] - ) - att_loss = mmodel.decoder_forward( - encoder_memory, - memory_mask, - token_ids=unsorted_token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) - loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss - else: - loss = ctc_loss - att_loss = torch.tensor([0]) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - info["frames"] = supervision_segments[:, 2].sum().item() - info["ctc_loss"] = ctc_loss.detach().cpu().item() - if params.att_rate != 0.0: - info["att_loss"] = att_loss.detach().cpu().item() - - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: BpeCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - graph_compiler: BpeCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - graph_compiler: - It is used to convert transcripts to FSAs. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - - if batch_idx % params.log_interval == 0: - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(42) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) - - logging.info("About to create model") - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=False, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - model = DDP(model, device_ids=[rank]) - - optimizer = Noam( - model.parameters(), - model_size=params.attention_dim, - factor=params.lr_factor, - warm_step=params.warm_step, - weight_decay=params.weight_decay, - ) - - if checkpoints: - optimizer.load_state_dict(checkpoints["optimizer"]) - - librispeech = LibriSpeechAsrDataModule(args) - train_dl = librispeech.train_dataloaders() - valid_dl = librispeech.valid_dataloaders() - - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - for epoch in range(params.start_epoch, params.num_epochs): - train_dl.sampler.set_epoch(epoch) - - cur_lr = optimizer._rate - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - if rank == 0: - logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: nn.Module, - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: BpeCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - optimizer.zero_grad() - loss, _ = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/fisher_swbd/ASR/streaming_conformer_ctc/transformer.py b/egs/fisher_swbd/ASR/streaming_conformer_ctc/transformer.py deleted file mode 100644 index bc78e4a41..000000000 --- a/egs/fisher_swbd/ASR/streaming_conformer_ctc/transformer.py +++ /dev/null @@ -1,966 +0,0 @@ -# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import math -from typing import Dict, List, Optional, Tuple - -import torch -import torch.nn as nn -from label_smoothing import LabelSmoothingLoss -from subsampling import Conv2dSubsampling, VggSubsampling -from torch.nn.utils.rnn import pad_sequence - -# Note: TorchScript requires Dict/List/etc. to be fully typed. -Supervisions = Dict[str, torch.Tensor] - - -class Transformer(nn.Module): - def __init__( - self, - num_features: int, - num_classes: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - num_decoder_layers: int = 6, - dropout: float = 0.1, - normalize_before: bool = True, - vgg_frontend: bool = False, - use_feat_batchnorm: bool = False, - ) -> None: - """ - Args: - num_features: - The input dimension of the model. - num_classes: - The output dimension of the model. - subsampling_factor: - Number of output frames is num_in_frames // subsampling_factor. - Currently, subsampling_factor MUST be 4. - d_model: - Attention dimension. - nhead: - Number of heads in multi-head attention. - Must satisfy d_model // nhead == 0. - dim_feedforward: - The output dimension of the feedforward layers in encoder/decoder. - num_encoder_layers: - Number of encoder layers. - num_decoder_layers: - Number of decoder layers. - dropout: - Dropout in encoder/decoder. - normalize_before: - If True, use pre-layer norm; False to use post-layer norm. - vgg_frontend: - True to use vgg style frontend for subsampling. - use_feat_batchnorm: - True to use batchnorm for the input layer. - """ - super().__init__() - self.use_feat_batchnorm = use_feat_batchnorm - if use_feat_batchnorm: - self.feat_batchnorm = nn.BatchNorm1d(num_features) - - self.num_features = num_features - self.num_classes = num_classes - self.subsampling_factor = subsampling_factor - if subsampling_factor != 4: - raise NotImplementedError("Support only 'subsampling_factor=4'.") - - # self.encoder_embed converts the input of shape (N, T, num_classes) - # to the shape (N, T//subsampling_factor, d_model). - # That is, it does two things simultaneously: - # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_classes -> d_model - if vgg_frontend: - self.encoder_embed = VggSubsampling(num_features, d_model) - else: - self.encoder_embed = Conv2dSubsampling(num_features, d_model) - - self.encoder_pos = PositionalEncoding(d_model, dropout) - - encoder_layer = TransformerEncoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - normalize_before=normalize_before, - ) - - if normalize_before: - encoder_norm = nn.LayerNorm(d_model) - else: - encoder_norm = None - - self.encoder = nn.TransformerEncoder( - encoder_layer=encoder_layer, - num_layers=num_encoder_layers, - norm=encoder_norm, - ) - - # TODO(fangjun): remove dropout - self.encoder_output_layer = nn.Sequential( - nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) - ) - - if num_decoder_layers > 0: - self.decoder_num_class = ( - self.num_classes - ) # bpe model already has sos/eos symbol - - self.decoder_embed = nn.Embedding( - num_embeddings=self.decoder_num_class, embedding_dim=d_model - ) - self.decoder_pos = PositionalEncoding(d_model, dropout) - - decoder_layer = TransformerDecoderLayer( - d_model=d_model, - nhead=nhead, - dim_feedforward=dim_feedforward, - dropout=dropout, - normalize_before=normalize_before, - ) - - if normalize_before: - decoder_norm = nn.LayerNorm(d_model) - else: - decoder_norm = None - - self.decoder = nn.TransformerDecoder( - decoder_layer=decoder_layer, - num_layers=num_decoder_layers, - norm=decoder_norm, - ) - - self.decoder_output_layer = torch.nn.Linear( - d_model, self.decoder_num_class - ) - - self.decoder_criterion = LabelSmoothingLoss() - else: - self.decoder_criterion = None - - def forward( - self, - x: torch.Tensor, - supervision: Optional[Supervisions] = None, - dynamic_chunk_training: bool = False, - short_chunk_proportion: float = 0.5, - chunk_size: int = -1, - simulate_streaming=False, - ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (N, T, C). - supervision: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) - - Returns: - Return a tuple containing 3 tensors: - - CTC output for ctc decoding. Its shape is (N, T, C) - - Encoder output with shape (T, N, C). It can be used as key and - value for the decoder. - - Encoder output padding mask. It can be used as - memory_key_padding_mask for the decoder. Its shape is (N, T). - It is None if `supervision` is None. - """ - if self.use_feat_batchnorm: - x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) - x = self.feat_batchnorm(x) - x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder( - x, - supervision, - dynamic_chunk_training=dynamic_chunk_training, - short_chunk_proportion=short_chunk_proportion, - chunk_size=chunk_size, - simulate_streaming=simulate_streaming, - ) - x = self.ctc_output(encoder_memory) - return x, encoder_memory, memory_key_padding_mask - - def run_encoder( - self, - x: torch.Tensor, - supervisions: Optional[Supervisions] = None, - chunk_size: int = -1, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Run the transformer encoder. - - Args: - x: - The model input. Its shape is (N, T, C). - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - CAUTION: It contains length information, i.e., start and number of - frames, before subsampling - It is read directly from the batch, without any sorting. It is used - to compute the encoder padding mask, which is used as memory key - padding mask for the decoder. - chunk_size: right chunk_size to simulate streaming decoding - -1 for whole right context - Returns: - Return a tuple with two tensors: - - The encoder output, with shape (T, N, C) - - encoder padding mask, with shape (N, T). - The mask is None if `supervisions` is None. - It is used as memory key padding mask in the decoder. - """ - # streaming decoding(chunk_size >= 0) is only verified with Conformer - assert chunk_size == -1 - x = self.encoder_embed(x) - x = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - mask = encoder_padding_mask(x.size(0), supervisions) - mask = mask.to(x.device) if mask is not None else None - x = self.encoder( - x, src_key_padding_mask=mask, chunk_size=chunk_size - ) # (T, N, C) - - return x, mask - - def ctc_output(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - The output tensor from the transformer encoder. - Its shape is (T, N, C) - - Returns: - Return a tensor that can be used for CTC decoding. - Its shape is (N, T, C) - """ - x = self.encoder_output_layer(x) - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) - return x - - @torch.jit.export - def decoder_forward( - self, - memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[List[int]], - sos_id: int, - eos_id: int, - ) -> torch.Tensor: - """ - Args: - memory: - It's the output of the encoder with shape (T, N, C) - memory_key_padding_mask: - The padding mask from the encoder. - token_ids: - A list-of-list IDs. Each sublist contains IDs for an utterance. - The IDs can be either phone IDs or word piece IDs. - sos_id: - sos token id - eos_id: - eos token id - - Returns: - A scalar, the **sum** of label smoothing loss over utterances - in the batch without any normalization. - """ - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) - - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) - - device = memory.device - ys_in_pad = ys_in_pad.to(device) - ys_out_pad = ys_out_pad.to(device) - - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) - - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - pred_pad = self.decoder( - tgt=tgt, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - ) # (T, N, C) - pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) - pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) - - decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) - - return decoder_loss - - @torch.jit.export - def decoder_nll( - self, - memory: torch.Tensor, - memory_key_padding_mask: torch.Tensor, - token_ids: List[torch.Tensor], - sos_id: int, - eos_id: int, - ) -> torch.Tensor: - """ - Args: - memory: - It's the output of the encoder with shape (T, N, C) - memory_key_padding_mask: - The padding mask from the encoder. - token_ids: - A list-of-list IDs (e.g., word piece IDs). - Each sublist represents an utterance. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. - Returns: - A 2-D tensor of shape (len(token_ids), max_token_length) - representing the cross entropy loss (i.e., negative log-likelihood). - """ - # The common part between this function and decoder_forward could be - # extracted as a separate function. - if isinstance(token_ids[0], torch.Tensor): - # This branch is executed by torchscript in C++. - # See https://github.com/k2-fsa/k2/pull/870 - # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 - token_ids = [tolist(t) for t in token_ids] - - ys_in = add_sos(token_ids, sos_id=sos_id) - ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence( - ys_in, batch_first=True, padding_value=float(eos_id) - ) - - ys_out = add_eos(token_ids, eos_id=eos_id) - ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence( - ys_out, batch_first=True, padding_value=float(-1) - ) - - device = memory.device - ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) - ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - device - ) - - tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) - # TODO: Use length information to create the decoder padding mask - # We set the first column to False since the first column in ys_in_pad - # contains sos_id, which is the same as eos_id in our current setting. - tgt_key_padding_mask[:, 0] = False - - tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) - tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - pred_pad = self.decoder( - tgt=tgt, - memory=memory, - tgt_mask=tgt_mask, - tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=memory_key_padding_mask, - ) # (T, B, F) - pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) - pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) - # nll: negative log-likelihood - nll = torch.nn.functional.cross_entropy( - pred_pad.view(-1, self.decoder_num_class), - ys_out_pad.view(-1), - ignore_index=-1, - reduction="none", - ) - - nll = nll.view(pred_pad.shape[0], -1) - - return nll - - -class TransformerEncoderLayer(nn.Module): - """ - Modified from torch.nn.TransformerEncoderLayer. - Add support of normalize_before, - i.e., use layer_norm before the first block. - - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - activation: - the activation function of intermediate layer, relu or - gelu (default=relu). - normalize_before: - whether to use layer_norm before the first block. - - Examples:: - >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> out = encoder_layer(src) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: str = "relu", - normalize_before: bool = True, - ) -> None: - super(TransformerEncoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - self.normalize_before = normalize_before - - def __setstate__(self, state): - if "activation" not in state: - state["activation"] = nn.functional.relu - super(TransformerEncoderLayer, self).__setstate__(state) - - def forward( - self, - src: torch.Tensor, - src_mask: Optional[torch.Tensor] = None, - src_key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional) - - Shape: - src: (S, N, E). - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, - N is the batch size, E is the feature number - """ - residual = src - if self.normalize_before: - src = self.norm1(src) - src2 = self.self_attn( - src, - src, - src, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - src = residual + self.dropout1(src2) - if not self.normalize_before: - src = self.norm1(src) - - residual = src - if self.normalize_before: - src = self.norm2(src) - src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) - src = residual + self.dropout2(src2) - if not self.normalize_before: - src = self.norm2(src) - return src - - -class TransformerDecoderLayer(nn.Module): - """ - Modified from torch.nn.TransformerDecoderLayer. - Add support of normalize_before, - i.e., use layer_norm before the first block. - - Args: - d_model: - the number of expected features in the input (required). - nhead: - the number of heads in the multiheadattention models (required). - dim_feedforward: - the dimension of the feedforward network model (default=2048). - dropout: - the dropout value (default=0.1). - activation: - the activation function of intermediate layer, relu or - gelu (default=relu). - - Examples:: - >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) - >>> memory = torch.rand(10, 32, 512) - >>> tgt = torch.rand(20, 32, 512) - >>> out = decoder_layer(tgt, memory) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - activation: str = "relu", - normalize_before: bool = True, - ) -> None: - super(TransformerDecoderLayer, self).__init__() - self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) - # Implementation of Feedforward model - self.linear1 = nn.Linear(d_model, dim_feedforward) - self.dropout = nn.Dropout(dropout) - self.linear2 = nn.Linear(dim_feedforward, d_model) - - self.norm1 = nn.LayerNorm(d_model) - self.norm2 = nn.LayerNorm(d_model) - self.norm3 = nn.LayerNorm(d_model) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.dropout3 = nn.Dropout(dropout) - - self.activation = _get_activation_fn(activation) - - self.normalize_before = normalize_before - - def __setstate__(self, state): - if "activation" not in state: - state["activation"] = nn.functional.relu - super(TransformerDecoderLayer, self).__setstate__(state) - - def forward( - self, - tgt: torch.Tensor, - memory: torch.Tensor, - tgt_mask: Optional[torch.Tensor] = None, - memory_mask: Optional[torch.Tensor] = None, - tgt_key_padding_mask: Optional[torch.Tensor] = None, - memory_key_padding_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - """Pass the inputs (and mask) through the decoder layer. - - Args: - tgt: - the sequence to the decoder layer (required). - memory: - the sequence from the last layer of the encoder (required). - tgt_mask: - the mask for the tgt sequence (optional). - memory_mask: - the mask for the memory sequence (optional). - tgt_key_padding_mask: - the mask for the tgt keys per batch (optional). - memory_key_padding_mask: - the mask for the memory keys per batch (optional). - - Shape: - tgt: (T, N, E). - memory: (S, N, E). - tgt_mask: (T, T). - memory_mask: (T, S). - tgt_key_padding_mask: (N, T). - memory_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, - N is the batch size, E is the feature number - """ - residual = tgt - if self.normalize_before: - tgt = self.norm1(tgt) - tgt2 = self.self_attn( - tgt, - tgt, - tgt, - attn_mask=tgt_mask, - key_padding_mask=tgt_key_padding_mask, - )[0] - tgt = residual + self.dropout1(tgt2) - if not self.normalize_before: - tgt = self.norm1(tgt) - - residual = tgt - if self.normalize_before: - tgt = self.norm2(tgt) - tgt2 = self.src_attn( - tgt, - memory, - memory, - attn_mask=memory_mask, - key_padding_mask=memory_key_padding_mask, - )[0] - tgt = residual + self.dropout2(tgt2) - if not self.normalize_before: - tgt = self.norm2(tgt) - - residual = tgt - if self.normalize_before: - tgt = self.norm3(tgt) - tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) - tgt = residual + self.dropout3(tgt2) - if not self.normalize_before: - tgt = self.norm3(tgt) - return tgt - - -def _get_activation_fn(activation: str): - if activation == "relu": - return nn.functional.relu - elif activation == "gelu": - return nn.functional.gelu - - raise RuntimeError( - "activation should be relu/gelu, not {}".format(activation) - ) - - -class PositionalEncoding(nn.Module): - """This class implements the positional encoding - proposed in the following paper: - - - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf - - PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) - PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) - - Note:: - - 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) - = exp(-1* 2i / d_model * log(100000)) - = exp(2i * -(log(10000) / d_model)) - """ - - def __init__(self, d_model: int, dropout: float = 0.1) -> None: - """ - Args: - d_model: - Embedding dimension. - dropout: - Dropout probability to be applied to the output of this module. - """ - super().__init__() - self.d_model = d_model - self.xscale = math.sqrt(self.d_model) - self.dropout = nn.Dropout(p=dropout) - # not doing: self.pe = None because of errors thrown by torchscript - self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) - - def extend_pe(self, x: torch.Tensor) -> None: - """Extend the time t in the positional encoding if required. - - The shape of `self.pe` is (1, T1, d_model). The shape of the input x - is (N, T, d_model). If T > T1, then we change the shape of self.pe - to (N, T, d_model). Otherwise, nothing is done. - - Args: - x: - It is a tensor of shape (N, T, C). - Returns: - Return None. - """ - if self.pe is not None: - if self.pe.size(1) >= x.size(1): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) - position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) - div_term = torch.exp( - torch.arange(0, self.d_model, 2, dtype=torch.float32) - * -(math.log(10000.0) / self.d_model) - ) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) - # Now pe is of shape (1, T, d_model), where T is x.size(1) - self.pe = pe.to(device=x.device, dtype=x.dtype) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Add positional encoding. - - Args: - x: - Its shape is (N, T, C) - - Returns: - Return a tensor of shape (N, T, C) - """ - self.extend_pe(x) - x = x * self.xscale + self.pe[:, : x.size(1), :] - return self.dropout(x) - - -class Noam(object): - """ - Implements Noam optimizer. - - Proposed in - "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa - - Args: - params: - iterable of parameters to optimize or dicts defining parameter groups - model_size: - attention dimension of the transformer model - factor: - learning rate factor - warm_step: - warmup steps - """ - - def __init__( - self, - params, - model_size: int = 256, - factor: float = 10.0, - warm_step: int = 25000, - weight_decay=0, - ) -> None: - """Construct an Noam object.""" - self.optimizer = torch.optim.Adam( - params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay - ) - self._step = 0 - self.warmup = warm_step - self.factor = factor - self.model_size = model_size - self._rate = 0 - - @property - def param_groups(self): - """Return param_groups.""" - return self.optimizer.param_groups - - def step(self): - """Update parameters and rate.""" - self._step += 1 - rate = self.rate() - for p in self.optimizer.param_groups: - p["lr"] = rate - self._rate = rate - self.optimizer.step() - - def rate(self, step=None): - """Implement `lrate` above.""" - if step is None: - step = self._step - return ( - self.factor - * self.model_size ** (-0.5) - * min(step ** (-0.5), step * self.warmup ** (-1.5)) - ) - - def zero_grad(self): - """Reset gradient.""" - self.optimizer.zero_grad() - - def state_dict(self): - """Return state_dict.""" - return { - "_step": self._step, - "warmup": self.warmup, - "factor": self.factor, - "model_size": self.model_size, - "_rate": self._rate, - "optimizer": self.optimizer.state_dict(), - } - - def load_state_dict(self, state_dict): - """Load state_dict.""" - for key, value in state_dict.items(): - if key == "optimizer": - self.optimizer.load_state_dict(state_dict["optimizer"]) - else: - setattr(self, key, value) - - -def encoder_padding_mask( - max_len: int, supervisions: Optional[Supervisions] = None -) -> Optional[torch.Tensor]: - """Make mask tensor containing indexes of padded part. - - TODO:: - This function **assumes** that the model uses - a subsampling factor of 4. We should remove that - assumption later. - - Args: - max_len: - Maximum length of input features. - CAUTION: It is the length after subsampling. - supervisions: - Supervision in lhotse format. - See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa - (CAUTION: It contains length information, i.e., start and number of - frames, before subsampling) - - Returns: - Tensor: Mask tensor of dimension (batch_size, input_length), - True denote the masked indices. - """ - if supervisions is None: - return None - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"], - supervisions["num_frames"], - ), - 1, - ).to(torch.int32) - - lengths = [ - 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) - ] - for idx in range(supervision_segments.size(0)): - # Note: TorchScript doesn't allow to unpack tensors as tuples - sequence_idx = supervision_segments[idx, 0].item() - start_frame = supervision_segments[idx, 1].item() - num_frames = supervision_segments[idx, 2].item() - lengths[sequence_idx] = start_frame + num_frames - - lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] - bs = int(len(lengths)) - seq_range = torch.arange(0, max_len, dtype=torch.int64) - seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) - # Note: TorchScript doesn't implement Tensor.new() - seq_length_expand = torch.tensor( - lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype - ).unsqueeze(-1) - mask = seq_range_expand >= seq_length_expand - - return mask - - -def decoder_padding_mask( - ys_pad: torch.Tensor, ignore_id: int = -1 -) -> torch.Tensor: - """Generate a length mask for input. - - The masked position are filled with True, - Unmasked positions are filled with False. - - Args: - ys_pad: - padded tensor of dimension (batch_size, input_length). - ignore_id: - the ignored number (the padding number) in ys_pad - - Returns: - Tensor: - a bool tensor of the same shape as the input tensor. - """ - ys_mask = ys_pad == ignore_id - return ys_mask - - -def generate_square_subsequent_mask(sz: int) -> torch.Tensor: - """Generate a square mask for the sequence. The masked positions are - filled with float('-inf'). Unmasked positions are filled with float(0.0). - The mask can be used for masked self-attention. - - For instance, if sz is 3, it returns:: - - tensor([[0., -inf, -inf], - [0., 0., -inf], - [0., 0., 0]]) - - Args: - sz: mask size - - Returns: - A square mask of dimension (sz, sz) - """ - mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) - mask = ( - mask.float() - .masked_fill(mask == 0, float("-inf")) - .masked_fill(mask == 1, float(0.0)) - ) - return mask - - -def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: - """Prepend sos_id to each utterance. - - Args: - token_ids: - A list-of-list of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - sos_id: - The ID of the SOS token. - - Return: - Return a new list-of-list, where each sublist starts - with SOS ID. - """ - return [[sos_id] + utt for utt in token_ids] - - -def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: - """Append eos_id to each utterance. - - Args: - token_ids: - A list-of-list of token IDs. Each sublist contains - token IDs (e.g., word piece IDs) of an utterance. - eos_id: - The ID of the EOS token. - - Return: - Return a new list-of-list, where each sublist ends - with EOS ID. - """ - return [utt + [eos_id] for utt in token_ids] - - -def tolist(t: torch.Tensor) -> List[int]: - """Used by jit""" - return torch.jit.annotate(List[int], t.tolist()) diff --git a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/README.md b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/README.md deleted file mode 100644 index 94d4ed6a3..000000000 --- a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/README.md +++ /dev/null @@ -1,4 +0,0 @@ - -Please visit - -for how to run this recipe. diff --git a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/__init__.py b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py deleted file mode 100644 index 7abe169d1..000000000 --- a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ /dev/null @@ -1,286 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from functools import lru_cache -from pathlib import Path - -from tqdm import tqdm - -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - BucketingSampler, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PerturbSpeed, - PrecomputedFeatures, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class Resample16kHz: - def __call__(self, cuts: CutSet) -> CutSet: - return cuts.resample(16000).with_recording_path_prefix('download') - - -class AsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/manifests"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the BucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=True, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=8, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: - logging.info("About to get Musan cuts") - cuts_musan = load_manifest( - self.args.manifest_dir / "musan_cuts.jsonl.gz" - ) - - input_strategy = PrecomputedFeatures() - if self.args.on_the_fly_feats: - input_strategy = OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80, sampling_rate=16000)), - ) - - train = K2SpeechRecognitionDataset( - input_strategy=input_strategy, - cut_transforms=[ - PerturbSpeed(factors=[0.9, 1.1], p=2 / 3, preserve_id=True), - Resample16kHz(), - CutMix( - cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True - ), - ], - input_transforms=[ - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=2, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ], - return_cuts=True, - ) - - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - drop_last=True, - ) - train_sampler.filter(lambda cut: 1.0 <= cut.duration <= 15.0) - - logging.info("About to create train dataloader") - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - - logging.info("About to create dev dataset") - input_strategy = PrecomputedFeatures() - if self.args.on_the_fly_feats: - input_strategy = OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80, sampling_rate=16000)), - ) - - validate = K2SpeechRecognitionDataset( - return_cuts=True, - input_strategy=input_strategy, - cut_transforms=[ - Resample16kHz(), - ], - ) - - valid_sampler = BucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - - input_strategy = PrecomputedFeatures() - if self.args.on_the_fly_feats: - input_strategy = OnTheFlyFeatures( - Fbank(FbankConfig(num_mel_bins=80, sampling_rate=16000)), - ) - - test = K2SpeechRecognitionDataset( - return_cuts=True, - input_strategy=input_strategy, - cut_transforms=[ - Resample16kHz(), - ], - ) - sampler = BucketingSampler( - cuts, max_duration=self.args.max_duration, shuffle=False - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train Fisher + SWBD cuts") - return load_manifest_lazy( - self.args.manifest_dir - / "train_utterances_fisher-swbd_cuts.jsonl.gz" - ) - - @lru_cache() - def dev_cuts(self) -> CutSet: - logging.info("About to get dev Fisher + SWBD cuts") - return load_manifest_lazy( - self.args.manifest_dir / "dev_utterances_fisher-swbd_cuts.jsonl.gz" - ) - - @lru_cache() - def test_cuts(self) -> CutSet: - logging.info("About to get test-clean cuts") - raise NotImplemented - - -def test(): - parser = argparse.ArgumentParser() - AsrDataModule.add_arguments(parser) - args = parser.parse_args() - adm = AsrDataModule(args) - - cuts = adm.train_cuts() - dl = adm.train_dataloaders(cuts) - for i, batch in tqdm(enumerate(dl)): - if i == 100: - break - - cuts = adm.dev_cuts() - dl = adm.valid_dataloaders(cuts) - for i, batch in tqdm(enumerate(dl)): - if i == 100: - break - - -if __name__ == '__main__': - test() diff --git a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/decode.py b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/decode.py deleted file mode 100755 index 9c964c2aa..000000000 --- a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/decode.py +++ /dev/null @@ -1,505 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from model import TdnnLstm - -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.decode import ( - get_lattice, - nbest_decoding, - one_best_decoding, - rescore_with_n_best_list, - rescore_with_whole_lattice, -) -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - get_texts, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=19, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=5, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - parser.add_argument( - "--method", - type=str, - default="whole-lattice-rescoring", - help="""Decoding method. - Supported values are: - - (1) 1best. Extract the best path from the decoding lattice as the - decoding result. - - (2) nbest. Extract n paths from the decoding lattice; the path - with the highest score is the decoding result. - - (3) nbest-rescoring. Extract n paths from the decoding lattice, - rescore them with an n-gram LM (e.g., a 4-gram LM), the path with - the highest score is the decoding result. - - (4) whole-lattice-rescoring. Rescore the decoding lattice with an - n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice - is the decoding result. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help="""Number of paths for n-best based decoding method. - Used only when "method" is one of the following values: - nbest, nbest-rescoring - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""The scale to be applied to `lattice.scores`. - It's needed if you use any kinds of n-best based rescoring. - Used only when "method" is one of the following values: - nbest, nbest-rescoring - A smaller value results in more unique paths. - """, - ) - - parser.add_argument( - "--export", - type=str2bool, - default=False, - help="""When enabled, the averaged model is saved to - tdnn/exp/pretrained.pt. Note: only model.state_dict() is saved. - pretrained.pt contains a dict {"model": model.state_dict()}, - which can be loaded by `icefall.checkpoint.load_checkpoint()`. - """, - ) - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "exp_dir": Path("tdnn_lstm_ctc/exp/"), - "lang_dir": Path("data/lang_phone"), - "lm_dir": Path("data/lm"), - "feature_dim": 80, - "subsampling_factor": 3, - "search_beam": 20, - "output_beam": 5, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - "env_info": get_env_info(), - } - ) - return params - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - HLG: k2.Fsa, - batch: dict, - lexicon: Lexicon, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if no rescoring is used, the key is the string `no_rescore`. - If LM rescoring is used, the key is the string `lm_scale_xxx`, - where `xxx` is the value of `lm_scale`. An example key is - `lm_scale_0.7` - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - - - params.method is "1best", it uses 1best decoding without LM rescoring. - - params.method is "nbest", it uses nbest decoding without LM rescoring. - - params.method is "nbest-rescoring", it uses nbest LM rescoring. - - params.method is "whole-lattice-rescoring", it uses whole lattice LM - rescoring. - - model: - The neural model. - HLG: - The decoding graph. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - lexicon: - It contains word symbol table. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = HLG.device - feature = batch["inputs"] - assert feature.ndim == 3 - feature = feature.to(device) - # at entry, feature is (N, T, C) - - feature = feature.permute(0, 2, 1) # now feature is (N, C, T) - - nnet_output = model(feature) - # nnet_output is (N, T, C) - - supervisions = batch["supervisions"] - - supervision_segments = torch.stack( - ( - supervisions["sequence_idx"], - supervisions["start_frame"] // params.subsampling_factor, - supervisions["num_frames"] // params.subsampling_factor, - ), - 1, - ).to(torch.int32) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - ) - - if params.method in ["1best", "nbest"]: - if params.method == "1best": - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - key = "no_rescore" - else: - best_path = nbest_decoding( - lattice=lattice, - num_paths=params.num_paths, - use_double_scores=params.use_double_scores, - nbest_scale=params.nbest_scale, - ) - key = f"no_rescore-{params.num_paths}" - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - return {key: hyps} - - assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"] - - lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] - lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] - lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] - - if params.method == "nbest-rescoring": - best_path_dict = rescore_with_n_best_list( - lattice=lattice, - G=G, - num_paths=params.num_paths, - lm_scale_list=lm_scale_list, - nbest_scale=params.nbest_scale, - ) - else: - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=lm_scale_list, - ) - - ans = dict() - for lm_scale_str, best_path in best_path_dict.items(): - hyps = get_texts(best_path) - hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] - ans[lm_scale_str] = hyps - return ans - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - HLG: k2.Fsa, - lexicon: Lexicon, - G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. - lexicon: - It contains word symbol table. - G: - An LM. It is not None when params.method is "nbest-rescoring" - or "whole-lattice-rescoring". In general, the G in HLG - is a 3-gram LM, while this G is a 4-gram LM. - Returns: - Return a dict, whose key may be "no-rescore" if no LM rescoring - is used, or it may be "lm_scale_0.7" if LM rescoring is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - results = [] - - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - - hyps_dict = decode_one_batch( - params=params, - model=model, - HLG=HLG, - batch=batch, - lexicon=lexicon, - G=G, - ) - - for lm_scale, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for hyp_words, ref_text in zip(hyps, texts): - ref_words = ref_text.split() - this_batch.append((ref_words, hyp_words)) - - results[lm_scale].extend(this_batch) - - num_cuts += len(batch["supervisions"]["text"]) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[int], List[int]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - - setup_logger(f"{params.exp_dir}/log/log-decode") - logging.info("Decoding started") - logging.info(params) - - lexicon = Lexicon(params.lang_dir) - max_phone_id = max(lexicon.tokens) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) - HLG = HLG.to(device) - assert HLG.requires_grad is False - - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() - - if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]: - if not (params.lm_dir / "G_4_gram.pt").is_file(): - logging.info("Loading G_4_gram.fst.txt") - logging.warning("It may take 8 minutes.") - with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.word_table["#0"] - - G = k2.Fsa.from_openfst(f.read(), acceptor=False) - # G.aux_labels is not needed in later computations, so - # remove it here. - del G.aux_labels - # CAUTION: The following line is crucial. - # Arcs entering the back-off state have label equal to #0. - # We have to change it to 0 here. - G.labels[G.labels >= first_word_disambig_id] = 0 - # See https://github.com/k2-fsa/k2/issues/874 - # for why we need to set G.properties to None - G.__dict__["_properties"] = None - G = k2.Fsa.from_fsas([G]).to(device) - G = k2.arc_sort(G) - torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") - else: - logging.info("Loading pre-compiled G_4_gram.pt") - d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu") - G = k2.Fsa.from_dict(d).to(device) - - if params.method == "whole-lattice-rescoring": - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G = G.to(device) - - # G.lm_scores is used to replace HLG.lm_scores during - # LM rescoring. - G.lm_scores = G.scores.clone() - else: - G = None - - model = TdnnLstm( - num_features=params.feature_dim, - num_classes=max_phone_id + 1, # +1 for the blank symbol - subsampling_factor=params.subsampling_factor, - ) - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - if params.export: - logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save( - {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" - ) - return - - model.to(device) - model.eval() - - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) - test_other_dl = librispeech.test_dataloaders(test_other_cuts) - - test_sets = ["test-clean", "test-other"] - test_dl = [test_clean_dl, test_other_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - HLG=HLG, - lexicon=lexicon, - G=G, - ) - - save_results( - params=params, test_set_name=test_set, results_dict=results_dict - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/model.py b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/model.py deleted file mode 100644 index 5e04c11b4..000000000 --- a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/model.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.nn as nn - - -class TdnnLstm(nn.Module): - def __init__( - self, num_features: int, num_classes: int, subsampling_factor: int = 3 - ) -> None: - """ - Args: - num_features: - The input dimension of the model. - num_classes: - The output dimension of the model. - subsampling_factor: - It reduces the number of output frames by this factor. - """ - super().__init__() - self.num_features = num_features - self.num_classes = num_classes - self.subsampling_factor = subsampling_factor - self.tdnn = nn.Sequential( - nn.Conv1d( - in_channels=num_features, - out_channels=500, - kernel_size=3, - stride=1, - padding=1, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=500, affine=False), - nn.Conv1d( - in_channels=500, - out_channels=500, - kernel_size=3, - stride=1, - padding=1, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=500, affine=False), - nn.Conv1d( - in_channels=500, - out_channels=500, - kernel_size=3, - stride=self.subsampling_factor, # stride: subsampling_factor! - padding=1, - ), - nn.ReLU(inplace=True), - nn.BatchNorm1d(num_features=500, affine=False), - ) - self.lstms = nn.ModuleList( - [ - nn.LSTM(input_size=500, hidden_size=500, num_layers=1) - for _ in range(5) - ] - ) - self.lstm_bnorms = nn.ModuleList( - [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)] - ) - self.dropout = nn.Dropout(0.2) - self.linear = nn.Linear(in_features=500, out_features=self.num_classes) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x: - Its shape is [N, C, T] - - Returns: - The output tensor has shape [N, T, C] - """ - x = self.tdnn(x) - x = x.permute(2, 0, 1) # (N, C, T) -> (T, N, C) -> how LSTM expects it - for lstm, bnorm in zip(self.lstms, self.lstm_bnorms): - x_new, _ = lstm(x) - x_new = bnorm(x_new.permute(1, 2, 0)).permute( - 2, 0, 1 - ) # (T, N, C) -> (N, C, T) -> (T, N, C) - x_new = self.dropout(x_new) - x = x_new + x # skip connections - x = x.transpose( - 1, 0 - ) # (T, N, C) -> (N, T, C) -> linear expects "features" in the last dim - x = self.linear(x) - x = nn.functional.log_softmax(x, dim=-1) - return x diff --git a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/pretrained.py b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/pretrained.py deleted file mode 100755 index 2baeb6bba..000000000 --- a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/pretrained.py +++ /dev/null @@ -1,277 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from model import TdnnLstm -from torch.nn.utils.rnn import pad_sequence - -from icefall.decode import ( - get_lattice, - one_best_decoding, - rescore_with_whole_lattice, -) -from icefall.utils import AttributeDict, get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--words-file", - type=str, - required=True, - help="Path to words.txt", - ) - - parser.add_argument( - "--HLG", type=str, required=True, help="Path to HLG.pt." - ) - - parser.add_argument( - "--method", - type=str, - default="1best", - help="""Decoding method. - Possible values are: - (1) 1best - Use the best path as decoding output. Only - the transformer encoder output is used for decoding. - We call it HLG decoding. - (2) whole-lattice-rescoring - Use an LM to rescore the - decoding lattice and then use 1best to decode the - rescored lattice. - We call it HLG decoding + n-gram LM rescoring. - """, - ) - - parser.add_argument( - "--G", - type=str, - help="""An LM for rescoring. - Used only when method is - whole-lattice-rescoring. - It's usually a 4-gram LM. - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.8, - help=""" - Used only when method is whole-lattice-rescoring. - It specifies the scale for n-gram LM scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "feature_dim": 80, - "subsampling_factor": 3, - "num_classes": 72, - "sample_rate": 16000, - "search_beam": 20, - "output_beam": 5, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) - # We use only the first channel - ans.append(wave[0]) - return ans - - -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - params.update(vars(args)) - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = TdnnLstm( - num_features=params.feature_dim, - num_classes=params.num_classes, - subsampling_factor=params.subsampling_factor, - ) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"]) - model.to(device) - model.eval() - - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - if params.method == "whole-lattice-rescoring": - logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = G.to(device) - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G.lm_scores = G.scores.clone() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) - features = features.permute(0, 2, 1) # now features is (N, C, T) - - with torch.no_grad(): - nnet_output = model(features) - # nnet_output is (N, T, C) - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "1best": - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - elif params.method == "whole-lattice-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=[params.ngram_lm_scale], - ) - best_path = next(iter(best_path_dict.values())) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/train.py b/egs/fisher_swbd/ASR/tdnn_lstm_ctc/train.py deleted file mode 100755 index 7439e157a..000000000 --- a/egs/fisher_swbd/ASR/tdnn_lstm_ctc/train.py +++ /dev/null @@ -1,603 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -from pathlib import Path -from shutil import copyfile -from typing import Optional, Tuple - -import k2 -import torch -import torch.multiprocessing as mp -import torch.nn as nn -import torch.optim as optim -from asr_datamodule import LibriSpeechAsrDataModule -from lhotse.utils import fix_random_seed -from model import TdnnLstm -from torch import Tensor -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_norm_ -from torch.optim.lr_scheduler import StepLR -from torch.utils.tensorboard import SummaryWriter - -from icefall.checkpoint import load_checkpoint -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.graph_compiler import CtcTrainingGraphCompiler -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - MetricsTracker, - encode_supervisions, - setup_logger, - str2bool, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=20, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - tdnn_lstm_ctc/exp/epoch-{start_epoch-1}.pt - """, - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - is saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - exp_dir: It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - - - lang_dir: It contains language related input files such as - "lexicon.txt" - - - lr: It specifies the initial learning rate - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - weight_decay: The weight_decay for the optimizer. - - - subsampling_factor: The subsampling factor for the model. - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval` is 0 - - - beam_size: It is used in k2.ctc_loss - - - reduction: It is used in k2.ctc_loss - - - use_double_scores: It is used in k2.ctc_loss - """ - params = AttributeDict( - { - "exp_dir": Path("tdnn_lstm_ctc/exp"), - "lang_dir": Path("data/lang_phone"), - "lr": 1e-3, - "feature_dim": 80, - "weight_decay": 5e-4, - "subsampling_factor": 3, - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 10, - "reset_interval": 200, - "valid_interval": 1000, - "beam_size": 10, - "reduction": "sum", - "use_double_scores": True, - "env_info": get_env_info(), - } - ) - - return params - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, -) -> None: - """Load checkpoint from file. - - If params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. Otherwise, this function does nothing. - - Apart from loading state dict for `model`, `optimizer` and `scheduler`, - it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The learning rate scheduler we are using. - Returns: - Return None. - """ - if params.start_epoch <= 0: - return - - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: torch.optim.lr_scheduler._LRScheduler, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - batch: dict, - graph_compiler: CtcTrainingGraphCompiler, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute CTC loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of TdnnLstm in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - graph_compiler: - It is used to build a decoding graph from a ctc topo and training - transcript. The training transcript is contained in the given `batch`, - while the ctc topo is built when this compiler is instantiated. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = graph_compiler.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - feature = feature.permute(0, 2, 1) # now feature is (N, C, T) - assert feature.ndim == 3 - feature = feature.to(device) - - with torch.set_grad_enabled(is_training): - nnet_output = model(feature) - # nnet_output is (N, T, C) - - # NOTE: We need `encode_supervisions` to sort sequences with - # different duration in decreasing order, required by - # `k2.intersect_dense` called in `k2.ctc_loss` - supervisions = batch["supervisions"] - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - decoding_graph = graph_compiler.compile(texts) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=params.beam_size, - reduction=params.reduction, - use_double_scores=params.use_double_scores, - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - info["frames"] = supervision_segments[:, 2].sum().item() - info["loss"] = loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process. The validation loss - is saved in `params.valid_loss`. - """ - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=False, - ) - assert loss.requires_grad is False - - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - graph_compiler: CtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - graph_compiler: - It is used to convert transcripts to FSAs. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - loss, loss_info = compute_loss( - params=params, - model=model, - batch=batch, - graph_compiler=graph_compiler, - is_training=True, - ) - # summary stats. - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - optimizer.zero_grad() - loss.backward() - clip_grad_norm_(model.parameters(), 5.0, 2.0) - optimizer.step() - - if batch_idx % params.log_interval == 0: - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}" - ) - if batch_idx % params.log_interval == 0: - - if tb_writer is not None: - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary( - tb_writer, "train/tot_", params.batch_idx_train - ) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, - "train/valid_", - params.batch_idx_train, - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(42) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - logging.info(params) - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - lexicon = Lexicon(params.lang_dir) - max_phone_id = max(lexicon.tokens) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - - graph_compiler = CtcTrainingGraphCompiler(lexicon=lexicon, device=device) - - model = TdnnLstm( - num_features=params.feature_dim, - num_classes=max_phone_id + 1, # +1 for the blank symbol - subsampling_factor=params.subsampling_factor, - ) - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - model = DDP(model, device_ids=[rank]) - - optimizer = optim.AdamW( - model.parameters(), - lr=params.lr, - weight_decay=params.weight_decay, - ) - scheduler = StepLR(optimizer, step_size=8, gamma=0.1) - - if checkpoints: - optimizer.load_state_dict(checkpoints["optimizer"]) - scheduler.load_state_dict(checkpoints["scheduler"]) - - librispeech = LibriSpeechAsrDataModule(args) - - train_cuts = librispeech.train_clean_100_cuts() - if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() - train_dl = librispeech.train_dataloaders(train_cuts) - - valid_cuts = librispeech.dev_clean_cuts() - valid_cuts += librispeech.dev_other_cuts() - valid_dl = librispeech.valid_dataloaders(valid_cuts) - - for epoch in range(params.start_epoch, params.num_epochs): - train_dl.sampler.set_epoch(epoch) - - if epoch > params.start_epoch: - logging.info(f"epoch {epoch}, lr: {scheduler.get_last_lr()[0]}") - - if tb_writer is not None: - tb_writer.add_scalar( - "train/lr", - scheduler.get_last_lr()[0], - params.batch_idx_train, - ) - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - tb_writer=tb_writer, - world_size=world_size, - ) - - scheduler.step() - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - rank=rank, - ) - - logging.info("Done!") - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -if __name__ == "__main__": - main() From aad2b7940d16885ce707e97f839e1942f9524482 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 17 Jan 2022 23:02:52 +0000 Subject: [PATCH 08/12] Add basic descriptions and results --- egs/fisher_swbd/ASR/README.md | 4 +++ egs/fisher_swbd/ASR/RESULTS.md | 49 ++++++++++++++++++++++++++++++++ egs/fisher_swbd/ASR/prepare.sh | 52 ++++++++++++++++------------------ 3 files changed, 77 insertions(+), 28 deletions(-) create mode 100644 egs/fisher_swbd/ASR/README.md create mode 100644 egs/fisher_swbd/ASR/RESULTS.md diff --git a/egs/fisher_swbd/ASR/README.md b/egs/fisher_swbd/ASR/README.md new file mode 100644 index 000000000..3b7231a9e --- /dev/null +++ b/egs/fisher_swbd/ASR/README.md @@ -0,0 +1,4 @@ + +# Introduction + +This is an ASR recipe for Switchboard and Switchboard+Fisher corpora. \ No newline at end of file diff --git a/egs/fisher_swbd/ASR/RESULTS.md b/egs/fisher_swbd/ASR/RESULTS.md new file mode 100644 index 000000000..b4f1d0db2 --- /dev/null +++ b/egs/fisher_swbd/ASR/RESULTS.md @@ -0,0 +1,49 @@ +## Results + +### SWBD BPE training results (Conformer-CTC) + +#### 01-17-2022 + +This recipe is based on LibriSpeech. +Data preparation/normalization is a simplified version of the one found in Kaldi. +The data is resampled to 16kHz on-the-fly -- it's not needed, but makes it easier to combine with other corpora, +and likely doesn't affect the results too much. +The training set was only Switchboard, minus 20 held-out conversations (dev data, ~1h of speech). +This was tested only on the dev data. +We didn't tune the model, hparams, or language model in any special way vs. LibriSpeech recipe. +No rescoring was used (decoding method: "1best"). +The model was trained on a single A100 GPU (24GB RAM) for 2 days. + +WER (it includes `[LAUGHTER]`, `[NOISE]`, `[VOCALIZED-NOISE]` so the "real" WER is likely lower): + +10 epochs (avg 5) : 19.58% +20 epochs (avg 10): 12.61% +30 epochs (avg 20): 11.24% +35 epochs (avg 20): 10.96% +40 epochs (avg 20): 10.94% + +To reproduce the above result, use the following commands for training: + +``` +cd egs/librispeech/ASR/conformer_ctc +./prepare.sh --swbd-only true +export CUDA_VISIBLE_DEVICES="0" +./conformer_ctc/train.py \ + --lr-factor 1.25 \ + --max-duration 200 \ + --num-workers 14 \ + --lang-dir data/lang_bpe_500 \ + --num-epochs 40 +``` + +and the following command for decoding + +``` +python conformer_ctc/decode.py \ + --epoch 40 \ + --avg 20 \ + --method 1best +``` + +The tensorboard log for training is available at + diff --git a/egs/fisher_swbd/ASR/prepare.sh b/egs/fisher_swbd/ASR/prepare.sh index 9ef2d7363..4a23bff5c 100755 --- a/egs/fisher_swbd/ASR/prepare.sh +++ b/egs/fisher_swbd/ASR/prepare.sh @@ -5,25 +5,18 @@ set -eou pipefail nj=15 stage=-1 stop_stage=100 +swbd_only=false # We assume dl_dir (download dir) contains the following -# directories and files. If not, they will be downloaded -# by this script automatically. +# directories and files. Most of them can't be downloaded automatically +# as they are not publically available and require a license purchased +# from the LDC. # -# - $dl_dir/LibriSpeech -# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it. -# You can download them from https://www.openslr.org/12 +# - $dl_dir/{LDC2004S13,LDC2004T19,LDC2005S13,LDC2005T19} +# Fisher LDC packages. # -# - $dl_dir/lm -# This directory contains the following files downloaded from -# http://www.openslr.org/resources/11 -# -# - 3-gram.pruned.1e-7.arpa.gz -# - 3-gram.pruned.1e-7.arpa -# - 4-gram.arpa.gz -# - 4-gram.arpa -# - librispeech-vocab.txt -# - librispeech-lexicon.txt +# - $dl_dir/LDC97S62 +# Switchboard LDC audio package (transcripts are auto-downloaded) # # - $dl_dir/musan # This directory contains the following directories downloaded from @@ -81,18 +74,14 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then fi fi -if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ] && ! $swbd_only; then log "Stage 1: Prepare Fisher manifests" - # We assume that you have downloaded the LibriSpeech corpus - # to $dl_dir/LibriSpeech mkdir -p data/manifests/fisher lhotse prepare fisher-english --absolute-paths 1 $dl_dir data/manifests/fisher fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Prepare SWBD manifests" - # We assume that you have downloaded the LibriSpeech corpus - # to $dl_dir/LibriSpeech mkdir -p data/manifests/swbd lhotse prepare switchboard --absolute-paths 1 --omit-silence $dl_dir/LDC97S62 data/manifests/swbd fi @@ -113,14 +102,21 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then set -x # Combine Fisher and SWBD recordings and supervisions - lhotse combine \ - data/manifests/fisher/recordings.jsonl.gz \ - data/manifests/swbd/swbd_recordings.jsonl \ - data/manifests/fisher-swbd_recordings.jsonl.gz - lhotse combine \ - data/manifests/fisher/supervisions.jsonl.gz \ - data/manifests/swbd/swbd_supervisions.jsonl \ - data/manifests/fisher-swbd_supervisions.jsonl.gz + if $swbd_only; then + cp data/manifests/swbd/swbd_recordings.jsonl \ + data/manifests/fisher-swbd_recordings.jsonl.gz + cp data/manifests/swbd/swbd_supervisions.jsonl \ + data/manifests/fisher-swbd_supervisions.jsonl.gz + else + lhotse combine \ + data/manifests/fisher/recordings.jsonl.gz \ + data/manifests/swbd/swbd_recordings.jsonl \ + data/manifests/fisher-swbd_recordings.jsonl.gz + lhotse combine \ + data/manifests/fisher/supervisions.jsonl.gz \ + data/manifests/swbd/swbd_supervisions.jsonl \ + data/manifests/fisher-swbd_supervisions.jsonl.gz + fi # Normalize text and remove supervisions that are not useful / hard to handle. python local/normalize_and_filter_supervisions.py \ From d394e88020c1b6eaf2dd999b7406417e866e760e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 17 Jan 2022 23:08:48 +0000 Subject: [PATCH 09/12] black --- .../ASR/conformer_ctc/asr_datamodule.py | 4 +-- egs/fisher_swbd/ASR/conformer_ctc/decode.py | 4 ++- .../normalize_and_filter_supervisions.py | 3 ++ .../ASR/local/prepare_lang_g2pen.py | 30 ++++++++++++++----- 4 files changed, 31 insertions(+), 10 deletions(-) diff --git a/egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py b/egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py index 7abe169d1..f9d25b7de 100644 --- a/egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/fisher_swbd/ASR/conformer_ctc/asr_datamodule.py @@ -40,7 +40,7 @@ from icefall.utils import str2bool class Resample16kHz: def __call__(self, cuts: CutSet) -> CutSet: - return cuts.resample(16000).with_recording_path_prefix('download') + return cuts.resample(16000).with_recording_path_prefix("download") class AsrDataModule: @@ -282,5 +282,5 @@ def test(): break -if __name__ == '__main__': +if __name__ == "__main__": test() diff --git a/egs/fisher_swbd/ASR/conformer_ctc/decode.py b/egs/fisher_swbd/ASR/conformer_ctc/decode.py index 3185ff777..467ece3df 100755 --- a/egs/fisher_swbd/ASR/conformer_ctc/decode.py +++ b/egs/fisher_swbd/ASR/conformer_ctc/decode.py @@ -665,7 +665,9 @@ def main(): datamodule = AsrDataModule(args) fisher_swbd_dev_cuts = datamodule.dev_cuts() - fisher_swbd_dev_dataloader = datamodule.test_dataloaders(fisher_swbd_dev_cuts) + fisher_swbd_dev_dataloader = datamodule.test_dataloaders( + fisher_swbd_dev_cuts + ) test_sets = ["dev-fisher-swbd"] test_dl = [fisher_swbd_dev_dataloader] diff --git a/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py b/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py index 22f265c9d..cd65f1c86 100644 --- a/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py +++ b/egs/fisher_swbd/ASR/local/normalize_and_filter_supervisions.py @@ -17,6 +17,7 @@ def get_args(): return parser.parse_args() +# fmt: off class FisherSwbdNormalizer: """ Note: the functions "normalize" and "keep" implement the logic similar to @@ -118,6 +119,7 @@ class FisherSwbdNormalizer: text = self.whitespace_regexp.sub(" ", text).strip() return text +# fmt: on def keep(sup: SupervisionSegment) -> bool: @@ -181,6 +183,7 @@ def test(): print(normalizer.normalize(text)) print() + if __name__ == "__main__": # test() main() diff --git a/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py b/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py index c771dac4b..4768e1dc0 100755 --- a/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py +++ b/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py @@ -106,7 +106,6 @@ def get_g2p_sym2int(): return sym2int - def write_mapping(filename: str, sym2id: Dict[str, int]) -> None: """Write a symbol to ID mapping to a file. @@ -382,7 +381,17 @@ def main(): lexicon_filename = lang_dir / "lexicon.txt" sil_token = "SIL" sil_prob = 0.5 - special_symbols = ["[UNK]", "[BREATH]", "[COUGH]", "[LAUGHTER]", "[LIPSMACK]", "[NOISE]", "[SIGH]", "[SNEEZE]", "[VOCALIZED-NOISE]"] + special_symbols = [ + "[UNK]", + "[BREATH]", + "[COUGH]", + "[LAUGHTER]", + "[LIPSMACK]", + "[NOISE]", + "[SIGH]", + "[SNEEZE]", + "[VOCALIZED-NOISE]", + ] g2p = G2p() token2id = get_g2p_sym2int() @@ -407,8 +416,15 @@ def main(): ( word, [ - phn for phn in g2p(word) - if phn not in ("'", " ", "-", ",") # g2p_en has these symbols as phones + phn + for phn in g2p(word) + if phn + not in ( + "'", + " ", + "-", + ",", + ) # g2p_en has these symbols as phones ], ) for word in tqdm(vocab, desc="Processing vocab with G2P") @@ -437,9 +453,9 @@ def main(): token2id = dict(sorted(token2id.items(), key=lambda tpl: tpl[1])) print(token2id) word2id = {"": 0} - word2id.update({ - word: int(id_) for id_, (word, pron) in enumerate(lexicon, start=1) - }) + word2id.update( + {word: int(id_) for id_, (word, pron) in enumerate(lexicon, start=1)} + ) for symbol in ["", "", "#0"]: word2id[symbol] = len(word2id) From e76de3ba59df66b22fbdf93967c9cf4a2a16cd89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Mon, 17 Jan 2022 23:29:35 +0000 Subject: [PATCH 10/12] Fixes --- egs/fisher_swbd/ASR/local/prepare_lang_bpe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/fisher_swbd/ASR/local/prepare_lang_bpe.py b/egs/fisher_swbd/ASR/local/prepare_lang_bpe.py index be92710d2..b3f012e84 100755 --- a/egs/fisher_swbd/ASR/local/prepare_lang_bpe.py +++ b/egs/fisher_swbd/ASR/local/prepare_lang_bpe.py @@ -41,7 +41,7 @@ from typing import Dict, List, Tuple import k2 import sentencepiece as spm import torch -from prepare_lang import ( +from prepare_lang_g2pen import ( Lexicon, add_disambig_symbols, add_self_loops, From cb329d1342c67331b31042446db7d5e6ba0129b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 21 Jan 2022 19:20:57 +0000 Subject: [PATCH 11/12] fixes --- .../ASR/local/prepare_lang_g2pen.py | 19 ++++--------------- egs/fisher_swbd/ASR/prepare.sh | 8 ++++---- 2 files changed, 8 insertions(+), 19 deletions(-) diff --git a/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py b/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py index 4768e1dc0..0549d7306 100755 --- a/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py +++ b/egs/fisher_swbd/ASR/local/prepare_lang_g2pen.py @@ -17,21 +17,10 @@ """ -This script takes as input a lexicon file "data/lang_phone/lexicon.txt" -consisting of words and tokens (i.e., phones) and does the following: - -1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt - -2. Generate tokens.txt, the token table mapping a token to a unique integer. - -3. Generate words.txt, the word table mapping a word to a unique integer. - -4. Generate L.pt, in k2 format. It can be loaded by - - d = torch.load("L.pt") - lexicon = k2.Fsa.from_dict(d) - -5. Generate L_disambig.pt, in k2 format. +This script takes as input a wors.txt file "data/lang_phone/words.txt" +consisting of words and their IDs and creates a lexicon with g2p_en python package +(it's CMUdict based). It also creates rest of the files typically expected in a lang +dir, including L.pt and Linv.pt. """ import argparse import math diff --git a/egs/fisher_swbd/ASR/prepare.sh b/egs/fisher_swbd/ASR/prepare.sh index 4a23bff5c..0f2562507 100755 --- a/egs/fisher_swbd/ASR/prepare.sh +++ b/egs/fisher_swbd/ASR/prepare.sh @@ -103,10 +103,10 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then # Combine Fisher and SWBD recordings and supervisions if $swbd_only; then - cp data/manifests/swbd/swbd_recordings.jsonl \ - data/manifests/fisher-swbd_recordings.jsonl.gz - cp data/manifests/swbd/swbd_supervisions.jsonl \ - data/manifests/fisher-swbd_supervisions.jsonl.gz + gunzip -c data/manifests/swbd/swbd_recordings.jsonl \ + > data/manifests/fisher-swbd_recordings.jsonl.gz + gunzip -c data/manifests/swbd/swbd_supervisions.jsonl \ + > data/manifests/fisher-swbd_supervisions.jsonl.gz else lhotse combine \ data/manifests/fisher/recordings.jsonl.gz \ From ec8fa55bcf4894e63b8f8626d6f8a77aaceab82b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Piotr=20=C5=BBelasko?= Date: Fri, 21 Jan 2022 19:23:49 +0000 Subject: [PATCH 12/12] remove unneeded and unsupported files --- egs/fisher_swbd/ASR/conformer_ctc/README.md | 75 --- egs/fisher_swbd/ASR/conformer_ctc/ali.py | 399 ---------------- .../ASR/conformer_ctc/pretrained.py | 435 ------------------ .../ASR/local/generate_unique_lexicon.py | 100 ---- 4 files changed, 1009 deletions(-) delete mode 100644 egs/fisher_swbd/ASR/conformer_ctc/README.md delete mode 100755 egs/fisher_swbd/ASR/conformer_ctc/ali.py delete mode 100755 egs/fisher_swbd/ASR/conformer_ctc/pretrained.py delete mode 100755 egs/fisher_swbd/ASR/local/generate_unique_lexicon.py diff --git a/egs/fisher_swbd/ASR/conformer_ctc/README.md b/egs/fisher_swbd/ASR/conformer_ctc/README.md deleted file mode 100644 index 37ace4204..000000000 --- a/egs/fisher_swbd/ASR/conformer_ctc/README.md +++ /dev/null @@ -1,75 +0,0 @@ -## Introduction - -Please visit - -for how to run this recipe. - -## How to compute framewise alignment information - -### Step 1: Train a model - -Please use `conformer_ctc/train.py` to train a model. -See -for how to do it. - -### Step 2: Compute framewise alignment - -Run - -``` -# Choose a checkpoint and determine the number of checkpoints to average -epoch=30 -avg=15 -./conformer_ctc/ali.py \ - --epoch $epoch \ - --avg $avg \ - --max-duration 500 \ - --bucketing-sampler 0 \ - --full-libri 1 \ - --exp-dir conformer_ctc/exp \ - --lang-dir data/lang_bpe_500 \ - --ali-dir data/ali_500 -``` -and you will get four files inside the folder `data/ali_500`: - -``` -$ ls -lh data/ali_500 -total 546M --rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:06 test_clean.pt --rw-r--r-- 1 kuangfangjun root 1.1M Sep 28 08:07 test_other.pt --rw-r--r-- 1 kuangfangjun root 542M Sep 28 11:36 train-960.pt --rw-r--r-- 1 kuangfangjun root 2.1M Sep 28 11:38 valid.pt -``` - -**Note**: It can take more than 3 hours to compute the alignment -for the training dataset, which contains 960 * 3 = 2880 hours of data. - -**Caution**: The model parameters in `conformer_ctc/ali.py` have to match those -in `conformer_ctc/train.py`. - -**Caution**: You have to set the parameter `preserve_id` to `True` for `CutMix`. -Search `./conformer_ctc/asr_datamodule.py` for `preserve_id`. - -### Step 3: Check your extracted alignments - -There is a file `test_ali.py` in `icefall/test` that can be used to test your -alignments. It uses pre-computed alignments to modify a randomly generated -`nnet_output` and it checks that we can decode the correct transcripts -from the resulting `nnet_output`. - -You should get something like the following if you run that script: - -``` -$ ./test/test_ali.py -['THE GOOD NATURED AUDIENCE IN PITY TO FALLEN MAJESTY SHOWED FOR ONCE GREATER DEFERENCE TO THE KING THAN TO THE MINISTER AND SUNG THE PSALM WHICH THE FORMER HAD CALLED FOR', 'THE OLD SERVANT TOLD HIM QUIETLY AS THEY CREPT BACK TO DWELL THAT THIS PASSAGE THAT LED FROM THE HUT IN THE PLEASANCE TO SHERWOOD AND THAT GEOFFREY FOR THE TIME WAS HIDING WITH THE OUTLAWS IN THE FOREST', 'FOR A WHILE SHE LAY IN HER CHAIR IN HAPPY DREAMY PLEASURE AT SUN AND BIRD AND TREE', "BUT THE ESSENCE OF LUTHER'S LECTURES IS THERE"] -['THE GOOD NATURED AUDIENCE IN PITY TO FALLEN MAJESTY SHOWED FOR ONCE GREATER DEFERENCE TO THE KING THAN TO THE MINISTER AND SUNG THE PSALM WHICH THE FORMER HAD CALLED FOR', 'THE OLD SERVANT TOLD HIM QUIETLY AS THEY CREPT BACK TO GAMEWELL THAT THIS PASSAGE WAY LED FROM THE HUT IN THE PLEASANCE TO SHERWOOD AND THAT GEOFFREY FOR THE TIME WAS HIDING WITH THE OUTLAWS IN THE FOREST', 'FOR A WHILE SHE LAY IN HER CHAIR IN HAPPY DREAMY PLEASURE AT SUN AND BIRD AND TREE', "BUT THE ESSENCE OF LUTHER'S LECTURES IS THERE"] -``` - -### Step 4: Use your alignments in training - -Please refer to `conformer_mmi/train.py` for usage. Some useful -functions are: - -- `load_alignments()`, it loads alignment saved by `conformer_ctc/ali.py` -- `convert_alignments_to_tensor()`, it converts alignments to PyTorch tensors -- `lookup_alignments()`, it returns the alignments of utterances by giving the cut ID of the utterances. diff --git a/egs/fisher_swbd/ASR/conformer_ctc/ali.py b/egs/fisher_swbd/ASR/conformer_ctc/ali.py deleted file mode 100755 index 42fa2308e..000000000 --- a/egs/fisher_swbd/ASR/conformer_ctc/ali.py +++ /dev/null @@ -1,399 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: - ./conformer_ctc/ali.py \ - --exp-dir ./conformer_ctc/exp \ - --lang-dir ./data/lang_bpe_500 \ - --epoch 20 \ - --avg 10 \ - --max-duration 300 \ - --dataset train-clean-100 \ - --out-dir data/ali -""" - -import argparse -import logging -from pathlib import Path - -import k2 -import numpy as np -import torch -from asr_datamodule import LibriSpeechAsrDataModule -from conformer import Conformer -from lhotse import CutSet -from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer - -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.checkpoint import average_checkpoints, load_checkpoint -from icefall.decode import one_best_decoding -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - encode_supervisions, - get_alignments, - setup_logger, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=34, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - parser.add_argument( - "--avg", - type=int, - default=20, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_bpe_500", - help="The lang dir", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conformer_ctc/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--out-dir", - type=str, - required=True, - help="""Output directory. - It contains 3 generated files: - - - labels_xxx.h5 - - aux_labels_xxx.h5 - - cuts_xxx.json.gz - - where xxx is the value of `--dataset`. For instance, if - `--dataset` is `train-clean-100`, it will contain 3 files: - - - `labels_train-clean-100.h5` - - `aux_labels_train-clean-100.h5` - - `cuts_train-clean-100.json.gz` - - Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise - alignment. The difference is that labels_xxx.h5 contains repeats. - """, - ) - - parser.add_argument( - "--dataset", - type=str, - required=True, - help="""The name of the dataset to compute alignments for. - Possible values are: - - test-clean. - - test-other - - train-clean-100 - - train-clean-360 - - train-other-500 - - dev-clean - - dev-other - """, - ) - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "lm_dir": Path("data/lm"), - "feature_dim": 80, - "nhead": 8, - "attention_dim": 512, - "subsampling_factor": 4, - # Set it to 0 since attention decoder - # is not used for computing alignments - "num_decoder_layers": 0, - "vgg_frontend": False, - "use_feat_batchnorm": True, - "output_beam": 10, - "use_double_scores": True, - "env_info": get_env_info(), - } - ) - return params - - -def compute_alignments( - model: torch.nn.Module, - dl: torch.utils.data.DataLoader, - labels_writer: FeaturesWriter, - aux_labels_writer: FeaturesWriter, - params: AttributeDict, - graph_compiler: BpeCtcTrainingGraphCompiler, -) -> CutSet: - """Compute the framewise alignments of a dataset. - - Args: - model: - The neural network model. - dl: - Dataloader containing the dataset. - params: - Parameters for computing alignments. - graph_compiler: - It converts token IDs to decoding graphs. - Returns: - Return a CutSet. Each cut has two custom fields: labels_alignment - and aux_labels_alignment, containing framewise alignments information. - Both are of type `lhotse.array.TemporalArray`. The difference between - the two alignments is that `labels_alignment` contain repeats. - """ - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - num_cuts = 0 - - device = graph_compiler.device - cuts = [] - for batch_idx, batch in enumerate(dl): - feature = batch["inputs"] - - # at entry, feature is [N, T, C] - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - cut_list = supervisions["cut"] - - for cut in cut_list: - assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}" - - nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is [N, T, C] - supervision_segments, texts = encode_supervisions( - supervisions, subsampling_factor=params.subsampling_factor - ) - # we need also to sort cut_ids as encode_supervisions() - # reorders "texts". - # In general, new2old is an identity map since lhotse sorts the returned - # cuts by duration in descending order - new2old = supervision_segments[:, 0].tolist() - - cut_list = [cut_list[i] for i in new2old] - - token_ids = graph_compiler.texts_to_ids(texts) - decoding_graph = graph_compiler.compile(token_ids) - - dense_fsa_vec = k2.DenseFsaVec( - nnet_output, - supervision_segments, - allow_truncate=params.subsampling_factor - 1, - ) - - lattice = k2.intersect_dense( - decoding_graph, - dense_fsa_vec, - params.output_beam, - ) - - best_path = one_best_decoding( - lattice=lattice, - use_double_scores=params.use_double_scores, - ) - - labels_ali = get_alignments(best_path, kind="labels") - aux_labels_ali = get_alignments(best_path, kind="aux_labels") - assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) - for cut, labels, aux_labels in zip( - cut_list, labels_ali, aux_labels_ali - ): - cut.labels_alignment = labels_writer.store_array( - key=cut.id, - value=np.asarray(labels, dtype=np.int32), - # frame shift is 0.01s, subsampling_factor is 4 - frame_shift=0.04, - temporal_dim=0, - start=0, - ) - cut.aux_labels_alignment = aux_labels_writer.store_array( - key=cut.id, - value=np.asarray(aux_labels, dtype=np.int32), - # frame shift is 0.01s, subsampling_factor is 4 - frame_shift=0.04, - temporal_dim=0, - start=0, - ) - - cuts += cut_list - - num_cuts += len(cut_list) - - if batch_idx % 100 == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info( - f"batch {batch_str}, cuts processed until now is {num_cuts}" - ) - - return CutSet.from_cuts(cuts) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - - args.enable_spec_aug = False - args.enable_musan = False - args.return_cuts = True - args.concatenate_cuts = False - - params = get_params() - params.update(vars(args)) - - setup_logger(f"{params.exp_dir}/log-ali") - - logging.info(f"Computing alignments for {params.dataset} - started") - logging.info(params) - - out_dir = Path(params.out_dir) - out_dir.mkdir(exist_ok=True) - - out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" - out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" - out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz" - - for f in ( - out_labels_ali_filename, - out_aux_labels_ali_filename, - out_manifest_filename, - ): - if f.exists(): - logging.info(f"{f} exists - skipping") - return - - lexicon = Lexicon(params.lang_dir) - max_token_id = max(lexicon.tokens) - num_classes = max_token_id + 1 # +1 for the blank - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - logging.info(f"device: {device}") - - graph_compiler = BpeCtcTrainingGraphCompiler( - params.lang_dir, - device=device, - sos_token="", - eos_token="", - ) - - logging.info("About to create model") - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - model.to(device) - - if params.avg == 1: - load_checkpoint( - f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False - ) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - - model.eval() - - librispeech = LibriSpeechAsrDataModule(args) - if params.dataset == "test-clean": - test_clean_cuts = librispeech.test_clean_cuts() - dl = librispeech.test_dataloaders(test_clean_cuts) - elif params.dataset == "test-other": - test_other_cuts = librispeech.test_other_cuts() - dl = librispeech.test_dataloaders(test_other_cuts) - elif params.dataset == "train-clean-100": - train_clean_100_cuts = librispeech.train_clean_100_cuts() - dl = librispeech.train_dataloaders(train_clean_100_cuts) - elif params.dataset == "train-clean-360": - train_clean_360_cuts = librispeech.train_clean_360_cuts() - dl = librispeech.train_dataloaders(train_clean_360_cuts) - elif params.dataset == "train-other-500": - train_other_500_cuts = librispeech.train_other_500_cuts() - dl = librispeech.train_dataloaders(train_other_500_cuts) - elif params.dataset == "dev-clean": - dev_clean_cuts = librispeech.dev_clean_cuts() - dl = librispeech.valid_dataloaders(dev_clean_cuts) - else: - assert params.dataset == "dev-other", f"{params.dataset}" - dev_other_cuts = librispeech.dev_other_cuts() - dl = librispeech.valid_dataloaders(dev_other_cuts) - - logging.info(f"Processing {params.dataset}") - with NumpyHdf5Writer(out_labels_ali_filename) as labels_writer: - with NumpyHdf5Writer(out_aux_labels_ali_filename) as aux_labels_writer: - cut_set = compute_alignments( - model=model, - dl=dl, - labels_writer=labels_writer, - aux_labels_writer=aux_labels_writer, - params=params, - graph_compiler=graph_compiler, - ) - - cut_set.to_file(out_manifest_filename) - - logging.info( - f"For dataset {params.dataset}, its alignments with repeats are " - f"saved to {out_labels_ali_filename}, the alignments without repeats " - f"are saved to {out_aux_labels_ali_filename}, and the cut manifest " - f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}" - ) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/fisher_swbd/ASR/conformer_ctc/pretrained.py b/egs/fisher_swbd/ASR/conformer_ctc/pretrained.py deleted file mode 100755 index 28724e1eb..000000000 --- a/egs/fisher_swbd/ASR/conformer_ctc/pretrained.py +++ /dev/null @@ -1,435 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from conformer import Conformer -from torch.nn.utils.rnn import pad_sequence - -from icefall.decode import ( - get_lattice, - one_best_decoding, - rescore_with_attention_decoder, - rescore_with_whole_lattice, -) -from icefall.utils import AttributeDict, get_texts - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--words-file", - type=str, - help="""Path to words.txt. - Used only when method is not ctc-decoding. - """, - ) - - parser.add_argument( - "--HLG", - type=str, - help="""Path to HLG.pt. - Used only when method is not ctc-decoding. - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model. - Used only when method is ctc-decoding. - """, - ) - - parser.add_argument( - "--method", - type=str, - default="1best", - help="""Decoding method. - Possible values are: - (0) ctc-decoding - Use CTC decoding. It uses a sentence - piece model, i.e., lang_dir/bpe.model, to convert - word pieces to words. It needs neither a lexicon - nor an n-gram LM. - (1) 1best - Use the best path as decoding output. Only - the transformer encoder output is used for decoding. - We call it HLG decoding. - (2) whole-lattice-rescoring - Use an LM to rescore the - decoding lattice and then use 1best to decode the - rescored lattice. - We call it HLG decoding + n-gram LM rescoring. - (3) attention-decoder - Extract n paths from the rescored - lattice and use the transformer attention decoder for - rescoring. - We call it HLG decoding + n-gram LM rescoring + attention - decoder rescoring. - """, - ) - - parser.add_argument( - "--G", - type=str, - help="""An LM for rescoring. - Used only when method is - whole-lattice-rescoring or attention-decoder. - It's usually a 4-gram LM. - """, - ) - - parser.add_argument( - "--num-paths", - type=int, - default=100, - help=""" - Used only when method is attention-decoder. - It specifies the size of n-best list.""", - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=1.3, - help=""" - Used only when method is whole-lattice-rescoring and attention-decoder. - It specifies the scale for n-gram LM scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "--attention-decoder-scale", - type=float, - default=1.2, - help=""" - Used only when method is attention-decoder. - It specifies the scale for attention decoder scores. - (Note: You need to tune it on a dataset.) - """, - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help=""" - Used only when method is attention-decoder. - It specifies the scale for lattice.scores when - extracting n-best lists. A smaller value results in - more unique number of paths with the risk of missing - the best path. - """, - ) - - parser.add_argument( - "--sos-id", - type=int, - default=1, - help=""" - Used only when method is attention-decoder. - It specifies ID for the SOS token. - """, - ) - - parser.add_argument( - "--num-classes", - type=int, - default=500, - help=""" - Vocab size in the BPE model. - """, - ) - - parser.add_argument( - "--eos-id", - type=int, - default=1, - help=""" - Used only when method is attention-decoder. - It specifies ID for the EOS token. - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - return parser - - -def get_params() -> AttributeDict: - params = AttributeDict( - { - "sample_rate": 16000, - # parameters for conformer - "subsampling_factor": 4, - "vgg_frontend": False, - "use_feat_batchnorm": True, - "feature_dim": 80, - "nhead": 8, - "attention_dim": 512, - "num_decoder_layers": 6, - # parameters for decoding - "search_beam": 20, - "output_beam": 8, - "min_active_states": 30, - "max_active_states": 10000, - "use_double_scores": True, - } - ) - return params - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert sample_rate == expected_sample_rate, ( - f"expected sample rate: {expected_sample_rate}. " - f"Given: {sample_rate}" - ) - # We use only the first channel - ans.append(wave[0]) - return ans - - -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - if args.method != "attention-decoder": - # to save memory as the attention decoder - # will not be used - params.num_decoder_layers = 0 - - params.update(vars(args)) - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = Conformer( - num_features=params.feature_dim, - nhead=params.nhead, - d_model=params.attention_dim, - num_classes=params.num_classes, - subsampling_factor=params.subsampling_factor, - num_decoder_layers=params.num_decoder_layers, - vgg_frontend=params.vgg_frontend, - use_feat_batchnorm=params.use_feat_batchnorm, - ) - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - - features = pad_sequence( - features, batch_first=True, padding_value=math.log(1e-10) - ) - - # Note: We don't use key padding mask for attention during decoding - with torch.no_grad(): - nnet_output, memory, memory_key_padding_mask = model(features) - - batch_size = nnet_output.shape[0] - supervision_segments = torch.tensor( - [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], - dtype=torch.int32, - ) - - if params.method == "ctc-decoding": - logging.info("Use CTC decoding") - bpe_model = spm.SentencePieceProcessor() - bpe_model.load(params.bpe_model) - max_token_id = params.num_classes - 1 - - H = k2.ctc_topo( - max_token=max_token_id, - modified=False, - device=device, - ) - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=H, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - token_ids = get_texts(best_path) - hyps = bpe_model.decode(token_ids) - hyps = [s.split() for s in hyps] - elif params.method in [ - "1best", - "whole-lattice-rescoring", - "attention-decoder", - ]: - logging.info(f"Loading HLG from {params.HLG}") - HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) - HLG = HLG.to(device) - if not hasattr(HLG, "lm_scores"): - # For whole-lattice-rescoring and attention-decoder - HLG.lm_scores = HLG.scores.clone() - - if params.method in [ - "whole-lattice-rescoring", - "attention-decoder", - ]: - logging.info(f"Loading G from {params.G}") - G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) - # Add epsilon self-loops to G as we will compose - # it with the whole lattice later - G = G.to(device) - G = k2.add_epsilon_self_loops(G) - G = k2.arc_sort(G) - G.lm_scores = G.scores.clone() - - lattice = get_lattice( - nnet_output=nnet_output, - decoding_graph=HLG, - supervision_segments=supervision_segments, - search_beam=params.search_beam, - output_beam=params.output_beam, - min_active_states=params.min_active_states, - max_active_states=params.max_active_states, - subsampling_factor=params.subsampling_factor, - ) - - if params.method == "1best": - logging.info("Use HLG decoding") - best_path = one_best_decoding( - lattice=lattice, use_double_scores=params.use_double_scores - ) - elif params.method == "whole-lattice-rescoring": - logging.info("Use HLG decoding + LM rescoring") - best_path_dict = rescore_with_whole_lattice( - lattice=lattice, - G_with_epsilon_loops=G, - lm_scale_list=[params.ngram_lm_scale], - ) - best_path = next(iter(best_path_dict.values())) - elif params.method == "attention-decoder": - logging.info("Use HLG + LM rescoring + attention decoder rescoring") - rescored_lattice = rescore_with_whole_lattice( - lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None - ) - best_path_dict = rescore_with_attention_decoder( - lattice=rescored_lattice, - num_paths=params.num_paths, - model=model, - memory=memory, - memory_key_padding_mask=memory_key_padding_mask, - sos_id=params.sos_id, - eos_id=params.eos_id, - nbest_scale=params.nbest_scale, - ngram_lm_scale=params.ngram_lm_scale, - attention_scale=params.attention_decoder_scale, - ) - best_path = next(iter(best_path_dict.values())) - - hyps = get_texts(best_path) - word_sym_table = k2.SymbolTable.from_file(params.words_file) - hyps = [[word_sym_table[i] for i in ids] for ids in hyps] - else: - raise ValueError(f"Unsupported decoding method: {params.method}") - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/fisher_swbd/ASR/local/generate_unique_lexicon.py b/egs/fisher_swbd/ASR/local/generate_unique_lexicon.py deleted file mode 100755 index 566c0743d..000000000 --- a/egs/fisher_swbd/ASR/local/generate_unique_lexicon.py +++ /dev/null @@ -1,100 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file takes as input a lexicon.txt and output a new lexicon, -in which each word has a unique pronunciation. - -The way to do this is to keep only the first pronunciation of a word -in lexicon.txt. -""" - - -import argparse -import logging -from pathlib import Path -from typing import List, Tuple - -from icefall.lexicon import read_lexicon, write_lexicon - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--lang-dir", - type=str, - help="""Input and output directory. - It should contain a file lexicon.txt. - This file will generate a new file uniq_lexicon.txt - in it. - """, - ) - - return parser.parse_args() - - -def filter_multiple_pronunications( - lexicon: List[Tuple[str, List[str]]] -) -> List[Tuple[str, List[str]]]: - """Remove multiple pronunciations of words from a lexicon. - - If a word has more than one pronunciation in the lexicon, only - the first one is kept, while other pronunciations are removed - from the lexicon. - - Args: - lexicon: - The input lexicon, containing a list of (word, [p1, p2, ..., pn]), - where "p1, p2, ..., pn" are the pronunciations of the "word". - Returns: - Return a new lexicon where each word has a unique pronunciation. - """ - seen = set() - ans = [] - - for word, tokens in lexicon: - if word in seen: - continue - seen.add(word) - ans.append((word, tokens)) - return ans - - -def main(): - args = get_args() - lang_dir = Path(args.lang_dir) - - lexicon_filename = lang_dir / "lexicon.txt" - - in_lexicon = read_lexicon(lexicon_filename) - - out_lexicon = filter_multiple_pronunications(in_lexicon) - - write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon) - - logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}") - logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}") - - -if __name__ == "__main__": - formatter = ( - "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - ) - - logging.basicConfig(format=formatter, level=logging.INFO) - - main()