From b293db4baf1606cfe95066cf28ffde56173a7ddb Mon Sep 17 00:00:00 2001 From: Daniil Date: Tue, 13 Dec 2022 03:13:26 -0500 Subject: [PATCH] Tedlium3 conformer ctc2 (#696) * modify preparation * small refacor * add tedlium3 conformer_ctc2 * modify decode * filter unk in decode * add scaling converter * address comments * fix lambda function lhotse * add implicit manifest shuffle * refactor ctc_greedy_search * import model arguments from train.py * style fix * fix ci test and last style issues * update RESULTS * fix RESULTS numbers * fix label smoothing loss * update model parameters number in RESULTS --- .../ASR/conformer_ctc/label_smoothing.py | 3 +- .../ASR/conformer_ctc2/subsampling.py | 5 +- .../emformer2.py | 4 +- egs/librispeech/ASR/local/compile_hlg.py | 2 +- .../ASR/local/compute_fbank_musan.py | 8 +- egs/librispeech/ASR/local/prepare_lang_bpe.py | 23 +- .../pruned_transducer_stateless2/scaling.py | 20 +- .../scaling_converter.py | 2 +- egs/tedlium3/ASR/RESULTS.md | 83 ++ egs/tedlium3/ASR/conformer_ctc2/__init__.py | 0 .../ASR/conformer_ctc2/asr_datamodule.py | 1 + egs/tedlium3/ASR/conformer_ctc2/attention.py | 201 +++ egs/tedlium3/ASR/conformer_ctc2/combiner.py | 244 ++++ egs/tedlium3/ASR/conformer_ctc2/conformer.py | 1033 ++++++++++++++++ egs/tedlium3/ASR/conformer_ctc2/decode.py | 899 ++++++++++++++ egs/tedlium3/ASR/conformer_ctc2/export.py | 294 +++++ .../ASR/conformer_ctc2/label_smoothing.py | 1 + egs/tedlium3/ASR/conformer_ctc2/lstmp.py | 1 + egs/tedlium3/ASR/conformer_ctc2/optim.py | 1 + egs/tedlium3/ASR/conformer_ctc2/scaling.py | 1 + .../ASR/conformer_ctc2/scaling_converter.py | 1 + .../ASR/conformer_ctc2/subsampling.py | 1 + egs/tedlium3/ASR/conformer_ctc2/train.py | 1061 ++++++++++++++++ .../ASR/conformer_ctc2/transformer.py | 1093 +++++++++++++++++ .../convert_transcript_words_to_bpe_ids.py | 42 +- .../convert_transcript_words_to_tokens.py | 1 - .../ASR/local/generate_unique_lexicon.py | 1 - egs/tedlium3/ASR/local/prepare_lang.py | 1 - egs/tedlium3/ASR/local/prepare_lexicon.py | 94 -- egs/tedlium3/ASR/local/prepare_transcripts.py | 66 +- egs/tedlium3/ASR/local/prepare_words.py | 83 ++ egs/tedlium3/ASR/local/test_prepare_lang.py | 1 - egs/tedlium3/ASR/prepare.sh | 98 +- icefall/decode.py | 2 - test/test_lexicon.py | 2 +- 35 files changed, 5158 insertions(+), 215 deletions(-) create mode 100755 egs/tedlium3/ASR/conformer_ctc2/__init__.py create mode 120000 egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py create mode 100644 egs/tedlium3/ASR/conformer_ctc2/attention.py create mode 100644 egs/tedlium3/ASR/conformer_ctc2/combiner.py create mode 100644 egs/tedlium3/ASR/conformer_ctc2/conformer.py create mode 100755 egs/tedlium3/ASR/conformer_ctc2/decode.py create mode 100755 egs/tedlium3/ASR/conformer_ctc2/export.py create mode 120000 egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py create mode 120000 egs/tedlium3/ASR/conformer_ctc2/lstmp.py create mode 120000 egs/tedlium3/ASR/conformer_ctc2/optim.py create mode 120000 egs/tedlium3/ASR/conformer_ctc2/scaling.py create mode 120000 egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py create mode 120000 egs/tedlium3/ASR/conformer_ctc2/subsampling.py create mode 100755 egs/tedlium3/ASR/conformer_ctc2/train.py create mode 100644 egs/tedlium3/ASR/conformer_ctc2/transformer.py delete mode 120000 egs/tedlium3/ASR/local/convert_transcript_words_to_tokens.py delete mode 120000 egs/tedlium3/ASR/local/generate_unique_lexicon.py delete mode 120000 egs/tedlium3/ASR/local/prepare_lang.py delete mode 100755 egs/tedlium3/ASR/local/prepare_lexicon.py create mode 100755 egs/tedlium3/ASR/local/prepare_words.py delete mode 120000 egs/tedlium3/ASR/local/test_prepare_lang.py diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py index cb0d6e04d..52d2eda3b 100644 --- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py @@ -44,7 +44,8 @@ class LabelSmoothingLoss(torch.nn.Module): mean of the output is taken. (3) "sum": the output will be summed. """ super().__init__() - assert 0.0 <= label_smoothing < 1.0 + assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}" + assert reduction in ("none", "sum", "mean"), reduction self.ignore_index = ignore_index self.label_smoothing = label_smoothing self.reduction = reduction diff --git a/egs/librispeech/ASR/conformer_ctc2/subsampling.py b/egs/librispeech/ASR/conformer_ctc2/subsampling.py index 3fcb4196f..85a4dc8df 100644 --- a/egs/librispeech/ASR/conformer_ctc2/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc2/subsampling.py @@ -24,10 +24,9 @@ from scaling import ( ScaledConv2d, ScaledLinear, ) -from torch import nn -class Conv2dSubsampling(nn.Module): +class Conv2dSubsampling(torch.nn.Module): """Convolutional 2D subsampling (to 1/4 length). Convert an input of shape (N, T, idim) to an output @@ -61,7 +60,7 @@ class Conv2dSubsampling(nn.Module): assert in_channels >= 7 super().__init__() - self.conv = nn.Sequential( + self.conv = torch.nn.Sequential( ScaledConv2d( in_channels=1, out_channels=layer1_channels, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py index 65a7efa77..188059044 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py @@ -1435,7 +1435,7 @@ class EmformerEncoder(nn.Module): self, x: torch.Tensor, states: List[torch.Tensor], - ) -> Tuple[torch.Tensor, List[torch.Tensor],]: + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: """Forward pass for streaming inference. B: batch size; @@ -1640,7 +1640,7 @@ class Emformer(EncoderInterface): self, x: torch.Tensor, states: List[torch.Tensor], - ) -> Tuple[torch.Tensor, List[torch.Tensor],]: + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: """Forward pass for streaming inference. B: batch size; diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index df6c609bb..08dac6a7b 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -24,7 +24,7 @@ This script takes as input lang_dir and generates HLG from Caution: We use a lexicon that contains disambiguation symbols - - G, the LM, built from data/lm/G_3_gram.fst.txt + - G, the LM, built from data/lm/G_n_gram.fst.txt The generated HLG is saved in $lang_dir/HLG.pt """ diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 4a4093ae4..62036467e 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -28,7 +28,7 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, combine +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter, MonoCut, combine from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor @@ -41,6 +41,10 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) +def is_cut_long(c: MonoCut) -> bool: + return c.duration > 5 + + def compute_fbank_musan(): src_dir = Path("data/manifests") output_dir = Path("data/fbank") @@ -86,7 +90,7 @@ def compute_fbank_musan(): recordings=combine(part["recordings"] for part in manifests.values()) ) .cut_into_windows(10.0) - .filter(lambda c: c.duration > 5) + .filter(is_cut_long) .compute_and_store_features( extractor=extractor, storage_path=f"{output_dir}/musan_feats", diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index e121aefa9..2a2d9c219 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -127,7 +127,7 @@ def lexicon_to_fst_no_sil( def generate_lexicon( - model_file: str, words: List[str] + model_file: str, words: List[str], oov: str ) -> Tuple[Lexicon, Dict[str, int]]: """Generate a lexicon from a BPE model. @@ -136,6 +136,8 @@ def generate_lexicon( Path to a sentencepiece model. words: A list of strings representing words. + oov: + The out of vocabulary word in lexicon. Returns: Return a tuple with two elements: - A dict whose keys are words and values are the corresponding @@ -156,12 +158,9 @@ def generate_lexicon( for word, pieces in zip(words, words_pieces): lexicon.append((word, pieces)) - # The OOV word is - lexicon.append(("", [sp.id_to_piece(sp.unk_id())])) + lexicon.append((oov, ["โ–", sp.id_to_piece(sp.unk_id())])) - token2id: Dict[str, int] = dict() - for i in range(sp.vocab_size()): - token2id[sp.id_to_piece(i)] = i + token2id: Dict[str, int] = {sp.id_to_piece(i): i for i in range(sp.vocab_size())} return lexicon, token2id @@ -176,6 +175,13 @@ def get_args(): """, ) + parser.add_argument( + "--oov", + type=str, + default="", + help="The out of vocabulary word in lexicon.", + ) + parser.add_argument( "--debug", type=str2bool, @@ -202,12 +208,13 @@ def main(): words = word_sym_table.symbols - excluded = ["", "!SIL", "", "", "#0", "", ""] + excluded = ["", "!SIL", "", args.oov, "#0", "", ""] + for w in excluded: if w in words: words.remove(w) - lexicon, token_sym_table = generate_lexicon(model_file, words) + lexicon, token_sym_table = generate_lexicon(model_file, words, args.oov) lexicon_disambig, max_disambig = add_disambig_symbols(lexicon) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index c802ecf89..963ebdc2d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -652,16 +652,16 @@ class ActivationBalancer(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: if random.random() >= self.balance_prob: return x - else: - return ActivationBalancerFunction.apply( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - self.max_factor / self.balance_prob, - self.min_abs, - self.max_abs, - ) + + return ActivationBalancerFunction.apply( + x, + self.channel_dim, + self.min_positive, + self.max_positive, + self.max_factor / self.balance_prob, + self.min_abs, + self.max_abs, + ) class DoubleSwishFunction(torch.autograd.Function): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index b712eeda0..a6540c584 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -282,7 +282,7 @@ def convert_scaled_to_non_scaled( if not inplace: model = copy.deepcopy(model) - excluded_patterns = r"self_attn\.(in|out)_proj" + excluded_patterns = r"(self|src)_attn\.(in|out)_proj" p = re.compile(excluded_patterns) d = {} diff --git a/egs/tedlium3/ASR/RESULTS.md b/egs/tedlium3/ASR/RESULTS.md index 511b19f73..38eaa8f44 100644 --- a/egs/tedlium3/ASR/RESULTS.md +++ b/egs/tedlium3/ASR/RESULTS.md @@ -1,5 +1,88 @@ ## Results +### TedLium3 BPE training results (Conformer-CTC 2) + +#### [conformer_ctc2](./conformer_ctc2) + +See for more details. + +The tensorboard log can be found at + + +You can find a pretrained model and decoding results at: + + +Number of model parameters: 101141699, i.e., 101.14 M + +The WERs are + +| | dev | test | comment | +|--------------------------|------------|-------------|---------------------| +| ctc decoding | 6.45 | 5.96 | --epoch 38 --avg 26 | +| 1best | 5.92 | 5.51 | --epoch 38 --avg 26 | +| whole lattice rescoring | 5.96 | 5.47 | --epoch 38 --avg 26 | +| attention decoder | 5.60 | 5.33 | --epoch 38 --avg 26 | + +The training command for reproducing is given below: + +``` +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./conformer_ctc2/train.py \ + --world-size 4 \ + --num-epochs 40 \ + --exp-dir conformer_ctc2/exp \ + --max-duration 350 \ + --use-fp16 true +``` + +The decoding command is: +``` +epoch=38 +avg=26 + +## ctc decoding +./conformer_ctc2/decode.py \ + --method ctc-decoding \ + --exp-dir conformer_ctc2/exp \ + --lang-dir data/lang_bpe_500 \ + --result-dir conformer_ctc2/exp \ + --max-duration 500 \ + --epoch $epoch \ + --avg $avg + +## 1best +./conformer_ctc2/decode.py \ + --method 1best \ + --exp-dir conformer_ctc2/exp \ + --lang-dir data/lang_bpe_500 \ + --result-dir conformer_ctc2/exp \ + --max-duration 500 \ + --epoch $epoch \ + --avg $avg + +## whole lattice rescoring +./conformer_ctc2/decode.py \ + --method whole-lattice-rescoring \ + --exp-dir conformer_ctc2/exp \ + --lm-path data/lm/G_4_gram_big.pt \ + --lang-dir data/lang_bpe_500 \ + --result-dir conformer_ctc2/exp \ + --max-duration 500 \ + --epoch $epoch \ + --avg $avg + +## attention decoder +./conformer_ctc2/decode.py \ + --method attention-decoder \ + --exp-dir conformer_ctc2/exp \ + --lang-dir data/lang_bpe_500 \ + --result-dir conformer_ctc2/exp \ + --max-duration 500 \ + --epoch $epoch \ + --avg $avg +``` + ### TedLium3 BPE training results (Pruned Transducer) #### 2022-03-21 diff --git a/egs/tedlium3/ASR/conformer_ctc2/__init__.py b/egs/tedlium3/ASR/conformer_ctc2/__init__.py new file mode 100755 index 000000000..e69de29bb diff --git a/egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py b/egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py new file mode 120000 index 000000000..49b2ee483 --- /dev/null +++ b/egs/tedlium3/ASR/conformer_ctc2/asr_datamodule.py @@ -0,0 +1 @@ +../transducer_stateless/asr_datamodule.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/conformer_ctc2/attention.py b/egs/tedlium3/ASR/conformer_ctc2/attention.py new file mode 100644 index 000000000..178cd7e62 --- /dev/null +++ b/egs/tedlium3/ASR/conformer_ctc2/attention.py @@ -0,0 +1,201 @@ +# Copyright 2022 Behavox LLC. (author: Daniil Kulko) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Tuple, Union + +import torch +from scaling import ScaledLinear + + +class MultiheadAttention(torch.nn.Module): + """Allows the model to jointly attend to information + from different representation subspaces. This is a modified + version of the original version of multihead attention + (see Attention Is All You Need ) + with replacement of input / output projection layers + with newly introduced ScaleLinear layer + (see https://github.com/k2-fsa/icefall/blob/master/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py). + + Args: + embed_dim: + total dimension of the model. + num_heads: + number of parallel attention heads. Note that embed_dim will be split + across num_heads, i.e. each head will have dimension (embed_dim // num_heads). + dropout: + dropout probability on attn_output_weights. (default=0.0). + bias: + if specified, adds bias to input / output projection layers (default=True). + add_bias_kv: + if specified, adds bias to the key and value sequences at dim=0 (default=False). + add_zero_attn: + if specified, adds a new batch of zeros to the key and value sequences + at dim=1 (default=False). + batch_first: + if True, then the input and output tensors are provided as + (batch, seq, feature), otherwise (seq, batch, feature) (default=False). + + Examples:: + >>> multihead_attn = MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + bias: bool = True, + add_bias_kv: bool = False, + add_zero_attn: bool = False, + batch_first: bool = False, + device: Union[torch.device, str, None] = None, + dtype: Union[torch.dtype, str, None] = None, + ) -> None: + + super().__init__() + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + + if embed_dim % num_heads != 0: + raise ValueError( + f"embed_dim must be divisible by num_heads. " + "Got embedding dim vs number 0f heads: " + f"{embed_dim} vs {num_heads}" + ) + + self.head_dim = embed_dim // num_heads + + self.in_proj = ScaledLinear( + embed_dim, + 3 * embed_dim, + bias=bias, + device=device, + dtype=dtype, + ) + self.out_proj = ScaledLinear( + embed_dim, + embed_dim, + bias=bias, + initial_scale=0.25, + device=device, + dtype=dtype, + ) + + if add_bias_kv: + self.bias_k = torch.nn.Parameter( + torch.empty((1, 1, embed_dim), device=device, dtype=dtype) + ) + self.bias_v = torch.nn.Parameter( + torch.empty((1, 1, embed_dim), device=device, dtype=dtype) + ) + else: + self.register_parameter("bias_k", None) + self.register_parameter("bias_v", None) + + self.add_zero_attn = add_zero_attn + + self._reset_parameters() + + def _reset_parameters(self) -> None: + if self.bias_k is not None: + torch.nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + torch.nn.init.xavier_normal_(self.bias_v) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + query: + Query embeddings of shape (L, N, E_q) when batch_first=False or (N, L, E_q) + when batch_first=True, where L is the target sequence length, N is the batch size, + and E_q is the query embedding dimension embed_dim. Queries are compared against + key-value pairs to produce the output. See "Attention Is All You Need" for more details. + key: + Key embeddings of shape (S, N, E_k) when batch_first=False or (N, S, E_k) when + batch_first=True, where S is the source sequence length, N is the batch size, and + E_k is the key embedding dimension kdim. See "Attention Is All You Need" for more details. + value: + Value embeddings of shape (S, N, E_v) when batch_first=False or (N, S, E_v) when + batch_first=True, where S is the source sequence length, N is the batch size, and + E_v is the value embedding dimension vdim. See "Attention Is All You Need" for more details. + key_padding_mask: + If specified, a mask of shape (N, S) indicating which elements within key + to ignore for the purpose of attention (i.e. treat as "padding"). + Binary and byte masks are supported. For a binary mask, a True value indicates + that the corresponding key value will be ignored for the purpose of attention. + For a byte mask, a non-zero value indicates that the corresponding key value will be ignored. + need_weights: + If specifid, returns attn_output_weights in addition to attn_outputs (default=True). + attn_mask: + If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + (L, S) or (N * num_heads, L, S), where N is the batch size, L is the target sequence length, + and S is the source sequence length. A 2D mask will be broadcasted across the batch while + a 3D mask allows for a different mask for each entry in the batch. + Binary, byte, and float masks are supported. For a binary mask, a True value indicates + that the corresponding position is not allowed to attend. For a byte mask, a non-zero + value indicates that the corresponding position is not allowed to attend. For a float mask, + the mask values will be added to the attention weight. + + Returns: + attn_output: + Attention outputs of shape (L, N, E) when batch_first=False or (N, L, E) when batch_first=True, + where L is the target sequence length, N is the batch size, and E is the embedding dimension + embed_dim. + attn_output_weights: + Attention output weights of shape (N, L, S), where N is the batch size, L is the target sequence + length, and S is the source sequence length. Only returned when need_weights=True. + """ + if self.batch_first: + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + + ( + attn_output, + attn_output_weights, + ) = torch.nn.functional.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + in_proj_weight=self.in_proj.get_weight(), + in_proj_bias=self.in_proj.get_bias(), + bias_k=self.bias_k, + bias_v=self.bias_v, + add_zero_attn=self.add_zero_attn, + dropout_p=self.dropout, + out_proj_weight=self.out_proj.get_weight(), + out_proj_bias=self.out_proj.get_bias(), + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + ) + + if self.batch_first: + return attn_output.transpose(1, 0), attn_output_weights + return attn_output, attn_output_weights diff --git a/egs/tedlium3/ASR/conformer_ctc2/combiner.py b/egs/tedlium3/ASR/conformer_ctc2/combiner.py new file mode 100644 index 000000000..ff526029d --- /dev/null +++ b/egs/tedlium3/ASR/conformer_ctc2/combiner.py @@ -0,0 +1,244 @@ +# Copyright 2022 Behavox LLC. (author: Daniil Kulko) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch + + +class RandomCombine(torch.nn.Module): + """ + This module combines a list of Tensors, all with the same shape, to + produce a single output of that same shape which, in training time, + is a random combination of all the inputs; but which in test time + will be just the last input. + The idea is that the list of Tensors will be a list of outputs of multiple + conformer layers. This has a similar effect as iterated loss. (See: + DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER + NETWORKS). + """ + + def __init__( + self, + num_inputs: int, + final_weight: float = 0.5, + pure_prob: float = 0.5, + stddev: float = 2.0, + ) -> None: + """ + Args: + num_inputs: + The number of tensor inputs, which equals the number of layers' + outputs that are fed into this module. E.g. in an 18-layer neural + net if we output layers 16, 12, 18, num_inputs would be 3. + final_weight: + The amount of weight or probability we assign to the + final layer when randomly choosing layers or when choosing + continuous layer weights. + pure_prob: + The probability, on each frame, with which we choose + only a single layer to output (rather than an interpolation) + stddev: + A standard deviation that we add to log-probs for computing + randomized weights. + The method of choosing which layers, or combinations of layers, to use, + is conceptually as follows:: + With probability `pure_prob`:: + With probability `final_weight`: choose final layer, + Else: choose random non-final layer. + Else:: + Choose initial log-weights that correspond to assigning + weight `final_weight` to the final layer and equal + weights to other layers; then add Gaussian noise + with variance `stddev` to these log-weights, and normalize + to weights (note: the average weight assigned to the + final layer here will not be `final_weight` if stddev>0). + """ + super().__init__() + assert 0 <= pure_prob <= 1, pure_prob + assert 0 < final_weight < 1, final_weight + assert num_inputs >= 1, num_inputs + + self.num_inputs = num_inputs + self.final_weight = final_weight + self.pure_prob = pure_prob + self.stddev = stddev + + self.final_log_weight = ( + torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) + .log() + .item() + ) + + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: + """Forward function. + Args: + inputs: + A list of Tensor, e.g. from various layers of a transformer. + All must be the same shape, of (*, num_channels) + Returns: + A Tensor of shape (*, num_channels). In test mode + this is just the final input. + """ + num_inputs = self.num_inputs + assert len(inputs) == num_inputs, f"{len(inputs)}, {num_inputs}" + if not self.training or torch.jit.is_scripting() or len(inputs) == 1: + return inputs[-1] + + # Shape of weights: (*, num_inputs) + num_channels = inputs[0].shape[-1] + num_frames = inputs[0].numel() // num_channels + + ndim = inputs[0].ndim + # stacked_inputs: (num_frames, num_channels, num_inputs) + stacked_inputs = torch.stack(inputs, dim=ndim).reshape( + (num_frames, num_channels, num_inputs) + ) + + # weights: (num_frames, num_inputs) + weights = self._get_random_weights( + inputs[0].dtype, inputs[0].device, num_frames + ) + + weights = weights.reshape(num_frames, num_inputs, 1) + # ans: (num_frames, num_channels, 1) + ans = torch.matmul(stacked_inputs, weights) + # ans: (*, num_channels) + + ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,)) + + return ans + + def _get_random_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ) -> torch.Tensor: + """Return a tensor of random weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired + Returns: + A tensor of shape (num_frames, self.num_inputs), such that + `ans.sum(dim=1)` is all ones. + """ + pure_prob = self.pure_prob + if pure_prob == 0.0: + return self._get_random_mixed_weights(dtype, device, num_frames) + elif pure_prob == 1.0: + return self._get_random_pure_weights(dtype, device, num_frames) + else: + p = self._get_random_pure_weights(dtype, device, num_frames) + m = self._get_random_mixed_weights(dtype, device, num_frames) + return torch.where( + torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m + ) + + def _get_random_pure_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ) -> torch.Tensor: + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A one-hot tensor of shape `(num_frames, self.num_inputs)`, with + exactly one weight equal to 1.0 on each frame. + """ + final_prob = self.final_weight + + # final contains self.num_inputs - 1 in all elements + final = torch.full((num_frames,), self.num_inputs - 1, device=device) + # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. + nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) + + indexes = torch.where( + torch.rand(num_frames, device=device) < final_prob, final, nonfinal + ) + ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( + dtype=dtype + ) + return ans + + def _get_random_mixed_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ) -> torch.Tensor: + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A tensor of shape (num_frames, self.num_inputs), which elements + in [0..1] that sum to one over the second axis, i.e. + `ans.sum(dim=1)` is all ones. + """ + logprobs = ( + torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) + * self.stddev + ) + logprobs[:, -1] += self.final_log_weight + return logprobs.softmax(dim=1) + + +def _test_random_combine( + final_weight: float, + pure_prob: float, + stddev: float, +) -> None: + print( + f"_test_random_combine: final_weight={final_weight}, " + f"pure_prob={pure_prob}, stddev={stddev}" + ) + num_inputs = 3 + num_channels = 50 + m = RandomCombine( + num_inputs=num_inputs, + final_weight=final_weight, + pure_prob=pure_prob, + stddev=stddev, + ) + + x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] + + y = m(x) + assert y.shape == x[0].shape + assert torch.allclose(y, x[0]) # .. since actually all ones. + + +def _test_random_combine_main() -> None: + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.0) + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.3) + _test_random_combine(0.5, 1, 0.3) + _test_random_combine(0.5, 0.5, 0.3) + + +if __name__ == "__main__": + _test_random_combine_main() diff --git a/egs/tedlium3/ASR/conformer_ctc2/conformer.py b/egs/tedlium3/ASR/conformer_ctc2/conformer.py new file mode 100644 index 000000000..fad2f371f --- /dev/null +++ b/egs/tedlium3/ASR/conformer_ctc2/conformer.py @@ -0,0 +1,1033 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# 2022 Xiaomi Corp. (author: Quandong Wang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import warnings +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from combiner import RandomCombine +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv1d, + ScaledLinear, +) +from subsampling import Conv2dSubsampling +from transformer import Supervisions, Transformer, encoder_padding_mask + + +class Conformer(Transformer): + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 31, + aux_layer_period: int = 3, + ) -> None: + """ + Args: + num_features (int): + number of input features. + num_classes (int): + number of output classes. + subsampling_factor (int): + subsampling factor of encoder; + currently, subsampling_factor MUST be 4. + d_model (int): + attention dimension, also the output dimension. + nhead (int): + number of heads in multi-head attention; + must satisfy d_model // nhead == 0. + dim_feedforward (int): + feedforward dimention. + num_encoder_layers (int): + number of encoder layers. + num_decoder_layers (int): + number of decoder layers. + dropout (float): + dropout rate. + layer_dropout (float): + layer-dropout rate. + cnn_module_kernel (int): + kernel size of convolution module. + aux_layer_period (int): + determines the auxiliary encoder layers. + """ + + super().__init__( + num_features=num_features, + num_classes=num_classes, + subsampling_factor=subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + dropout=dropout, + layer_dropout=layer_dropout, + ) + + self.num_features = num_features + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_pos = RelPositionalEncoding(d_model, dropout) + + encoder_layer = ConformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + layer_dropout=layer_dropout, + cnn_module_kernel=cnn_module_kernel, + ) + + # aux_layers from 1/3 + self.encoder = ConformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + aux_layers=list( + range( + num_encoder_layers // 3, + num_encoder_layers - 1, + aux_layer_period, + ) + ), + ) + + def run_encoder( + self, + x: torch.Tensor, + supervisions: Optional[Supervisions] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + x: + the input tensor. Its shape is (batch_size, seq_len, feature_dim). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. + warmup: + a floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + + Returns: + torch.Tensor: Predictor tensor of dimension (S, N, C). + torch.Tensor: Mask tensor of dimension (N, S) + """ + x = self.encoder_embed(x) + x, pos_emb = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, S, C) -> (S, N, C) + mask = encoder_padding_mask(x.size(0), supervisions) + mask = mask.to(x.device) if mask is not None else None + + x = self.encoder( + x, pos_emb, src_key_padding_mask=mask, warmup=warmup + ) # (S, N, C) + + return x, mask + + +class ConformerEncoderLayer(nn.Module): + """ + ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. + See: "Conformer: Convolution-augmented Transformer for Speech Recognition" + + Examples: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + bypass_scale: float = 0.1, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 31, + ) -> None: + """ + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + bypass_scale: + a scale on the layer's output, used in bypass (resnet-type) skip-connection; + when the layer is bypassed the final output will be a + weighted sum of the layer's input and layer's output with weights + (1.0-bypass_scale) and bypass_scale correspondingly (default=0.1). + layer_dropout: + the probability to bypass the layer (default=0.075). + cnn_module_kernel (int): + kernel size of convolution module (default=31). + """ + super().__init__() + + if bypass_scale < 0.0 or bypass_scale > 1.0: + raise ValueError("bypass_scale should be between 0.0 and 1.0") + + if layer_dropout < 0.0 or layer_dropout > 1.0: + raise ValueError("layer_dropout should be between 0.0 and 1.0") + + self.bypass_scale = bypass_scale + self.layer_dropout = layer_dropout + + self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.feed_forward_macaron = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) + + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: torch.Tensor, + pos_emb: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> torch.Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: + the sequence to the encoder layer of shape (S, N, C) (required). + pos_emb: + positional embedding tensor of shape (N, 2*S-1, C) (required). + src_mask: + the mask for the src sequence of shape (S, S) (optional). + src_key_padding_mask: + the mask for the src keys per batch of shape (N, S) (optional). + warmup: + controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + + Returns: + Output tensor of the shape (S, N, C), where + S is the source sequence length, + N is the batch size, + C is the feature number + """ + src_orig = src + + warmup_scale = min(self.bypass_scale + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else self.bypass_scale + ) + else: + alpha = 1.0 + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + # multi-headed self-attention module + src_att = self.self_attn( + src, + src, + src, + pos_emb=pos_emb, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + + src = src + self.dropout(src_att) + + # convolution module + src = src + self.dropout(self.conv_module(src)) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig + + return src + + +class ConformerEncoder(nn.Module): + """ + ConformerEncoder is a stack of N encoder layers + + Examples: + >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) + >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = conformer_encoder(src, pos_emb) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + aux_layers: List[int], + ) -> None: + + """ + Args: + encoder_layer: + an instance of the ConformerEncoderLayer() class (required). + num_layers: + the number of sub-encoder-layers in the encoder (required). + aux_layers: + list of indexes of sub-encoder-layers outputs to be combined (required). + """ + + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + assert len(set(aux_layers)) == len(aux_layers) + + assert num_layers - 1 not in aux_layers + self.aux_layers = aux_layers + [num_layers - 1] + + self.combiner = RandomCombine( + num_inputs=len(self.aux_layers), + final_weight=0.5, + pure_prob=0.333, + stddev=2.0, + ) + + def forward( + self, + src: torch.Tensor, + pos_emb: torch.Tensor, + mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> torch.Tensor: + """ + Pass the input through the encoder layers in turn. + + Args: + src: + the sequence to the encoder of shape (S, N, C) (required). + pos_emb: + positional embedding tensor of shape (N, 2*S-1, C) (required). + mask: + the mask for the src sequence of shape (S, S) (optional). + src_key_padding_mask: + the mask for the src keys per batch of shape (N, S) (optional). + warmup: + controls selective bypass of layer; if < 1.0, we will + bypass the layer more frequently (default=1.0). + + Returns: + Output tensor of the shape (S, N, C), where + S is the source sequence length, + N is the batch size, + C is the feature number. + + """ + output = src + + outputs = [] + for i, mod in enumerate(self.layers): + output = mod( + output, + pos_emb, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) + + if i in self.aux_layers: + outputs.append(output) + + output = self.combiner(outputs) + + return output + + +class RelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. + + See: Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py + + """ + + def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + """ + Construct an PositionalEncoding object. + + Args: + d_model: Embedding dimension. + dropout_rate: Dropout rate. + max_len: Maximum input length. + + """ + super().__init__() + self.d_model = d_model + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x: torch.Tensor) -> None: + """ + Reset the positional encodings. + + Args: + x: + input tensor (N, T, C), where + T is the source sequence length, + N is the batch size. + C is the feature number. + + """ + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` means to the position of query vecotr and `j` means the + # position of key vector. We use position relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i Tuple[torch.Tensor, torch.Tensor]: + """ + Add positional encoding. + + Args: + x: + input tensor (N, T, C). + + Returns: + torch.Tensor: Encoded tensor (N, T, C). + torch.Tensor: Encoded tensor (N, 2*T-1, C), where + T is the source sequence length, + N is the batch size. + C is the feature number. + + """ + self.extend_pe(x) + pos_emb = self.pe[ + :, + self.pe.size(1) // 2 + - x.size(1) + + 1 : self.pe.size(1) // 2 # noqa E203 + + x.size(1), + ] + return self.dropout(x), self.dropout(pos_emb) + + +class RelPositionMultiheadAttention(nn.Module): + """ + Multi-Head Attention layer with relative position encoding + See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context". + + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + ) -> None: + """ + Args: + embed_dim: + total dimension of the model. + num_heads: + parallel attention heads. + dropout: + a Dropout layer on attn_output_weights. Default: 0.0. + """ + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = ScaledLinear( + embed_dim, embed_dim, bias=True, initial_scale=0.25 + ) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) + self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) + self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) + self._reset_parameters() + + def _pos_bias_u(self): + return self.pos_bias_u * self.pos_bias_u_scale.exp() + + def _pos_bias_v(self): + return self.pos_bias_v * self.pos_bias_v_scale.exp() + + def _reset_parameters(self) -> None: + nn.init.normal_(self.pos_bias_u, std=0.01) + nn.init.normal_(self.pos_bias_v, std=0.01) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + pos_emb: torch.Tensor, + key_padding_mask: Optional[torch.Tensor] = None, + need_weights: bool = False, + attn_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. When given a binary mask + and a value is True, the corresponding value on the attention + layer will be ignored. When given a byte mask and a value is + non-zero, the corresponding value on the attention layer will be ignored. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. + A 2D mask will be broadcasted for all the batches while a 3D + mask allows to specify a different mask for the entries of each batch. + + Shape: + - Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the position + with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + - Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + return self.multi_head_attention_forward( + query, + key, + value, + pos_emb, + self.embed_dim, + self.num_heads, + self.in_proj.get_weight(), + self.in_proj.get_bias(), + self.dropout, + self.out_proj.get_weight(), + self.out_proj.get_bias(), + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + ) + + def rel_shift(self, x: torch.Tensor) -> torch.Tensor: + """ + Compute relative positional encoding. + + Args: + x: + input tensor (batch, head, time1, 2*time1-1). + time1 means the length of query vector. + + Returns: + torch.Tensor: tensor of shape (batch, head, time1, time2) + (note: time2 has the same value as time1, but it is for + the key, while time1 is for the query). + """ + (batch_size, num_heads, time1, n) = x.shape + assert n == 2 * time1 - 1 + # Note: TorchScript requires explicit arg for stride() + batch_stride = x.stride(0) + head_stride = x.stride(1) + time1_stride = x.stride(2) + n_stride = x.stride(3) + return x.as_strided( + (batch_size, num_heads, time1, time1), + (batch_stride, head_stride, time1_stride - n_stride, n_stride), + storage_offset=n_stride * (time1 - 1), + ) + + def multi_head_attention_forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + pos_emb: torch.Tensor, + embed_dim_to_check: int, + num_heads: int, + in_proj_weight: torch.Tensor, + in_proj_bias: torch.Tensor, + dropout_p: float, + out_proj_weight: torch.Tensor, + out_proj_bias: torch.Tensor, + training: bool = True, + key_padding_mask: Optional[torch.Tensor] = None, + need_weights: bool = False, + attn_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + query, key, value: map a query and a set of key-value pairs to an output. + pos_emb: Positional embedding tensor + embed_dim_to_check: total dimension of the model. + num_heads: parallel attention heads. + in_proj_weight, in_proj_bias: input projection weight and bias. + dropout_p: probability of an element to be zeroed. + out_proj_weight, out_proj_bias: the output projection weight and bias. + training: apply dropout if is ``True``. + key_padding_mask: if provided, specified padding elements in the key will + be ignored by the attention. This is an binary mask. + When the value is True, the corresponding value on the + attention layer will be filled with -inf. + need_weights: output attn_output_weights. + attn_mask: 2D or 3D mask that prevents attention to certain positions. + A 2D mask will be broadcasted for all the batches while a 3D + mask allows to specify a different mask for the entries of each batch. + + Shape: + Inputs: + - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is + the embedding dimension. + - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is + the embedding dimension. + - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence + length, N is the batch size, E is the embedding dimension. + - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. + If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions + will be unchanged. If a BoolTensor is provided, the positions with the + value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. + - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. + 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, + S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked + positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend + while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` + are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor + is provided, it will be added to the attention weight. + + Outputs: + - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, + E is the embedding dimension. + - attn_output_weights: :math:`(N, L, S)` where N is the batch size, + L is the target sequence length, S is the source sequence length. + """ + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == embed_dim_to_check + assert key.size(0) == value.size(0) and key.size(1) == value.size(1) + + head_dim = embed_dim // num_heads + assert ( + head_dim * num_heads == embed_dim + ), "embed_dim must be divisible by num_heads" + + scaling = float(head_dim) ** -0.5 + + if torch.equal(query, key) and torch.equal(key, value): + # self-attention + q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( + 3, dim=-1 + ) + + elif torch.equal(key, value): + # encoder-decoder attention + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) + + else: + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = 0 + _end = embed_dim + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + q = nn.functional.linear(query, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim + _end = embed_dim * 2 + _w = in_proj_weight[_start:_end, :] + if _b is not None: + _b = _b[_start:_end] + k = nn.functional.linear(key, _w, _b) + + # This is inline in_proj function with in_proj_weight and in_proj_bias + _b = in_proj_bias + _start = embed_dim * 2 + _end = None + _w = in_proj_weight[_start:, :] + if _b is not None: + _b = _b[_start:] + v = nn.functional.linear(value, _w, _b) + + if attn_mask is not None: + assert ( + attn_mask.dtype == torch.float32 + or attn_mask.dtype == torch.float64 + or attn_mask.dtype == torch.float16 + or attn_mask.dtype == torch.uint8 + or attn_mask.dtype == torch.bool + ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( + attn_mask.dtype + ) + if attn_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for attn_mask is deprecated. Use bool tensor instead." + ) + attn_mask = attn_mask.to(torch.bool) + + if attn_mask.dim() == 2: + attn_mask = attn_mask.unsqueeze(0) + if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: + raise RuntimeError("The size of the 2D attn_mask is not correct.") + elif attn_mask.dim() == 3: + if list(attn_mask.size()) != [ + bsz * num_heads, + query.size(0), + key.size(0), + ]: + raise RuntimeError("The size of the 3D attn_mask is not correct.") + else: + raise RuntimeError( + f"attn_mask's dimension {attn_mask.dim()} is not supported" + ) + # attn_mask's dim is 3 now. + + # convert ByteTensor key_padding_mask to bool + if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + warnings.warn( + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." + ) + key_padding_mask = key_padding_mask.to(torch.bool) + + q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) + k = k.contiguous().view(-1, bsz, num_heads, head_dim) + v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) + + src_len = k.size(0) + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz, "{} == {}".format( + key_padding_mask.size(0), bsz + ) + assert key_padding_mask.size(1) == src_len, "{} == {}".format( + key_padding_mask.size(1), src_len + ) + + q = q.transpose(0, 1) # (batch, time1, head, d_k) + + pos_emb_bsz = pos_emb.size(0) + assert pos_emb_bsz in (1, bsz) # actually it is 1 + p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) + p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k) + + q_with_bias_u = (q + self._pos_bias_u()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + q_with_bias_v = (q + self._pos_bias_v()).transpose( + 1, 2 + ) # (batch, head, time1, d_k) + + # compute attention score + # first compute matrix a and matrix c + # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 + k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + + # compute matrix b and matrix d + matrix_bd = torch.matmul( + q_with_bias_v, p.transpose(-2, -1) + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift(matrix_bd) + + attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + + assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_output_weights.masked_fill_(attn_mask, float("-inf")) + else: + attn_output_weights += attn_mask + + if key_padding_mask is not None: + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + attn_output_weights = attn_output_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + float("-inf"), + ) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, src_len + ) + + attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) + attn_output_weights = nn.functional.dropout( + attn_output_weights, p=dropout_p, training=training + ) + + attn_output = torch.bmm(attn_output_weights, v) + assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] + attn_output = ( + attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + + if need_weights: + # average attention weights over heads + attn_output_weights = attn_output_weights.view( + bsz, num_heads, tgt_len, src_len + ) + return attn_output, attn_output_weights.sum(dim=1) / num_heads + else: + return attn_output, None + + +class ConvolutionModule(nn.Module): + def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + """ + ConvolutionModule in Conformer model. + Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py + Construct a ConvolutionModule object. + + Args: + channels (int): + the number of channels of conv layers. + kernel_size (int): + kernerl size of conv layers. + bias (bool): + whether to use bias in conv layers (default=True). + """ + super().__init__() + # kernerl_size should be a odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0 + + self.pointwise_conv1 = ScaledConv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + + # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, + # but sometimes, for some reason, for layer 0 the rms ends up being very large, + # between 50 and 100 for different channels. This will cause very peaky and + # sparse derivatives for the sigmoid gating function, which will tend to make + # the loss function not learn effectively. (for most layers the average absolute values + # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different + # layers, which likely breaks down as 0.5 for the "linear" half and + # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we + # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, + # it will be in a better position to start learning something, i.e. to latch onto + # the correct range. + self.deriv_balancer1 = ActivationBalancer( + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + ) + + self.depthwise_conv = ScaledConv1d( + channels, + channels, + kernel_size, + stride=1, + padding=(kernel_size - 1) // 2, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + channel_dim=1, min_positive=0.05, max_positive=1.0 + ) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.25, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Compute convolution module. + + Args: + x: + input tensor of shape (T, N, C). + + Returns: + torch.Tensor: Output tensor (T, N, C), where + T is the source sequence length, + N is the batch size, + C is the feature number. + + """ + # exchange the temporal dimension and the feature dimension + x = x.permute(1, 2, 0) # (#batch, channels, time). + + # GLU mechanism + x = self.pointwise_conv1(x) # (batch, 2*channels, time) + + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (batch, channels, time) + + # 1D Depthwise Conv + x = self.depthwise_conv(x) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + x = self.pointwise_conv2(x) # (batch, channel, time) + + return x.permute(2, 0, 1) diff --git a/egs/tedlium3/ASR/conformer_ctc2/decode.py b/egs/tedlium3/ASR/conformer_ctc2/decode.py new file mode 100755 index 000000000..ce4dcd142 --- /dev/null +++ b/egs/tedlium3/ASR/conformer_ctc2/decode.py @@ -0,0 +1,899 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, +# Fangjun Kuang, +# Quandong Wang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import shutil +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import TedLiumAsrDataModule +from conformer import Conformer +from train import add_model_arguments + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + load_averaged_model, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--method", + type=str, + default="attention-decoder", + help="""Decoding method. + Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) ctc-greedy-search. It only use CTC output and a sentence piece + model for decoding. It produces the same results with ctc-decoding. + - (2) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (3) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (4) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + - (6) attention-decoder. Extract n paths from the LM rescored + lattice, the path with the highest score is the decoding result. + - (7) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--lm-path", + type=str, + default="data/lm/G_4_gram.pt", + help="""The n-gram LM dir for rescoring. + It should contain either lm_fname.pt or lm_fname.fst.txt + """, + ) + + parser.add_argument( + "--result-dir", + type=str, + default="conformer_ctc2/exp", + help="Directory to store results.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + """ + params = AttributeDict( + { + # parameters for conformer + "subsampling_factor": 4, + "feature_dim": 80, + # parameters for decoding + "search_beam": 15, + "output_beam": 8, + "min_active_states": 10, + "max_active_states": 7000, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def ctc_greedy_search( + ctc_probs: torch.Tensor, + mask: torch.Tensor, +) -> List[List[int]]: + """Apply CTC greedy search + Args: + ctc_probs (torch.Tensor): (batch, max_len, num_bpe) + mask (torch.Tensor): (batch, max_len) + Returns: + best path result + """ + + _, max_index = ctc_probs.max(2) # (B, maxlen) + max_index = max_index.masked_fill_(mask, 0) # (B, maxlen) + + ret_hyps = [] + for hyp in max_index: + hyp = torch.unique_consecutive(hyp) + hyp = hyp[hyp > 0].tolist() + ret_hyps.append(hyp) + return ret_hyps + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + - params.method is "nbest", it uses nbest decoding without LM rescoring. + - params.method is "nbest-rescoring", it uses nbest LM rescoring. + - params.method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + + nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + torch.div( + supervisions["start_frame"], + params.subsampling_factor, + rounding_mode="floor", + ), + torch.div( + supervisions["num_frames"], + params.subsampling_factor, + rounding_mode="floor", + ), + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + unk = bpe_model.decode(bpe_model.unk_id()).strip() + hyps = [[w for w in s.split() if w != unk] for s in hyps] + key = "ctc-decoding" + + return {key: hyps} + + if params.method == "ctc-greedy-search": + hyps = ctc_greedy_search(nnet_output, memory_key_padding_mask) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(hyps) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + unk = bpe_model.decode(bpe_model.unk_id()).strip() + hyps = [[w for w in s.split() if w != unk] for s in hyps] + key = "ctc-greedy-search" + + return {key: hyps} + + if params.method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [ + [word_table[i] for i in ids if word_table[i] != ""] for ids in hyps + ] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.method == "nbest": + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [ + [word_table[i] for i in ids if word_table[i] != ""] for ids in hyps + ] + return {key: hyps} + + assert params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.method == "1best": + best_path_dict = one_best_decoding( + lattice=lattice, + lm_scale_list=lm_scale_list, + ) + elif params.method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + elif params.method == "attention-decoder": + best_path_dict = rescore_with_attention_decoder( + lattice=lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + nbest_scale=params.nbest_scale, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [ + [word_table[i] for i in ids if word_table[i] != ""] for ids in hyps + ] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + if hyps_dict is not None: + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + else: + assert len(results) > 0, "It should not decode to empty in the first batch!" + this_batch = [] + hyp_words = [] + for ref_text in texts: + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + for lm_scale in results.keys(): + results[lm_scale].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +) -> None: + if params.method == "attention-decoder": + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.result_dir / f"recogs-{test_set_name}-{key}.txt" + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.result_dir / f"errs-{test_set_name}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.result_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main() -> None: + parser = get_parser() + TedLiumAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_path = Path(args.lm_path) + args.result_dir = Path(args.result_dir) + + if args.result_dir.is_dir(): + shutil.rmtree(args.result_dir) + args.result_dir.mkdir() + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + if params.method in ("ctc-decoding", "ctc-greedy-search"): + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.method in ("nbest-rescoring", "whole-lattice-rescoring"): + assert params.lm_path.suffix in (".pt", ".txt") + + if params.lm_path.is_file() and params.lm_path.suffix == ".pt": + logging.info(f"Loading pre-compiled {params.lm_path.name}") + d = torch.load(params.lm_path, map_location=device) + G = k2.Fsa.from_dict(d) + elif not params.lm_path.is_file() and params.lm_path.suffix == ".txt": + raise FileNotFoundError(f"No such language model file: '{params.lm_path}'") + else: + # here we pass only if LM filename ends with '.pt' and doesn't exist + # or if LM filename ends '.txt' and exists. + if ( + not params.lm_path.is_file() + and params.lm_path.suffix == ".pt" + and not ( + params.lm_path.parent / f"{params.lm_path.stem}.fst.txt" + ).is_file() + ): + raise FileNotFoundError( + f"No such language model file: '{params.lm_path}'\n" + "'.fst.txt' representation of the language model was " + "not found either." + ) + else: + # whatever params.lm_path.name we got lm_name.pt or lm_name.fst.txt + # we are going to load lm_name.fst.txt here + params.lm_path = params.lm_path.parent / params.lm_path.name.replace( + ".pt", ".fst.txt" + ) + logging.info(f"Loading {params.lm_path.name}") + logging.warning("It may take 8 minutes.") + with open(params.lm_path) as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save( + G.as_dict(), + params.lm_path.parent + / params.lm_path.name.replace(".fst.txt", ".pt"), + ) + + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + model = Conformer( + num_features=params.feature_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + d_model=params.dim_model, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + num_decoder_layers=params.num_decoder_layers, + ) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + tedlium = TedLiumAsrDataModule(args) + + valid_cuts = tedlium.dev_cuts() + test_cuts = tedlium.test_cuts() + + valid_dl = tedlium.valid_dataloaders(valid_cuts) + test_dl = tedlium.test_dataloaders(test_cuts) + + test_sets = ["dev", "test"] + test_dls = [valid_dl, test_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +torch.set_num_threads(1) +# when we import add_model_arguments from train.py +# we enforce torch.set_num_interop_threads(1) in it, +# so we ended up with setting num_interop_threads to one +# two times: in train.py and decode.py which cause an error, +# that is why added an additional if statement. +if torch.get_num_interop_threads() != 1: + torch.set_num_interop_threads(1) + +# The flag below controls whether to allow TF32 on matmul. This flag defaults to False +# in PyTorch 1.12 and later. +torch.backends.cuda.matmul.allow_tf32 = True + +if __name__ == "__main__": + main() diff --git a/egs/tedlium3/ASR/conformer_ctc2/export.py b/egs/tedlium3/ASR/conformer_ctc2/export.py new file mode 100755 index 000000000..009bea230 --- /dev/null +++ b/egs/tedlium3/ASR/conformer_ctc2/export.py @@ -0,0 +1,294 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Behavox LLC (Author: Daniil Kulko) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./conformer_ctc2/export.py \ + --exp-dir ./conformer_ctc2/exp \ + --epoch 20 \ + --avg 10 + +It will generate a file exp_dir/pretrained.pt + +To use the generated file with `conformer_ctc2/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/tedlium3/ASR + ./conformer_ctc2/decode.py \ + --exp-dir ./conformer_ctc2/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 100 +""" + +import argparse +import logging +from pathlib import Path + +import torch +from conformer import Conformer +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, str2bool + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help=( + "Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'" + ), + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help=( + "Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. " + ), + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc2/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=True, + help="""True to save a model after applying torch.jit.script. + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + """ + # parameters for conformer + params = AttributeDict({"subsampling_factor": 4, "feature_dim": 80}) + return params + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info(params) + + logging.info("About to create model") + + model = Conformer( + num_features=params.feature_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + d_model=params.dim_model, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + num_decoder_layers=params.num_decoder_layers, + ) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + "Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py b/egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py new file mode 120000 index 000000000..e9d239fff --- /dev/null +++ b/egs/tedlium3/ASR/conformer_ctc2/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/conformer_ctc2/lstmp.py b/egs/tedlium3/ASR/conformer_ctc2/lstmp.py new file mode 120000 index 000000000..b82e115fc --- /dev/null +++ b/egs/tedlium3/ASR/conformer_ctc2/lstmp.py @@ -0,0 +1 @@ +../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/conformer_ctc2/optim.py b/egs/tedlium3/ASR/conformer_ctc2/optim.py new file mode 120000 index 000000000..0a2f285aa --- /dev/null +++ b/egs/tedlium3/ASR/conformer_ctc2/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/conformer_ctc2/scaling.py b/egs/tedlium3/ASR/conformer_ctc2/scaling.py new file mode 120000 index 000000000..c10cdfe12 --- /dev/null +++ b/egs/tedlium3/ASR/conformer_ctc2/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py b/egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py new file mode 120000 index 000000000..db93d155b --- /dev/null +++ b/egs/tedlium3/ASR/conformer_ctc2/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/conformer_ctc2/subsampling.py b/egs/tedlium3/ASR/conformer_ctc2/subsampling.py new file mode 120000 index 000000000..8c91f2336 --- /dev/null +++ b/egs/tedlium3/ASR/conformer_ctc2/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc2/subsampling.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/conformer_ctc2/train.py b/egs/tedlium3/ASR/conformer_ctc2/train.py new file mode 100755 index 000000000..42e4c010a --- /dev/null +++ b/egs/tedlium3/ASR/conformer_ctc2/train.py @@ -0,0 +1,1061 @@ +#!/usr/bin/env python3 +# Copyright 2022 Behavox LLC. (authors: Daniil Kulko) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./conformer_ctc/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir conformer_ctc/exp \ + --max-duration 300 + +# For mix precision training: + +./conformer_ctc/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir conformer_ctc/exp \ + --max-duration 550 + +""" + + +import argparse +import copy +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +from asr_datamodule import TedLiumAsrDataModule +from conformer import Conformer +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from local.convert_transcript_words_to_bpe_ids import convert_texts_into_ids +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + display_and_save_batch, + encode_supervisions, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def add_model_arguments(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--num-encoder-layers", + type=int, + default=24, + help="Number of conformer encoder layers..", + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=6, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.8, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--dim-feedforward", + type=int, + default=1536, + help="Feedforward module dimension of the conformer model.", + ) + + parser.add_argument( + "--nhead", + type=int, + default=8, + help="Number of attention heads in the conformer multiheadattention modules.", + ) + + parser.add_argument( + "--dim-model", + type=int, + default=384, + help="Attention dimension in the conformer model.", + ) + + +def get_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" and "bpe.model" + """, + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="The initial learning rate. This value should not need to be changed.", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="Number of epochs that affects how rapidly the learning rate decreases.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 10, + "reset_interval": 200, + "valid_interval": 1000, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + # parameters for ctc loss + "beam_size": 10, + "reduction": "none", + "use_double_scores": True, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: torch.nn.Module, + model_avg: torch.nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that is used for training. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[torch.nn.Module, DDP], + model_avg: Optional[torch.nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used for training. + scheduler: + The learning rate scheduler used for training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[torch.nn.Module, DDP], + graph_compiler: BpeCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, + warmup: float = 1.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model( + feature, supervisions, warmup=warmup + ) + + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + token_ids = convert_texts_into_ids(texts, graph_compiler.sp) + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + + if params.att_rate > 0.0: + with torch.set_grad_enabled(is_training): + mmodel = model.module if hasattr(model, "module") else model + # Note: We need to generate an unsorted version of token_ids + # `encode_supervisions()` called above sorts text, but + # encoder_memory and memory_mask are not sorted, so we + # use an unsorted version `supervisions["text"]` to regenerate + # the token_ids + # + # See https://github.com/k2-fsa/icefall/issues/97 + # for more details + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + warmup=warmup, + ) + else: + att_loss = torch.tensor([0]) + + ctc_loss_is_finite = torch.isfinite(ctc_loss) + att_loss_is_finite = torch.isfinite(att_loss) + if torch.any(~ctc_loss_is_finite) or torch.any(~att_loss_is_finite): + logging.info( + "Not all losses are finite!\n" + f"ctc_loss: {ctc_loss}\n" + f"att_loss: {att_loss}" + ) + display_and_save_batch(batch, params=params, sp=graph_compiler.sp) + ctc_loss = ctc_loss[ctc_loss_is_finite] + att_loss = att_loss[att_loss_is_finite] + + # If the batch contains more than 10 utterances AND + # if either all ctc_loss or att_loss is inf or nan, + # we stop the training process by raising an exception + if torch.all(~ctc_loss_is_finite) or torch.all(~att_loss_is_finite): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + + ctc_loss = ctc_loss.sum() + att_loss = att_loss.sum() + + if params.att_rate > 0.0: + loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + else: + loss = ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. + info["frames"] = ( + torch.div(feature_lens, params.subsampling_factor, rounding_mode="floor") + .sum() + .item() + ) + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate > 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[torch.nn.Module, DDP], + graph_compiler: BpeCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch in valid_dl: + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[torch.nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: BpeCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[torch.nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=graph_compiler.sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + if "lang_bpe" not in str(params.lang_dir): + raise ValueError( + f"Unsupported type of lang dir (we expected it to have " + f"'lang_bpe' in its name): {params.lang_dir}" + ) + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + d_model=params.dim_model, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + num_decoder_layers=params.num_decoder_layers, + ) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[torch.nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = optim.Eve(model.parameters(), lr=params.initial_lr) + scheduler = optim.Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and checkpoints.get("optimizer") is not None: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if checkpoints and checkpoints.get("scheduler") is not None: + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + tedlium = TedLiumAsrDataModule(args) + + train_cuts = tedlium.train_cuts() + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = tedlium.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = tedlium.dev_cuts() + valid_dl = tedlium.valid_dataloaders(valid_cuts) + + if ( + params.start_epoch <= 1 + and params.start_batch <= 0 + and not params.print_diagnostics + ): + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + train_dl.dataset.epoch = epoch - 1 + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: Union[torch.nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + params: AttributeDict, + warmup: float, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + warmup=warmup, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=graph_compiler.sp) + raise + + +def main(): + parser = get_parser() + TedLiumAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +# The flag below controls whether to allow TF32 on matmul. This flag defaults to False +# in PyTorch 1.12 and later. +torch.backends.cuda.matmul.allow_tf32 = True + +if __name__ == "__main__": + main() diff --git a/egs/tedlium3/ASR/conformer_ctc2/transformer.py b/egs/tedlium3/ASR/conformer_ctc2/transformer.py new file mode 100644 index 000000000..9dbf32e48 --- /dev/null +++ b/egs/tedlium3/ASR/conformer_ctc2/transformer.py @@ -0,0 +1,1093 @@ +# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Copyright 2022 Xiaomi Corp. (author: Quandong Wang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn as nn +from attention import MultiheadAttention +from combiner import RandomCombine +from label_smoothing import LabelSmoothingLoss +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledEmbedding, + ScaledLinear, +) +from subsampling import Conv2dSubsampling +from torch.nn.utils.rnn import pad_sequence + +# Note: TorchScript requires Dict/List/etc. to be fully typed. +Supervisions = Dict[str, torch.Tensor] + + +class Transformer(nn.Module): + def __init__( + self, + num_features: int, + num_classes: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + num_decoder_layers: int = 6, + dropout: float = 0.1, + layer_dropout: float = 0.075, + aux_layer_period: int = 3, + ) -> None: + """ + Args: + num_features: + the input dimension of the model. + num_classes: + the output dimension of the model. + subsampling_factor: + number of output frames is num_in_frames // subsampling_factor; + currently, subsampling_factor MUST be 4. + d_model: + attention dimension. + nhead: + number of heads in multi-head attention; + must satisfy d_model // nhead == 0. + dim_feedforward: + the output dimension of the feedforward layers in encoder/decoder. + num_encoder_layers: + number of encoder layers. + num_decoder_layers: + number of decoder layers. + dropout: + dropout in encoder/decoder. + layer_dropout: + layer-dropout rate. + aux_layer_period: + determines the auxiliary encoder layers. + """ + super().__init__() + + self.num_features = num_features + self.num_classes = num_classes + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_classes) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_classes -> d_model + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + + self.encoder_pos = PositionalEncoding(d_model, dropout) + + encoder_layer = TransformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + layer_dropout=layer_dropout, + ) + # aux_layers from 1/3 + self.encoder = TransformerEncoder( + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + aux_layers=list( + range( + num_encoder_layers // 3, + num_encoder_layers - 1, + aux_layer_period, + ) + ), + ) + + # TODO(fangjun): remove dropout + self.encoder_output_layer = nn.Sequential( + nn.Dropout(p=dropout), ScaledLinear(d_model, num_classes, bias=True) + ) + + if num_decoder_layers > 0: + self.decoder_num_class = ( + self.num_classes + ) # bpe model already has sos/eos symbol + + self.decoder_embed = ScaledEmbedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model + ) + self.decoder_pos = PositionalEncoding(d_model, dropout) + + decoder_layer = TransformerDecoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, + ) + + self.decoder = TransformerDecoder( + decoder_layer=decoder_layer, + num_layers=num_decoder_layers, + aux_layers=[], + ) + + self.decoder_output_layer = ScaledLinear( + d_model, self.decoder_num_class, bias=True + ) + + self.decoder_criterion = LabelSmoothingLoss(reduction="none") + else: + self.decoder_criterion = None + + def forward( + self, + x: torch.Tensor, + supervision: Optional[Supervisions] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (N, S, C). + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + warmup: + a floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + + Returns: + Return a tuple containing 3 tensors: + - CTC output for ctc decoding. Its shape is (N, S, C) + - Encoder output with shape (S, N, C). It can be used as key and + value for the decoder. + - Encoder output padding mask. It can be used as + memory_key_padding_mask for the decoder. Its shape is (N, S). + It is None if `supervision` is None. + """ + + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision, warmup + ) + + x = self.ctc_output(encoder_memory) + return x, encoder_memory, memory_key_padding_mask + + def run_encoder( + self, + x: torch.Tensor, + supervisions: Optional[Supervisions] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Run the transformer encoder. + + Args: + x: + The model input. Its shape is (N, S, C). + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute the encoder padding mask, which is used as memory key + padding mask for the decoder. + warmup: + a floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + + Returns: + Return a tuple with two tensors: + - The encoder output, with shape (S, N, C) + - encoder padding mask, with shape (N, S). + The mask is None if `supervisions` is None. + It is used as memory key padding mask in the decoder. + """ + x = self.encoder_embed(x) + x = self.encoder_pos(x) + x = x.permute(1, 0, 2) # (N, S, C) -> (S, N, C) + mask = encoder_padding_mask(x.size(0), supervisions) + mask = mask.to(x.device) if mask is not None else None + x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup) # (S, N, C) + + return x, mask + + def ctc_output(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: + the output tensor from the transformer encoder; + its shape is (S, N, C) + + Returns: + Return a tensor that can be used for CTC decoding. + Its shape is (N, S, C) + """ + x = self.encoder_output_layer(x) + x = x.permute(1, 0, 2) # (S, N, C) -> (N, S, C) + x = nn.functional.log_softmax(x, dim=-1) # (N, S, C) + return x + + @torch.jit.export + def decoder_forward( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, + warmup: float = 1.0, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder of shape (S, N, C) + memory_key_padding_mask: + The padding mask from the encoder of shape (N, S). + token_ids: + A list-of-list IDs. Each sublist contains IDs for an utterance. + The IDs can be either phone IDs or word piece IDs. + sos_id: + sos token id + eos_id: + eos token id + warmup: + a floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + + Returns: + A scalar, the **sum** of label smoothing loss over utterances + in the batch without any normalization. + """ + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + + device = memory.device + ys_in_pad = ys_in_pad.to(device) + ys_out_pad = ys_out_pad.to(device) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + warmup=warmup, + ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) + + decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) + + return decoder_loss + + @torch.jit.export + def decoder_nll( + self, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[torch.Tensor], + sos_id: int, + eos_id: int, + warmup: float = 1.0, + ) -> torch.Tensor: + """ + Args: + memory: + It's the output of the encoder of shape (S, N, C). + memory_key_padding_mask: + The padding mask from the encoder of shape (N, S). + token_ids: + A list-of-list IDs (e.g., word piece IDs). + Each sublist represents an utterance. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + warmup: + a floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + + Returns: + A 2-D tensor of shape (len(token_ids), max_token_length) + representing the cross entropy loss (i.e., negative log-likelihood). + """ + # The common part between this function and decoder_forward could be + # extracted as a separate function. + if isinstance(token_ids[0], torch.Tensor): + # This branch is executed by torchscript in C++. + # See https://github.com/k2-fsa/k2/pull/870 + # https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286 + token_ids = [tolist(t) for t in token_ids] + + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + + device = memory.device + ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) + + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + + tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) + # TODO: Use length information to create the decoder padding mask + # We set the first column to False since the first column in ys_in_pad + # contains sos_id, which is the same as eos_id in our current setting. + tgt_key_padding_mask[:, 0] = False + + tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) + tgt = self.decoder_pos(tgt) + tgt = tgt.permute(1, 0, 2) # (N, T, ะก) -> (T, N, C) + pred_pad = self.decoder( + tgt=tgt, + memory=memory, + tgt_mask=tgt_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + warmup=warmup, + ) # (T, B, F) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) + # nll: negative log-likelihood + nll = torch.nn.functional.cross_entropy( + pred_pad.view(-1, self.decoder_num_class), + ys_out_pad.view(-1), + ignore_index=-1, + reduction="none", + ) + + nll = nll.view(pred_pad.shape[0], -1) + + return nll + + +class TransformerEncoderLayer(nn.Module): + """ + Modified from torch.nn.TransformerEncoderLayer. + + Example: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> out = encoder_layer(src) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + bypass_scale: float = 0.1, + layer_dropout: float = 0.075, + ) -> None: + """ + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + bypass_scale: + a scale on the layer's output, used in bypass (resnet-type) skip-connection; + when the layer is bypassed the final output will be a + weighted sum of the layer's input and layer's output with weights + (1.0-bypass_scale) and bypass_scale correspondingly (default=0.1). + layer_dropout: + the probability to bypass the layer (default=0.075). + """ + + super().__init__() + + if bypass_scale < 0.0 or bypass_scale > 1.0: + raise ValueError("bypass_scale should be between 0.0 and 1.0") + + if layer_dropout < 0.0 or layer_dropout > 1.0: + raise ValueError("layer_dropout should be between 0.0 and 1.0") + + self.bypass_scale = bypass_scale + self.layer_dropout = layer_dropout + + self.self_attn = MultiheadAttention(d_model, nhead) + # Implementation of Feedforward model + + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> torch.Tensor: + """ + Pass the input through the encoder layer. + + Args: + src: + the sequence to the encoder layer of shape (S, N, C) (required). + src_mask: + the mask for the src sequence of shape (S, S) (optional). + src_key_padding_mask: + the mask for the src keys per batch of shape (N, S) (optional) + warmup: + controls selective bypass of layers; if < 1.0, we will + bypass the layer more frequently (default=1.0). + + Returns: + Output tensor of the shape (S, N, C), where + S is the source sequence length, + N is the batch size, + C is the feature number. + + """ + src_orig = src + + warmup_scale = min(self.bypass_scale + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else self.bypass_scale + ) + else: + alpha = 1.0 + + src_att = self.self_attn( + src, + src, + src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask, + )[0] + src = src + self.dropout(src_att) + + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1.0 - alpha) * src_orig + + return src + + +class TransformerDecoderLayer(nn.Module): + """Modified from torch.nn.TransformerDecoderLayer. + + Example: + >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = decoder_layer(tgt, memory) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + bypass_scale: float = 0.1, + layer_dropout: float = 0.075, + ) -> None: + + """ + Args: + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + bypass_scale: + a scale on the layer's output, used in bypass (resnet-type) skip-connection; + when the layer is bypassed, the final output will be a + weighted sum of the layer's input and layer's output with weights + (1.0-bypass_scale) and bypass_scale correspondingly (default=0.1). + layer_dropout: + the probability to bypass the layer (default=0.075). + """ + + super().__init__() + + if bypass_scale < 0.0 or bypass_scale > 1.0: + raise ValueError("bypass_scale should be between 0.0 and 1.0") + + if layer_dropout < 0.0 or layer_dropout > 1.0: + raise ValueError("layer_dropout should be between 0.0 and 1.0") + + self.bypass_scale = bypass_scale + self.layer_dropout = layer_dropout + + self.self_attn = MultiheadAttention(d_model, nhead) + self.src_attn = MultiheadAttention(d_model, nhead) + + # Implementation of Feedforward model + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> torch.Tensor: + """Pass the inputs (and mask) through the decoder layer. + + Args: + tgt: + the sequence to the decoder layer of shape (T, N, C) (required). + memory: + the sequence from the last layer of the encoder of shape (S, N, C) (required). + tgt_mask: + the mask for the tgt sequence of shape (T, T) (optional). + memory_mask: + the mask for the memory sequence of shape (T, S) (optional). + tgt_key_padding_mask: + the mask for the tgt keys per batch of shape (N, T) (optional). + memory_key_padding_mask: + the mask for the memory keys per batch of shape (N, S) (optional). + warmup: controls selective bypass of layers; if < 1.0, we will + bypass the layer more frequently (default=1.0). + + Returns: + Output tensor of the shape (T, N, C), where + S is the source sequence length, + T is the target sequence length, + N is the batch size, + C is the feature number. + + """ + tgt_orig = tgt + + warmup_scale = min(self.bypass_scale + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else self.bypass_scale + ) + else: + alpha = 1.0 + + tgt_att = self.self_attn( + tgt, + tgt, + tgt, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask, + )[0] + tgt = tgt + self.dropout(tgt_att) + + src_att = self.src_attn( + tgt, + memory, + memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask, + )[0] + tgt = tgt + self.dropout(src_att) + + tgt = tgt + self.dropout(self.feed_forward(tgt)) + + tgt = self.norm_final(self.balancer(tgt)) + + if alpha != 1.0: + tgt = alpha * tgt + (1.0 - alpha) * tgt_orig + + return tgt + + +class TransformerEncoder(nn.Module): + """TransformerEncoder is a stack of N encoder layers + + Examples: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + aux_layers: List[int], + ) -> None: + """ + Args: + encoder_layer: + an instance of the TransformerEncoderLayer() class (required). + num_layers: + the number of sub-encoder-layers in the encoder (required). + aux_layers: + list of indexes of sub-encoder-layers outputs to be combined (required). + """ + + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + assert len(set(aux_layers)) == len(aux_layers) + + assert num_layers - 1 not in aux_layers + self.aux_layers = aux_layers + [num_layers - 1] + + self.combiner = RandomCombine( + num_inputs=len(self.aux_layers), + final_weight=0.5, + pure_prob=0.333, + stddev=2.0, + ) + + def forward( + self, + src: torch.Tensor, + mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> torch.Tensor: + """Pass the input through the encoder layers in turn. + + Args: + src: + the input to the encoder of shape (S, N, C) (required). + mask: + the mask for the src sequence of shape (S, S) (optional). + src_key_padding_mask: + the mask for the src keys per batch of shape (N, S) (optional). + warmup: + controls selective bypass of layer; if < 1.0, we will + bypass the layer more frequently (default=1.0). + + Returns: + Output tensor of the shape (S, N, C), where + S is the source sequence length, + N is the batch size, + C is the feature number. + + """ + output = src + + outputs = [] + for i, mod in enumerate(self.layers): + output = mod( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + warmup=warmup, + ) + + if i in self.aux_layers: + outputs.append(output) + + output = self.combiner(outputs) + + return output + + +class TransformerDecoder(nn.Module): + """TransformerDecoder is a stack of N decoder layers + + Examples: + >>> decoder_layer = TransformerDecoderLayer(d_model=512, nhead=8) + >>> transformer_decoder = TransformerDecoder(decoder_layer, num_layers=6) + >>> memory = torch.rand(10, 32, 512) + >>> tgt = torch.rand(20, 32, 512) + >>> out = transformer_decoder(tgt, memory) + """ + + def __init__( + self, + decoder_layer: nn.Module, + num_layers: int, + aux_layers: List[int], + ) -> None: + """ + Args: + decoder_layer: + an instance of the TransformerDecoderLayer() class (required). + num_layers: + the number of decoder layers in the decoder (required). + aux_layers: + list of indexes of decoder layer outputs to be combined (required). + """ + + super().__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(decoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + assert len(set(aux_layers)) == len(aux_layers) + + assert num_layers - 1 not in aux_layers + self.aux_layers = aux_layers + [num_layers - 1] + + self.combiner = RandomCombine( + num_inputs=len(self.aux_layers), + final_weight=0.5, + pure_prob=0.333, + stddev=2.0, + ) + + def forward( + self, + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> torch.Tensor: + """Pass the input (and mask) through the decoder layers in turn. + + Args: + tgt: + the sequence to the decoder of shape (T, N, C) (required). + memory: + the sequence from the last layer of the encoder of shape (S, N, C) (required). + tgt_mask: + the mask for the tgt sequence of shape (T, T) (optional). + memory_mask: + the mask for the memory sequence of shape (T, S) (optional). + tgt_key_padding_mask: + the mask for the tgt keys per batch of shape (N, T) (optional). + memory_key_padding_mask: + the mask for the memory keys per batch of shape (N, S) (optional). + warmup: + controls selective bypass of layer; if < 1.0, we will + bypass the layer more frequently (default=1.0). + + Returns: + Output tensor of the shape (T, N, C), where + S is the source sequence length, + T is the target sequence length, + N is the batch size, + C is the feature number. + + """ + output = tgt + + outputs = [] + for i, mod in enumerate(self.layers): + output = mod( + output, + memory, + tgt_mask=tgt_mask, + memory_mask=memory_mask, + tgt_key_padding_mask=tgt_key_padding_mask, + memory_key_padding_mask=memory_key_padding_mask, + warmup=warmup, + ) + + if i in self.aux_layers: + outputs.append(output) + + output = self.combiner(outputs) + + return output + + +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf + + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) + + Note: + + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) + """ + + def __init__(self, d_model: int, dropout: float = 0.1) -> None: + """ + Args: + d_model: Embedding dimension. + dropout: Dropout probability to be applied to the output of this module. + """ + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = nn.Dropout(p=dropout) + # not doing: self.pe = None because of errors thrown by torchscript + self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) + + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + The shape of `self.pe` is (1, T1, d_model). The shape of the input x + is (N, T, d_model). If T > T1, then we change the shape of self.pe + to (N, T, d_model). Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape (N, T, C). + T is the target sequence length, + N is the batch size, + C is the feature number. + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + # Now pe is of shape (1, T, d_model), where T is x.size(1) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Add positional encoding. + + Args: + x: Input of shape is (N, T, C) + + Returns: + A tensor of the same shape (N, T, C), + T is the target sequence length, + N is the batch size, + C is the feature number. + + """ + self.extend_pe(x) + x = x + self.pe[:, : x.size(1), :] + return self.dropout(x) + + +def encoder_padding_mask( + max_len: int, supervisions: Optional[Supervisions] = None +) -> Optional[torch.Tensor]: + """Make mask tensor containing indexes of padded part. + + TODO: + This function **assumes** that the model uses + a subsampling factor of 4. We should remove that + assumption later. + + Args: + max_len: + Maximum length of input features. + CAUTION: It is the length after subsampling. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) + + Returns: + Mask tensor of dimension (batch_size, input_length), + True denotes the masked indices. + """ + if supervisions is None: + return None + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"], + supervisions["num_frames"], + ), + 1, + ).to(torch.int32) + + lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + for idx in range(supervision_segments.size(0)): + # Note: TorchScript doesn't allow to unpack tensors as tuples + sequence_idx = supervision_segments[idx, 0].item() + start_frame = supervision_segments[idx, 1].item() + num_frames = supervision_segments[idx, 2].item() + lengths[sequence_idx] = start_frame + num_frames + + lengths = [((i - 1) // 2 - 1) // 2 for i in lengths] + bs = int(len(lengths)) + seq_range = torch.arange(0, max_len, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len) + # Note: TorchScript doesn't implement Tensor.new() + seq_length_expand = torch.tensor( + lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype + ).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + return mask + + +def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: + """Generate a length mask for input. + + The masked position are filled with True, + Unmasked positions are filled with False. + + Args: + ys_pad: + padded tensor of dimension (batch_size, input_length). + ignore_id: + the ignored number (the padding number) in ys_pad + + Returns: + A bool tensor of the same shape as the input tensor. + """ + ys_mask = ys_pad == ignore_id + return ys_mask + + +def generate_square_subsequent_mask(sz: int) -> torch.Tensor: + """Generate a square mask for the sequence. The masked positions are + filled with float('-inf'). Unmasked positions are filled with float(0.0). + The mask can be used for masked self-attention. + + For instance, if sz is 3, it returns:: + + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0]]) + + Args: + sz: mask size + + Returns: + A square mask tensor of dimension (sz, sz) + """ + mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) + mask = ( + mask.float() + .masked_fill(mask == 0, float("-inf")) + .masked_fill(mask == 1, float(0.0)) + ) + return mask + + +def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: + """Prepend sos_id to each utterance. + + Args: + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + sos_id: + The ID of the SOS token. + + Return: + Return a new list-of-list, where each sublist starts + with SOS ID. + """ + return [[sos_id] + utt for utt in token_ids] + + +def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: + """Append eos_id to each utterance. + + Args: + token_ids: + A list-of-lists of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + eos_id: + The ID of the EOS token. + + Return: + Return a new list-of-lists, where each sublist ends + with EOS ID. + """ + return [utt + [eos_id] for utt in token_ids] + + +def tolist(t: torch.Tensor) -> List[int]: + """Used by jit""" + return torch.jit.annotate(List[int], t.tolist()) diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py index 9dbcc9d9e..19ba8d24b 100644 --- a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py +++ b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py @@ -4,16 +4,18 @@ """ Convert a transcript based on words to a list of BPE ids. -For example, if we use 2 as the encoding id of : +For example, if we use 2 as the encoding id of +Note: it, inserts a space token before each texts = ['this is a day'] -spm_ids = [[38, 33, 6, 2, 316]] +spm_ids = [[38, 33, 6, 15, 2, 316]] texts = [' this is a sunny day'] -spm_ids = [[2, 38, 33, 6, 118, 11, 11, 21, 316]] +spm_ids = [[15, 2, 38, 33, 6, 118, 11, 11, 21, 316]] texts = [''] -spm_ids = [[2]] +spm_ids = [[15, 2]] + """ import argparse @@ -38,29 +40,27 @@ def get_args(): def convert_texts_into_ids( texts: List[str], - unk_id: int, sp: spm.SentencePieceProcessor, ) -> List[List[int]]: """ Args: texts: A string list of transcripts, such as ['Today is Monday', 'It's sunny']. - unk_id: - A number id for the token ''. + sp: + A sentencepiece BPE model. Returns: Return an integer list of bpe ids. """ y = [] for text in texts: - y_ids = [] if "" in text: - text_segments = text.split("") - id_segments = sp.encode(text_segments, out_type=int) + id_segments = sp.encode(text.split(""), out_type=int) + + y_ids = [] for i in range(len(id_segments)): - if i != len(id_segments) - 1: - y_ids.extend(id_segments[i] + [unk_id]) - else: - y_ids.extend(id_segments[i]) + y_ids += id_segments[i] + if i < len(id_segments) - 1: + y_ids += [sp.piece_to_id("โ–"), sp.unk_id()] else: y_ids = sp.encode(text, out_type=int) y.append(y_ids) @@ -70,19 +70,13 @@ def convert_texts_into_ids( def main(): args = get_args() - texts = args.texts - bpe_model = args.bpe_model sp = spm.SentencePieceProcessor() - sp.load(bpe_model) - unk_id = sp.piece_to_id("") + sp.load(args.bpe_model) - y = convert_texts_into_ids( - texts=texts, - unk_id=unk_id, - sp=sp, - ) - logging.info(f"The input texts: {texts}") + y = convert_texts_into_ids(texts=args.texts, sp=sp) + + logging.info(f"The input texts: {args.texts}") logging.info(f"The encoding ids: {y}") diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_tokens.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_tokens.py deleted file mode 120000 index 2ce13fd69..000000000 --- a/egs/tedlium3/ASR/local/convert_transcript_words_to_tokens.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/local/generate_unique_lexicon.py b/egs/tedlium3/ASR/local/generate_unique_lexicon.py deleted file mode 120000 index c0aea1403..000000000 --- a/egs/tedlium3/ASR/local/generate_unique_lexicon.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/generate_unique_lexicon.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/local/prepare_lang.py b/egs/tedlium3/ASR/local/prepare_lang.py deleted file mode 120000 index 747f2ab39..000000000 --- a/egs/tedlium3/ASR/local/prepare_lang.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/local/prepare_lexicon.py b/egs/tedlium3/ASR/local/prepare_lexicon.py deleted file mode 100755 index b9160b6d4..000000000 --- a/egs/tedlium3/ASR/local/prepare_lexicon.py +++ /dev/null @@ -1,94 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This script takes as input supervisions json dir "data/manifests" -consisting of supervisions_train.json and does the following: - -1. Generate lexicon_words.txt. - -""" -import argparse -import logging -from pathlib import Path - -import lhotse - - -def get_args(): - parser = argparse.ArgumentParser() - parser.add_argument( - "--manifests-dir", - type=str, - help="""Input directory. - """, - ) - parser.add_argument( - "--lang-dir", - type=str, - help="""Output directory. - """, - ) - - return parser.parse_args() - - -def prepare_lexicon(manifests_dir: str, lang_dir: str): - """ - Args: - manifests_dir: - The manifests directory, e.g., data/manifests. - lang_dir: - The language directory, e.g., data/lang_phone. - - Return: - The lexicon_words.txt file. - """ - words = set() - - lexicon = Path(lang_dir) / "lexicon_words.txt" - sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz") - for s in sups: - # list the words units and filter the empty item - words_list = list(filter(None, s.text.split())) - - for word in words_list: - if word not in words and word != "": - words.add(word) - - with open(lexicon, "w") as f: - for word in sorted(words): - f.write(word + " " + word) - f.write("\n") - - -def main(): - args = get_args() - manifests_dir = Path(args.manifests_dir) - lang_dir = Path(args.lang_dir) - - logging.info("Generating lexicon_words.txt") - prepare_lexicon(manifests_dir, lang_dir) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - main() diff --git a/egs/tedlium3/ASR/local/prepare_transcripts.py b/egs/tedlium3/ASR/local/prepare_transcripts.py index 7ea4e89a4..d4ccdd1e3 100755 --- a/egs/tedlium3/ASR/local/prepare_transcripts.py +++ b/egs/tedlium3/ASR/local/prepare_transcripts.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Mingshuang Luo) +# Copyright 2021 Xiaomi Corp. (author: Mingshuang Luo) +# Copyright 2022 Behavox LLC. (author: Daniil Kulko) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -17,68 +18,67 @@ """ -This script takes as input supervisions json dir "data/manifests" -consisting of supervisions_train.json and does the following: - -1. Generate train.text. +This script takes input text file and removes all words +that iclude any character out of English alphabet. """ import argparse import logging +import re from pathlib import Path -import lhotse - def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "--manifests-dir", + "--input-text-path", type=str, - help="""Input directory. - """, + help="Input text file path.", ) parser.add_argument( - "--lang-dir", + "--output-text-path", type=str, - help="""Output directory. - """, + help="Output text file path.", ) return parser.parse_args() -def prepare_transcripts(manifests_dir: str, lang_dir: str): +def prepare_transcripts(input_text_path: Path, output_text_path: Path) -> None: """ Args: - manifests_dir: - The manifests directory, e.g., data/manifests. - lang_dir: - The language directory, e.g., data/lang_phone. + input_text_path: + The input data text file path, e.g., data/lang/train_orig.txt. + output_text_path: + The output data text file path, e.g., data/lang/train.txt. Return: - The train.text in lang_dir. + Saved text file in output_text_path. """ - texts = [] - train_text = Path(lang_dir) / "train.text" - sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz") - for s in sups: - texts.append(s.text) + foreign_chr_check = re.compile(r"[^a-z']") - with open(train_text, "w") as f: - for text in texts: - f.write(text) - f.write("\n") + logging.info(f"Loading {input_text_path.name}") + with open(input_text_path, "r", encoding="utf8") as f: + texts = {t.rstrip("\n") for t in f} + + texts = { + " ".join([w for w in t.split() if foreign_chr_check.search(w) is None]) + for t in texts + } + + with open(output_text_path, "w+", encoding="utf8") as f: + for t in texts: + f.write(f"{t}\n") -def main(): +def main() -> None: args = get_args() - manifests_dir = Path(args.manifests_dir) - lang_dir = Path(args.lang_dir) + input_text_path = Path(args.input_text_path) + output_text_path = Path(args.output_text_path) - logging.info("Generating train.text") - prepare_transcripts(manifests_dir, lang_dir) + logging.info(f"Generating {output_text_path.name}") + prepare_transcripts(input_text_path, output_text_path) if __name__ == "__main__": diff --git a/egs/tedlium3/ASR/local/prepare_words.py b/egs/tedlium3/ASR/local/prepare_words.py new file mode 100755 index 000000000..a37d0f08f --- /dev/null +++ b/egs/tedlium3/ASR/local/prepare_words.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +# Copyright 2022 Behavox LLC. (authors: Daniil Kulko) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input supervisions json dir "data/manifests" +consisting of tedlium_supervisions_train.json and does the following: + +1. Generate words.txt. + +""" +import argparse +import logging +import re +from pathlib import Path + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="Output directory.", + ) + + return parser.parse_args() + + +def prepare_words(lang_dir: str) -> None: + """ + Args: + lang_dir: + The language directory, e.g., data/lang. + + Return: + The words.txt file. + """ + + words_orig_path = Path(lang_dir) / "words_orig.txt" + words_path = Path(lang_dir) / "words.txt" + + foreign_chr_check = re.compile(r"[^a-z']") + + logging.info(f"Loading {words_orig_path.name}") + with open(words_orig_path, "r", encoding="utf8") as f: + words = {w for w_compl in f for w in w_compl.strip("-\n").split("_")} + words = {w for w in words if foreign_chr_check.search(w) is None and w != ""} + words.add("") + words = ["", "!SIL"] + sorted(words) + ["#0", "", ""] + + with open(words_path, "w+", encoding="utf8") as f: + for idx, word in enumerate(words): + f.write(f"{word} {idx}\n") + + +def main() -> None: + args = get_args() + lang_dir = Path(args.lang_dir) + + logging.info("Generating words.txt") + prepare_words(lang_dir) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/tedlium3/ASR/local/test_prepare_lang.py b/egs/tedlium3/ASR/local/test_prepare_lang.py deleted file mode 120000 index f0f864998..000000000 --- a/egs/tedlium3/ASR/local/test_prepare_lang.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/local/test_prepare_lang.py \ No newline at end of file diff --git a/egs/tedlium3/ASR/prepare.sh b/egs/tedlium3/ASR/prepare.sh index 272cf7aed..3d90436ff 100755 --- a/egs/tedlium3/ASR/prepare.sh +++ b/egs/tedlium3/ASR/prepare.sh @@ -5,7 +5,6 @@ export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python set -eou pipefail -nj=15 stage=0 stop_stage=100 @@ -63,6 +62,13 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then mv $dl_dir/TEDLIUM_release-3 $dl_dir/tedlium3 fi + # Download big and small 4 gram lanuage models + if [ ! -d $dl_dir/lm ]; then + wget --continue http://kaldi-asr.org/models/5/4gram_small.arpa.gz -P $dl_dir/lm + wget --continue http://kaldi-asr.org/models/5/4gram_big.arpa.gz -P $dl_dir/lm + gzip -d $dl_dir/lm/4gram_small.arpa.gz $dl_dir/lm/4gram_big.arpa.gz + fi + # If you have pre-downloaded it to /path/to/musan, # you can create a symlink # @@ -100,7 +106,14 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then if [ ! -e data/fbank/.tedlium3.done ]; then mkdir -p data/fbank + python3 ./local/compute_fbank_tedlium.py + + gunzip -c data/fbank/tedlium_cuts_train.jsonl.gz | shuf | \ + gzip -c > data/fbank/tedlium_cuts_train-shuf.jsonl.gz + mv data/fbank/tedlium_cuts_train-shuf.jsonl.gz \ + data/fbank/tedlium_cuts_train.jsonl.gz + touch data/fbank/.tedlium3.done fi fi @@ -115,28 +128,24 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then - log "Stage 5: Prepare phone based lang" - lang_dir=data/lang_phone + log "Stage 5: Prepare BPE train data and set of words" + lang_dir=data/lang mkdir -p $lang_dir - if [ ! -f $lang_dir/train.text ]; then + if [ ! -f $lang_dir/train.txt ]; then + gunzip -c $dl_dir/tedlium3/LM/*.en.gz | sed 's: <\/s>::g' > $lang_dir/train_orig.txt + ./local/prepare_transcripts.py \ - --lang-dir $lang_dir \ - --manifests-dir data/manifests + --input-text-path $lang_dir/train_orig.txt \ + --output-text-path $lang_dir/train.txt fi - if [ ! -f $lang_dir/lexicon_words.txt ]; then - ./local/prepare_lexicon.py \ - --lang-dir $lang_dir \ - --manifests-dir data/manifests - fi + if [ ! -f $lang_dir/words.txt ]; then - (echo '!SIL SIL'; echo ' '; ) | - cat - $lang_dir/lexicon_words.txt | - sort | uniq > $lang_dir/lexicon.txt + awk '{print $1}' $dl_dir/tedlium3/TEDLIUM.152k.dic | + sed 's:([0-9])::g' | sort | uniq > $lang_dir/words_orig.txt - if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang.py --lang-dir $lang_dir + ./local/prepare_words.py --lang-dir $lang_dir fi fi @@ -148,25 +157,56 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then mkdir -p $lang_dir # We reuse words.txt from phone based lexicon # so that the two can share G.pt later. - cp data/lang_phone/words.txt $lang_dir - - if [ ! -f $lang_dir/transcript_words.txt ]; then - log "Generate data for BPE training" - cat data/lang_phone/train.text | - cut -d " " -f 2- > $lang_dir/transcript_words.txt - # remove the for transcript_words.txt - sed -i 's/ //g' $lang_dir/transcript_words.txt - sed -i 's/ //g' $lang_dir/transcript_words.txt - sed -i 's///g' $lang_dir/transcript_words.txt - fi + cp data/lang/words.txt $lang_dir ./local/train_bpe_model.py \ --lang-dir $lang_dir \ --vocab-size $vocab_size \ - --transcript $lang_dir/transcript_words.txt + --transcript data/lang/train.txt if [ ! -f $lang_dir/L_disambig.pt ]; then - ./local/prepare_lang_bpe.py --lang-dir $lang_dir + ./local/prepare_lang_bpe.py --lang-dir $lang_dir --oov "" + fi + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare G" + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p data/lm + if [ ! -f data/lm/G_4_gram_small.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table="data/lang/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + --max-arpa-warnings=-1 \ + $dl_dir/lm/4gram_small.arpa > data/lm/G_4_gram_small.fst.txt + fi + + if [ ! -f data/lm/G_4_gram_big.fst.txt ]; then + # It is used for LM rescoring + python3 -m kaldilm \ + --read-symbol-table="data/lang/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + --max-arpa-warnings=-1 \ + $dl_dir/lm/4gram_big.arpa > data/lm/G_4_gram_big.fst.txt + fi +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Compile HLG" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/HLG.pt ]; then + ./local/compile_hlg.py \ + --lang-dir $lang_dir \ + --lm G_4_gram_small fi done fi diff --git a/icefall/decode.py b/icefall/decode.py index 68e490c5e..23f9fb9b3 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -466,9 +466,7 @@ def one_best_decoding( Return: An FsaVec containing linear paths. """ - if lm_scale_list is not None: - ans = dict() saved_am_scores = lattice.scores - lattice.lm_scores for lm_scale in lm_scale_list: diff --git a/test/test_lexicon.py b/test/test_lexicon.py index 69867efc7..b1beab3f6 100755 --- a/test/test_lexicon.py +++ b/test/test_lexicon.py @@ -112,7 +112,7 @@ def uniq_lexicon_test(): # But there is no word "ca" in the lexicon, so our # implementation returns the id of "" print(token_ids, expected_token_ids) - assert token_ids.tolist() == [[sp.unk_id()]] + assert token_ids.tolist() == [[sp.piece_to_id("โ–"), sp.unk_id()]] # case 3: With OOV texts = ["foo"]