From 53454701cb69ec23be8a37c6ab69f1cf5104585d Mon Sep 17 00:00:00 2001 From: marcoyang Date: Tue, 22 Nov 2022 11:39:21 +0800 Subject: [PATCH 01/27] fix segmentation fault --- egs/aidatatang_200zh/ASR/prepare.sh | 3 +++ egs/aishell/ASR/prepare.sh | 3 +++ egs/aishell2/ASR/prepare.sh | 3 +++ egs/aishell4/ASR/prepare.sh | 3 +++ egs/alimeeting/ASR/prepare.sh | 3 +++ egs/csj/ASR/prepare.sh | 3 +++ egs/gigaspeech/ASR/prepare.sh | 3 +++ egs/librispeech/ASR/prepare.sh | 3 +++ egs/ptb/LM/prepare.sh | 3 +++ egs/spgispeech/ASR/prepare.sh | 3 +++ egs/tal_csasr/ASR/prepare.sh | 3 +++ egs/tedlium3/ASR/prepare.sh | 3 +++ egs/timit/ASR/prepare.sh | 3 +++ egs/wenetspeech/ASR/prepare.sh | 3 +++ egs/yesno/ASR/prepare.sh | 3 +++ 15 files changed, 45 insertions(+) diff --git a/egs/aidatatang_200zh/ASR/prepare.sh b/egs/aidatatang_200zh/ASR/prepare.sh index 4749e1b7f..46ecd5769 100755 --- a/egs/aidatatang_200zh/ASR/prepare.sh +++ b/egs/aidatatang_200zh/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail stage=-1 diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index eaeecfc4a..5917668a1 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail nj=15 diff --git a/egs/aishell2/ASR/prepare.sh b/egs/aishell2/ASR/prepare.sh index 06810bfdd..3e8e840ab 100755 --- a/egs/aishell2/ASR/prepare.sh +++ b/egs/aishell2/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail nj=30 diff --git a/egs/aishell4/ASR/prepare.sh b/egs/aishell4/ASR/prepare.sh index c351e3964..cb2b73a3e 100755 --- a/egs/aishell4/ASR/prepare.sh +++ b/egs/aishell4/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail stage=-1 diff --git a/egs/alimeeting/ASR/prepare.sh b/egs/alimeeting/ASR/prepare.sh index 17224bb68..604cc92c6 100755 --- a/egs/alimeeting/ASR/prepare.sh +++ b/egs/alimeeting/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail stage=-1 diff --git a/egs/csj/ASR/prepare.sh b/egs/csj/ASR/prepare.sh index 052748ca6..c4ce91984 100755 --- a/egs/csj/ASR/prepare.sh +++ b/egs/csj/ASR/prepare.sh @@ -35,6 +35,9 @@ # can generate other transcript formats by supplying your own config files. A few examples of these # config files can be found in local/conf. +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail nj=8 diff --git a/egs/gigaspeech/ASR/prepare.sh b/egs/gigaspeech/ASR/prepare.sh index fd2532741..bd255dc6a 100755 --- a/egs/gigaspeech/ASR/prepare.sh +++ b/egs/gigaspeech/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail nj=15 diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 94e003036..8668af0e4 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail nj=15 diff --git a/egs/ptb/LM/prepare.sh b/egs/ptb/LM/prepare.sh index 70586785d..91c3c667a 100755 --- a/egs/ptb/LM/prepare.sh +++ b/egs/ptb/LM/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail nj=15 diff --git a/egs/spgispeech/ASR/prepare.sh b/egs/spgispeech/ASR/prepare.sh index 231ebd742..4842f52d0 100755 --- a/egs/spgispeech/ASR/prepare.sh +++ b/egs/spgispeech/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail nj=20 diff --git a/egs/tal_csasr/ASR/prepare.sh b/egs/tal_csasr/ASR/prepare.sh index 340521ad8..d9938fa63 100755 --- a/egs/tal_csasr/ASR/prepare.sh +++ b/egs/tal_csasr/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail stage=-1 diff --git a/egs/tedlium3/ASR/prepare.sh b/egs/tedlium3/ASR/prepare.sh index ccb307a52..272cf7aed 100755 --- a/egs/tedlium3/ASR/prepare.sh +++ b/egs/tedlium3/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail nj=15 diff --git a/egs/timit/ASR/prepare.sh b/egs/timit/ASR/prepare.sh index d11cd3a05..148a9f51b 100644 --- a/egs/timit/ASR/prepare.sh +++ b/egs/timit/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail num_phones=39 diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index da7d7e061..50a00253d 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail nj=15 diff --git a/egs/yesno/ASR/prepare.sh b/egs/yesno/ASR/prepare.sh index 8fcee0290..d4ef8d601 100755 --- a/egs/yesno/ASR/prepare.sh +++ b/egs/yesno/ASR/prepare.sh @@ -1,5 +1,8 @@ #!/usr/bin/env bash +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + set -eou pipefail stage=-1 From 4c636c2cfffd853a4dc1f618dd8a6fede78a3bea Mon Sep 17 00:00:00 2001 From: Senyan Li <1149593720@qq.com> Date: Fri, 25 Nov 2022 14:39:56 +0800 Subject: [PATCH 02/27] fix librispeech ASR pruned_transducer_stateless5 export (#704) --- egs/librispeech/ASR/pruned_transducer_stateless5/export.py | 2 ++ egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py | 1 + .../ASR/pruned_transducer_stateless5/scaling_converter.py | 1 + 3 files changed, 4 insertions(+) create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index a4fad1e59..54f656859 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -50,6 +50,7 @@ from pathlib import Path import sentencepiece as spm import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import ( @@ -263,6 +264,7 @@ def main(): # it here. # Otherwise, one of its arguments is a ragged tensor and is not # torch scriptabe. + convert_scaled_to_non_scaled(model, inplace=True) model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py b/egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py new file mode 120000 index 000000000..4f377cd01 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/lstmp.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py new file mode 120000 index 000000000..3b667058d --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file From 89c3982a0760f135740556ae67c11d0af434303c Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 26 Nov 2022 00:50:21 +0800 Subject: [PATCH 03/27] show dominant parameters --- .../ASR/pruned_transducer_stateless7/optim.py | 79 ++++++++++++++++--- .../ASR/pruned_transducer_stateless7/train.py | 13 ++- 2 files changed, 79 insertions(+), 13 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 8b90c9a0d..ab55381d7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -42,7 +42,7 @@ class BatchedOptimizer(Optimizer): super(BatchedOptimizer, self).__init__(params, defaults) @contextlib.contextmanager - def batched_params(self, param_group): + def batched_params(self, param_group, group_params_names=None): """ This function returns (technically, yields) a list of of tuples (p, state), where @@ -75,20 +75,28 @@ class BatchedOptimizer(Optimizer): batches = defaultdict( list ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches_names = defaultdict( + list + ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str - for p in param_group: + for p, named_p in zip(param_group, group_params_names): key = (str(p.dtype), *p.shape) batches[key].append(p) + batches_names[key].append(named_p) + + batches_names_keys = list(batches_names.keys()) + sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i]) + batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] + batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] stacked_params_dict = dict() # turn batches into a list, in deterministic order. - batches = [batches[key] for key in sorted(batches.keys())] # pairs will contain pairs of (stacked_param, state), one for each batch # in `batches`. pairs = [] - for batch in batches: + for batch, batch_names in zip(batches, batches_names): p = batch[0] # we arbitrarily store the state in the # state corresponding to the 1st parameter in the @@ -100,11 +108,11 @@ class BatchedOptimizer(Optimizer): ) p_stacked.grad = grad stacked_params_dict[key] = p_stacked - pairs.append((p_stacked, state)) + pairs.append((p_stacked, state, batch_names)) yield pairs # <-- calling code will do the actual optimization here! - for ((stacked_params, _state), batch) in zip(pairs, batches): + for ((stacked_params, _state, _names), batch) in zip(pairs, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -165,6 +173,8 @@ class ScaledAdam(BatchedOptimizer): scalar_max=10.0, size_update_period=4, clipping_update_period=100, + parameters_names=None, + show_dominant_parameters=False, ): defaults = dict( @@ -181,6 +191,8 @@ class ScaledAdam(BatchedOptimizer): ) super(ScaledAdam, self).__init__(params, defaults) + self.parameters_names = parameters_names + self.show_dominant_parameters = show_dominant_parameters def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) @@ -199,9 +211,11 @@ class ScaledAdam(BatchedOptimizer): loss = closure() batch = True - for group in self.param_groups: + assert len(self.param_groups) == len(self.parameters_names) - with self.batched_params(group["params"]) as batches: + for group, group_params_names in zip(self.param_groups, self.parameters_names): + + with self.batched_params(group["params"], group_params_names) as batches: # batches is list of pairs (stacked_param, state). stacked_param is like # a regular parameter, and will have a .grad, but the 1st dim corresponds to @@ -214,7 +228,7 @@ class ScaledAdam(BatchedOptimizer): else: clipping_scale = self._get_clipping_scale(group, batches) - for p, state in batches: + for p, state, _ in batches: # Perform optimization step. # grad is not going to be None, we handled that when creating the batches. grad = p.grad @@ -276,7 +290,7 @@ class ScaledAdam(BatchedOptimizer): state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) def _get_clipping_scale( - self, group: dict, pairs: List[Tuple[Tensor, dict]] + self, group: dict, pairs: List[Tuple[Tensor, dict, List[str]]] ) -> float: """ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients @@ -289,7 +303,7 @@ class ScaledAdam(BatchedOptimizer): """ assert len(pairs) >= 1 clipping_scale = group["clipping_scale"] - (first_p, first_state) = pairs[0] + (first_p, first_state, _) = pairs[0] step = first_state["step"] if clipping_scale is None or step == 0: # no clipping. return early on step == 0 because the other @@ -298,7 +312,7 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state) in pairs: + for (p, state, param_names) in pairs: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -361,8 +375,49 @@ class ScaledAdam(BatchedOptimizer): logging.warn( f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" ) + if self.show_dominant_parameters: + assert p.shape[0] == len(param_names) + self._show_gradient_dominating_parameter(pairs, tot_sumsq) return ans + def _show_gradient_dominating_parameter(self, pairs, tot_sumsq): + # ori means calculated with state["param_rms"] + # cur means calculated with "param_rms" of current param. + # bt is short batch + # all_sumsq_ori_rms + all_sumsq_ori = {} + all_sumsq_cur = {} + for (p, state, batch_param_names) in pairs: + # p is a stacked batch parameters. + grad = p.grad + if p.numel() == p.shape[0]: # a batch of scalars + batch_sumsq_ori = grad**2 # sum() to change shape [1] to [] + batch_sumsq_cur = batch_sumsq_ori # sum() to change shape [1] to [] + # Dummpy values used by following `zip` statement. + batch_rms_ori = torch.zeros(p.shape[0]) + batch_rms_cur = batch_rms_ori + else: + batch_rms_ori = state["param_rms"] + batch_sumsq_ori = ((grad * batch_rms_ori) ** 2).sum(dim=list(range(1, grad.ndim))) + + batch_rms_cur = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + batch_sumsq_cur = ((grad * batch_rms_cur) ** 2).sum(dim=list(range(1, grad.ndim))) + + for name, sumsq_ori, sumsq_cur in zip( + batch_param_names, batch_sumsq_ori, batch_sumsq_cur): + + proportion_ori = sumsq_ori / tot_sumsq + proportion_cur = sumsq_cur / tot_sumsq + + all_sumsq_ori[name] = (proportion_ori, sumsq_ori) + all_sumsq_cur[name] = (proportion_cur, sumsq_cur) + + for rms_type, all_sumsq in zip(("ori", "cur"), (all_sumsq_ori, all_sumsq_cur)): + sorted_by_proportion = {k: v for k, v in sorted(all_sumsq.items(), key=lambda item: item[1][0], reverse=True)} + dominant_param_name = next(iter(sorted_by_proportion)) + dominant_proportion, dominant_sumsq = sorted_by_proportion[dominant_param_name] + logging.info(f"Dominant sumsq with {rms_type}_rms: {dominant_param_name} {dominant_proportion} {dominant_sumsq} {tot_sumsq}") + def _step_one_batch( self, group: dict, p: Tensor, state: dict, clipping_scale: float ): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index b27c573ab..8375b1a18 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -368,6 +368,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--show-dominant-parameters", + type=str2bool, + default=False, + help="Whether to show dominant parameters.", + ) + add_model_arguments(parser) return parser @@ -988,7 +995,11 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + parameters_names = [] + parameters_names.append([name_param_pair[0] for name_param_pair in model.named_parameters()]) + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, + clipping_scale=2.0, parameters_names=parameters_names, + show_dominant_parameters=params.show_dominant_parameters) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) From db75627e92155c16fd6e74d640ece4f6563f96f2 Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Fri, 25 Nov 2022 21:00:45 -0500 Subject: [PATCH 04/27] [recipe] AMI Zipformer transducer (#698) * remove unnecessary changes * add AMI prepare scripts * add zipformer scripts for AMI * added logs and pretrained model * minor fix * remove unwanted changes * fix missing link * make suggested changes * update results --- egs/ami/ASR/README.md | 48 + egs/ami/ASR/RESULTS.md | 92 ++ egs/ami/ASR/local/__init__.py | 0 egs/ami/ASR/local/compute_fbank_ami.py | 194 +++ egs/ami/ASR/local/compute_fbank_musan.py | 114 ++ egs/ami/ASR/local/prepare_ami_enhanced.py | 158 +++ egs/ami/ASR/local/prepare_ami_gss.sh | 98 ++ egs/ami/ASR/local/prepare_lang_bpe.py | 1 + egs/ami/ASR/local/train_bpe_model.py | 1 + egs/ami/ASR/prepare.sh | 144 ++ .../pruned_transducer_stateless7/__init__.py | 0 .../asr_datamodule.py | 430 ++++++ .../beam_search.py | 1 + .../pruned_transducer_stateless7/decode.py | 747 +++++++++++ .../pruned_transducer_stateless7/decoder.py | 1 + .../encoder_interface.py | 1 + .../pruned_transducer_stateless7/export.py | 1 + .../pruned_transducer_stateless7/joiner.py | 1 + .../ASR/pruned_transducer_stateless7/model.py | 1 + .../ASR/pruned_transducer_stateless7/optim.py | 1 + .../pruned_transducer_stateless7/scaling.py | 1 + .../scaling_converter.py | 1 + .../ASR/pruned_transducer_stateless7/train.py | 1184 +++++++++++++++++ .../pruned_transducer_stateless7/zipformer.py | 1 + egs/ami/ASR/shared | 1 + 25 files changed, 3222 insertions(+) create mode 100644 egs/ami/ASR/README.md create mode 100644 egs/ami/ASR/RESULTS.md create mode 100644 egs/ami/ASR/local/__init__.py create mode 100755 egs/ami/ASR/local/compute_fbank_ami.py create mode 100755 egs/ami/ASR/local/compute_fbank_musan.py create mode 100644 egs/ami/ASR/local/prepare_ami_enhanced.py create mode 100755 egs/ami/ASR/local/prepare_ami_gss.sh create mode 120000 egs/ami/ASR/local/prepare_lang_bpe.py create mode 120000 egs/ami/ASR/local/train_bpe_model.py create mode 100755 egs/ami/ASR/prepare.sh create mode 100644 egs/ami/ASR/pruned_transducer_stateless7/__init__.py create mode 100644 egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/beam_search.py create mode 100755 egs/ami/ASR/pruned_transducer_stateless7/decode.py create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/decoder.py create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/export.py create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/joiner.py create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/model.py create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/optim.py create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/scaling.py create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py create mode 100755 egs/ami/ASR/pruned_transducer_stateless7/train.py create mode 120000 egs/ami/ASR/pruned_transducer_stateless7/zipformer.py create mode 120000 egs/ami/ASR/shared diff --git a/egs/ami/ASR/README.md b/egs/ami/ASR/README.md new file mode 100644 index 000000000..1c9714bd4 --- /dev/null +++ b/egs/ami/ASR/README.md @@ -0,0 +1,48 @@ +# AMI + +This is an ASR recipe for the AMI corpus. AMI provides recordings from the speaker's +headset and lapel microphones, and also 2 array microphones containing 8 channels each. +We pool data in the following 4 ways and train a single model on the pooled data: + +(i) individual headset microphone (IHM) +(ii) IHM with simulated reverb +(iii) Single distant microphone (SDM) +(iv) GSS-enhanced array microphones + +Speed perturbation and MUSAN noise augmentation are additionally performed on the pooled +data. Here are the statistics of the combined training data: + +```python +>>> cuts_train.describe() +Cuts count: 1222053 +Total duration (hh:mm:ss): 905:00:28 +Speech duration (hh:mm:ss): 905:00:28 (99.9%) +Duration statistics (seconds): +mean 2.7 +std 2.8 +min 0.0 +25% 0.6 +50% 1.6 +75% 3.8 +99% 12.3 +99.5% 13.9 +99.9% 18.4 +max 36.8 +``` + +**Note:** This recipe additionally uses [GSS](https://github.com/desh2608/gss) for enhancement +of far-field array microphones, but this is optional (see `prepare.sh` for details). + +## Performance Record + +### pruned_transducer_stateless7 + +The following are decoded using `modified_beam_search`: + +| Evaluation set | dev WER | test WER | +|--------------------------|------------|---------| +| IHM | 18.92 | 17.40 | +| SDM | 31.25 | 32.21 | +| MDM (GSS-enhanced) | 21.67 | 22.43 | + +See [RESULTS](/egs/ami/ASR/RESULTS.md) for details. diff --git a/egs/ami/ASR/RESULTS.md b/egs/ami/ASR/RESULTS.md new file mode 100644 index 000000000..163986021 --- /dev/null +++ b/egs/ami/ASR/RESULTS.md @@ -0,0 +1,92 @@ +## Results + +### AMI training results (Pruned Transducer) + +#### 2022-11-20 + +#### Zipformer (pruned_transducer_stateless7) + +Zipformer encoder + non-current decoder. The decoder +contains only an embedding layer, a Conv1d (with kernel size 2) and a linear +layer (to transform tensor dim). + +All the results below are using a single model that is trained by combining the following +data: IHM, IHM+reverb, SDM, and GSS-enhanced MDM. Speed perturbation and MUSAN noise +augmentation are applied on top of the pooled data. + +**WERs for IHM:** + +| | dev | test | comment | +|---------------------------|------------|------------|------------------------------------------| +| greedy search | 19.25 | 17.83 | --epoch 14 --avg 8 --max-duration 500 | +| modified beam search | 18.92 | 17.40 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 | +| fast beam search | 19.44 | 18.04 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | + +**WERs for SDM:** + +| | dev | test | comment | +|---------------------------|------------|------------|------------------------------------------| +| greedy search | 31.32 | 32.38 | --epoch 14 --avg 8 --max-duration 500 | +| modified beam search | 31.25 | 32.21 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 | +| fast beam search | 31.11 | 32.10 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | + +**WERs for GSS-enhanced MDM:** + +| | dev | test | comment | +|---------------------------|------------|------------|------------------------------------------| +| greedy search | 22.05 | 22.93 | --epoch 14 --avg 8 --max-duration 500 | +| modified beam search | 21.67 | 22.43 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 | +| fast beam search | 22.21 | 22.83 | --epoch 14 --avg 8 --max-duration 500 --beam-size 4 --max-contexts 4 --max-states 8 | + +The training command for reproducing is given below: + +``` +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 15 \ + --exp-dir pruned_transducer_stateless7/exp \ + --max-duration 150 \ + --max-cuts 150 \ + --prune-range 5 \ + --lr-factor 5 \ + --lm-scale 0.25 \ + --use-fp16 True +``` + +The decoding command is: +``` +# greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 14 \ + --avg 8 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 500 \ + --decoding-method greedy_search + +# modified beam search +./pruned_transducer_stateless7/decode.py \ + --iter 105000 \ + --avg 10 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 500 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +# fast beam search +./pruned_transducer_stateless7/decode.py \ + --iter 105000 \ + --avg 10 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +``` + +Pretrained model is available at + +The tensorboard training log can be found at + diff --git a/egs/ami/ASR/local/__init__.py b/egs/ami/ASR/local/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/ami/ASR/local/compute_fbank_ami.py b/egs/ami/ASR/local/compute_fbank_ami.py new file mode 100755 index 000000000..4892b40e3 --- /dev/null +++ b/egs/ami/ASR/local/compute_fbank_ami.py @@ -0,0 +1,194 @@ +#!/usr/bin/env python3 +# Copyright 2022 Johns Hopkins University (authors: Desh Raj) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the AMI dataset. +For the training data, we pool together IHM, reverberated IHM, and GSS-enhanced +audios. For the test data, we separately prepare IHM, SDM, and GSS-enhanced +parts (which are the 3 evaluation settings). +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" +import logging +import math +from pathlib import Path + +import torch +import torch.multiprocessing +from lhotse import CutSet, LilcomChunkyWriter +from lhotse.features.kaldifeat import ( + KaldifeatFbank, + KaldifeatFbankConfig, + KaldifeatFrameOptions, + KaldifeatMelOptions, +) +from lhotse.recipes.utils import read_manifests_if_cached + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +torch.multiprocessing.set_sharing_strategy("file_system") + + +def compute_fbank_ami(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + sampling_rate = 16000 + num_mel_bins = 80 + + extractor = KaldifeatFbank( + KaldifeatFbankConfig( + frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), + mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), + device="cuda", + ) + ) + + logging.info("Reading manifests") + manifests_ihm = read_manifests_if_cached( + dataset_parts=["train", "dev", "test"], + output_dir=src_dir, + prefix="ami-ihm", + suffix="jsonl.gz", + ) + manifests_sdm = read_manifests_if_cached( + dataset_parts=["train", "dev", "test"], + output_dir=src_dir, + prefix="ami-sdm", + suffix="jsonl.gz", + ) + # For GSS we already have cuts so we read them directly. + manifests_gss = read_manifests_if_cached( + dataset_parts=["train", "dev", "test"], + output_dir=src_dir, + prefix="ami-gss", + suffix="jsonl.gz", + ) + + def _extract_feats(cuts: CutSet, storage_path: Path, manifest_path: Path) -> None: + cuts = cuts + cuts.perturb_speed(0.9) + cuts.perturb_speed(1.1) + _ = cuts.compute_and_store_features_batch( + extractor=extractor, + storage_path=storage_path, + manifest_path=manifest_path, + batch_duration=5000, + num_workers=8, + storage_type=LilcomChunkyWriter, + ) + + logging.info( + "Preparing training cuts: IHM + reverberated IHM + SDM + GSS (optional)" + ) + + logging.info("Processing train split IHM") + cuts_ihm = ( + CutSet.from_manifests(**manifests_ihm["train"]) + .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False) + .modify_ids(lambda x: x + "-ihm") + ) + _extract_feats( + cuts_ihm, + output_dir / "feats_train_ihm", + src_dir / "cuts_train_ihm.jsonl.gz", + ) + + logging.info("Processing train split IHM + reverberated IHM") + cuts_ihm_rvb = cuts_ihm.reverb_rir() + _extract_feats( + cuts_ihm_rvb, + output_dir / "feats_train_ihm_rvb", + src_dir / "cuts_train_ihm_rvb.jsonl.gz", + ) + + logging.info("Processing train split SDM") + cuts_sdm = ( + CutSet.from_manifests(**manifests_sdm["train"]) + .trim_to_supervisions(keep_overlapping=False) + .modify_ids(lambda x: x + "-sdm") + ) + _extract_feats( + cuts_sdm, + output_dir / "feats_train_sdm", + src_dir / "cuts_train_sdm.jsonl.gz", + ) + + logging.info("Processing train split GSS") + cuts_gss = ( + CutSet.from_manifests(**manifests_gss["train"]) + .trim_to_supervisions(keep_overlapping=False) + .modify_ids(lambda x: x + "-gss") + ) + _extract_feats( + cuts_gss, + output_dir / "feats_train_gss", + src_dir / "cuts_train_gss.jsonl.gz", + ) + + logging.info("Preparing test cuts: IHM, SDM, GSS (optional)") + for split in ["dev", "test"]: + logging.info(f"Processing {split} IHM") + cuts_ihm = ( + CutSet.from_manifests(**manifests_ihm[split]) + .trim_to_supervisions(keep_overlapping=False, keep_all_channels=False) + .compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir / f"feats_{split}_ihm", + manifest_path=src_dir / f"cuts_{split}_ihm.jsonl.gz", + batch_duration=5000, + num_workers=8, + storage_type=LilcomChunkyWriter, + ) + ) + logging.info(f"Processing {split} SDM") + cuts_sdm = ( + CutSet.from_manifests(**manifests_sdm[split]) + .trim_to_supervisions(keep_overlapping=False) + .compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir / f"feats_{split}_sdm", + manifest_path=src_dir / f"cuts_{split}_sdm.jsonl.gz", + batch_duration=500, + num_workers=4, + storage_type=LilcomChunkyWriter, + ) + ) + logging.info(f"Processing {split} GSS") + cuts_gss = ( + CutSet.from_manifests(**manifests_gss[split]) + .trim_to_supervisions(keep_overlapping=False) + .compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir / f"feats_{split}_gss", + manifest_path=src_dir / f"cuts_{split}_gss.jsonl.gz", + batch_duration=500, + num_workers=4, + storage_type=LilcomChunkyWriter, + ) + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + compute_fbank_ami() diff --git a/egs/ami/ASR/local/compute_fbank_musan.py b/egs/ami/ASR/local/compute_fbank_musan.py new file mode 100755 index 000000000..1fcf951f9 --- /dev/null +++ b/egs/ami/ASR/local/compute_fbank_musan.py @@ -0,0 +1,114 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the musan dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import logging +from pathlib import Path + +import torch +from lhotse import CutSet, LilcomChunkyWriter, combine +from lhotse.features.kaldifeat import ( + KaldifeatFbank, + KaldifeatFbankConfig, + KaldifeatFrameOptions, + KaldifeatMelOptions, +) +from lhotse.recipes.utils import read_manifests_if_cached + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_musan(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + sampling_rate = 16000 + num_mel_bins = 80 + + dataset_parts = ( + "music", + "speech", + "noise", + ) + prefix = "musan" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + musan_cuts_path = src_dir / "musan_cuts.jsonl.gz" + + if musan_cuts_path.is_file(): + logging.info(f"{musan_cuts_path} already exists - skipping") + return + + logging.info("Extracting features for Musan") + + extractor = KaldifeatFbank( + KaldifeatFbankConfig( + frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), + mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), + device="cuda", + ) + ) + + # create chunks of Musan with duration 5 - 10 seconds + _ = ( + CutSet.from_manifests( + recordings=combine(part["recordings"] for part in manifests.values()) + ) + .cut_into_windows(10.0) + .filter(lambda c: c.duration > 5) + .compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir / "musan_feats", + manifest_path=musan_cuts_path, + batch_duration=500, + num_workers=4, + storage_type=LilcomChunkyWriter, + ) + ) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_fbank_musan() diff --git a/egs/ami/ASR/local/prepare_ami_enhanced.py b/egs/ami/ASR/local/prepare_ami_enhanced.py new file mode 100644 index 000000000..bed220eb3 --- /dev/null +++ b/egs/ami/ASR/local/prepare_ami_enhanced.py @@ -0,0 +1,158 @@ +#!/usr/local/bin/python +# -*- coding: utf-8 -*- +# Data preparation for AMI GSS-enhanced dataset. + +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +from lhotse import Recording, RecordingSet, SupervisionSet +from lhotse.qa import fix_manifests +from lhotse.recipes.utils import read_manifests_if_cached +from lhotse.utils import fastcopy +from tqdm import tqdm + +logging.basicConfig( + format="%(asctime)s %(levelname)-8s %(message)s", + level=logging.INFO, + datefmt="%Y-%m-%d %H:%M:%S", +) + + +def get_args(): + import argparse + + parser = argparse.ArgumentParser(description="AMI enhanced dataset preparation.") + parser.add_argument( + "manifests_dir", + type=Path, + help="Path to directory containing AMI manifests.", + ) + parser.add_argument( + "enhanced_dir", + type=Path, + help="Path to enhanced data directory.", + ) + parser.add_argument( + "--num-jobs", + "-j", + type=int, + default=1, + help="Number of parallel jobs to run.", + ) + parser.add_argument( + "--min-segment-duration", + "-d", + type=float, + default=0.0, + help="Minimum duration of a segment in seconds.", + ) + return parser.parse_args() + + +def find_recording_and_create_new_supervision(enhanced_dir, supervision): + """ + Given a supervision (corresponding to original AMI recording), this function finds the + enhanced recording correspoding to the supervision, and returns this recording and + a new supervision whose start and end times are adjusted to match the enhanced recording. + """ + file_name = Path( + f"{supervision.recording_id}-{supervision.speaker}-{int(100*supervision.start):06d}_{int(100*supervision.end):06d}.flac" + ) + save_path = enhanced_dir / f"{supervision.recording_id}" / file_name + if save_path.exists(): + recording = Recording.from_file(save_path) + if recording.duration == 0: + logging.warning(f"Skipping {save_path} which has duration 0 seconds.") + return None + + # Old supervision is wrt to the original recording, we create new supervision + # wrt to the enhanced segment + new_supervision = fastcopy( + supervision, + recording_id=recording.id, + start=0, + duration=recording.duration, + ) + return recording, new_supervision + else: + logging.warning(f"{save_path} does not exist.") + return None + + +def main(args): + # Get arguments + manifests_dir = args.manifests_dir + enhanced_dir = args.enhanced_dir + + # Load manifests from cache if they exist (saves time) + manifests = read_manifests_if_cached( + dataset_parts=["train", "dev", "test"], + output_dir=manifests_dir, + prefix="ami-sdm", + suffix="jsonl.gz", + ) + if not manifests: + raise ValueError("AMI SDM manifests not found in {}".format(manifests_dir)) + + with ThreadPoolExecutor(args.num_jobs) as ex: + for part in ["train", "dev", "test"]: + logging.info(f"Processing {part}...") + supervisions_orig = manifests[part]["supervisions"].filter( + lambda s: s.duration >= args.min_segment_duration + ) + # Remove TS3009d supervisions since they are not present in the enhanced data + supervisions_orig = supervisions_orig.filter( + lambda s: s.recording_id != "TS3009d" + ) + futures = [] + + for supervision in tqdm( + supervisions_orig, + desc="Distributing tasks", + ): + futures.append( + ex.submit( + find_recording_and_create_new_supervision, + enhanced_dir, + supervision, + ) + ) + + recordings = [] + supervisions = [] + for future in tqdm( + futures, + total=len(futures), + desc="Processing tasks", + ): + result = future.result() + if result is not None: + recording, new_supervision = result + recordings.append(recording) + supervisions.append(new_supervision) + + # Remove duplicates from the recordings + recordings_nodup = {} + for recording in recordings: + if recording.id not in recordings_nodup: + recordings_nodup[recording.id] = recording + else: + logging.warning("Recording {} is duplicated.".format(recording.id)) + recordings = RecordingSet.from_recordings(recordings_nodup.values()) + supervisions = SupervisionSet.from_segments(supervisions) + + recordings, supervisions = fix_manifests( + recordings=recordings, supervisions=supervisions + ) + + logging.info(f"Writing {part} enhanced manifests") + recordings.to_file(manifests_dir / f"ami-gss_recordings_{part}.jsonl.gz") + supervisions.to_file( + manifests_dir / f"ami-gss_supervisions_{part}.jsonl.gz" + ) + + +if __name__ == "__main__": + args = get_args() + main(args) diff --git a/egs/ami/ASR/local/prepare_ami_gss.sh b/egs/ami/ASR/local/prepare_ami_gss.sh new file mode 100755 index 000000000..d5422458b --- /dev/null +++ b/egs/ami/ASR/local/prepare_ami_gss.sh @@ -0,0 +1,98 @@ +#!/bin/bash +# This script is used to run GSS-based enhancement on AMI data. +set -euo pipefail +nj=4 +stage=0 + +. shared/parse_options.sh || exit 1 + +if [ $# != 2 ]; then + echo "Wrong #arguments ($#, expected 2)" + echo "Usage: local/prepare_ami_gss.sh [options] " + echo "e.g. local/prepare_ami_gss.sh data/manifests exp/ami_gss" + echo "main options (for others, see top of script file)" + echo " --nj # number of parallel jobs" + echo " --stage # stage to start running from" + exit 1; +fi + +DATA_DIR=$1 +EXP_DIR=$2 + +mkdir -p $EXP_DIR + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 1 ]; then + log "Stage 1: Prepare cut sets" + for part in train dev test; do + lhotse cut simple \ + -r $DATA_DIR/ami-mdm_recordings_${part}.jsonl.gz \ + -s $DATA_DIR/ami-mdm_supervisions_${part}.jsonl.gz \ + $EXP_DIR/cuts_${part}.jsonl.gz + done +fi + +if [ $stage -le 2 ]; then + log "Stage 2: Trim cuts to supervisions (1 cut per supervision segment)" + for part in train dev test; do + lhotse cut trim-to-supervisions --discard-overlapping \ + $EXP_DIR/cuts_${part}.jsonl.gz $EXP_DIR/cuts_per_segment_${part}.jsonl.gz + done +fi + +if [ $stage -le 3 ]; then + log "Stage 3: Split manifests for multi-GPU processing (optional)" + for part in train; do + gss utils split $nj $EXP_DIR/cuts_per_segment_${part}.jsonl.gz \ + $EXP_DIR/cuts_per_segment_${part}_split$nj + done +fi + +if [ $stage -le 4 ]; then + log "Stage 4: Enhance train segments using GSS (requires GPU)" + # for train, we use smaller context and larger batches to speed-up processing + for JOB in $(seq $nj); do + gss enhance cuts $EXP_DIR/cuts_train.jsonl.gz \ + $EXP_DIR/cuts_per_segment_train_split$nj/cuts_per_segment_train.JOB.jsonl.gz $EXP_DIR/enhanced \ + --bss-iterations 10 \ + --context-duration 5.0 \ + --use-garbage-class \ + --channels 0,1,2,3,4,5,6,7 \ + --min-segment-length 0.05 \ + --max-segment-length 35.0 \ + --max-batch-duration 60.0 \ + --num-buckets 3 \ + --num-workers 2 + done +fi + +if [ $stage -le 5 ]; then + log "Stage 5: Enhance dev/test segments using GSS (using GPU)" + # for dev/test, we use larger context and smaller batches to get better quality + for part in dev test; do + for JOB in $(seq $nj); do + gss enhance cuts $EXP_DIR/cuts_${part}.jsonl.gz \ + $EXP_DIR/cuts_per_segment_${part}_split$nj/cuts_per_segment_${part}.JOB.jsonl.gz \ + $EXP_DIR/enhanced \ + --bss-iterations 10 \ + --context-duration 15.0 \ + --use-garbage-class \ + --channels 0,1,2,3,4,5,6,7 \ + --min-segment-length 0.05 \ + --max-segment-length 30.0 \ + --max-batch-duration 45.0 \ + --num-buckets 3 \ + --num-workers 2 + done + done +fi + +if [ $stage -le 6 ]; then + log "Stage 6: Prepare manifests for GSS-enhanced data" + python local/prepare_ami_enhanced.py $DATA_DIR $EXP_DIR/enhanced -j $nj --min-segment-duration 0.05 +fi diff --git a/egs/ami/ASR/local/prepare_lang_bpe.py b/egs/ami/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/ami/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/ami/ASR/local/train_bpe_model.py b/egs/ami/ASR/local/train_bpe_model.py new file mode 120000 index 000000000..6fad36421 --- /dev/null +++ b/egs/ami/ASR/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/ami/ASR/prepare.sh b/egs/ami/ASR/prepare.sh new file mode 100755 index 000000000..fb21a8ec6 --- /dev/null +++ b/egs/ami/ASR/prepare.sh @@ -0,0 +1,144 @@ +#!/usr/bin/env bash + +set -eou pipefail + +stage=-1 +stop_stage=100 +use_gss=true # Use GSS-based enhancement with MDM setting + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/amicorpus +# You can find audio and transcripts in this path. +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech +# +# - $dl_dir/{LDC2004S13,LDC2005S13,LDC2004T19,LDC2005T19} +# These contain the Fisher English audio and transcripts. We will +# only use the transcripts as extra LM training data (similar to Kaldi). +# +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data +vocab_size=500 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/amicorpus, + # you can create a symlink + # + # ln -sfv /path/to/amicorpus $dl_dir/amicorpus + # + if [ ! -d $dl_dir/amicorpus ]; then + lhotse download ami --mic ihm $dl_dir/amicorpus + lhotse download ami --mic mdm $dl_dir/amicorpus + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare AMI manifests" + # We assume that you have downloaded the AMI corpus + # to $dl_dir/amicorpus. We perform text normalization for the transcripts. + mkdir -p data/manifests + for mic in ihm sdm mdm; do + lhotse prepare ami --mic $mic --partition full-corpus-asr --normalize-text kaldi \ + --max-words-per-segment 30 $dl_dir/amicorpus data/manifests/ + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to $dl_dir/musan + mkdir -p data/manifests + lhotse prepare musan $dl_dir/musan data/manifests +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ] && [ $use_gss = true ]; then + log "Stage 3: Apply GSS enhancement on MDM data (this stage requires a GPU)" + # We assume that you have installed the GSS package: https://github.com/desh2608/gss + local/prepare_ami_gss.sh data/manifests exp/ami_gss +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank features for AMI" + mkdir -p data/fbank + python local/compute_fbank_ami.py + log "Combine features from train splits" + lhotse combine data/manifests/cuts_train_{ihm,ihm_rvb,sdm,gss}.jsonl.gz - | shuf |\ + gzip -c > data/manifests/cuts_train_all.jsonl.gz +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Compute fbank features for musan" + mkdir -p data/fbank + python local/compute_fbank_musan.py +fi + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Dump transcripts for BPE model training." + mkdir -p data/lm + cat <(gunzip -c data/manifests/ami-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g')> data/lm/transcript_words.txt +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare BPE based lang" + + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + + # Add special words to words.txt + echo " 0" > $lang_dir/words.txt + echo "!SIL 1" >> $lang_dir/words.txt + echo " 2" >> $lang_dir/words.txt + + # Add regular words to words.txt + cat data/lm/transcript_words.txt | grep -o -E '\w+' | sort -u | awk '{print $0,NR+2}' >> $lang_dir/words.txt + + # Add remaining special word symbols expected by LM scripts. + num_words=$(cat $lang_dir/words.txt | wc -l) + echo " ${num_words}" >> $lang_dir/words.txt + num_words=$(cat $lang_dir/words.txt | wc -l) + echo " ${num_words}" >> $lang_dir/words.txt + num_words=$(cat $lang_dir/words.txt | wc -l) + echo "#0 ${num_words}" >> $lang_dir/words.txt + + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript data/lm/transcript_words.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + fi +fi diff --git a/egs/ami/ASR/pruned_transducer_stateless7/__init__.py b/egs/ami/ASR/pruned_transducer_stateless7/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py new file mode 100644 index 000000000..f7ee9c962 --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -0,0 +1,430 @@ +# Copyright 2021 Piotr Żelasko +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import re +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.cut import Cut +from lhotse.dataset import ( + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader +from tqdm import tqdm + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class AmiAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description=( + "These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc." + ), + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/manifests"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help=( + "When enabled, select noise from MUSAN and mix it " + "with training dataset. " + ), + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help=( + "When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding." + ), + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help=( + "Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch." + ), + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help=( + "The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used." + ), + ) + group.add_argument( + "--max-duration", + type=int, + default=100.0, + help=( + "Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM." + ), + ) + group.add_argument( + "--max-cuts", type=int, default=None, help="Maximum cuts in a single batch." + ) + group.add_argument( + "--num-buckets", + type=int, + default=50, + help=( + "The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets)." + ), + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help=( + "When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available." + ), + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help=( + "When enabled (=default), the examples will be " + "shuffled for each epoch." + ), + ) + + group.add_argument( + "--num-workers", + type=int, + default=8, + help=( + "The number of training dataloader workers that " "collect the batches." + ), + ) + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help=( + "Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp." + ), + ) + group.add_argument( + "--ihm-only", + type=str2bool, + default=False, + help="When enabled, only use IHM data for training.", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to get Musan cuts") + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + "Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=2, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + if self.args.on_the_fly_feats: + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + ) + else: + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + ) + + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + max_cuts=self.args.max_cuts, + shuffle=False, + num_buckets=self.args.num_buckets, + drop_last=True, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=True, + ) + sampler = DynamicBucketingSampler( + cuts, max_duration=self.args.max_duration, shuffle=False + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + def remove_short_cuts(self, cut: Cut) -> bool: + """ + See: https://github.com/k2-fsa/icefall/issues/500 + Basically, the zipformer model subsamples the input using the following formula: + num_out_frames = (num_in_frames - 7)//2 + For num_out_frames to be at least 1, num_in_frames must be at least 9. + """ + return cut.duration >= 0.09 + + @lru_cache() + def train_cuts(self, sp: Optional[Any] = None) -> CutSet: + logging.info("About to get AMI train cuts") + + def _remove_short_and_long_utt(c: Cut): + if c.duration < 0.2 or c.duration > 25.0: + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + return T >= len(tokens) + + if self.args.ihm_only: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / "cuts_train_ihm.jsonl.gz" + ) + else: + cuts_train = load_manifest_lazy( + self.args.manifest_dir / "cuts_train_all.jsonl.gz" + ) + + return cuts_train.filter(_remove_short_and_long_utt) + + @lru_cache() + def dev_ihm_cuts(self) -> CutSet: + logging.info("About to get AMI IHM dev cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_ihm.jsonl.gz") + return cs.filter(self.remove_short_cuts) + + @lru_cache() + def dev_sdm_cuts(self) -> CutSet: + logging.info("About to get AMI SDM dev cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_sdm.jsonl.gz") + return cs.filter(self.remove_short_cuts) + + @lru_cache() + def dev_gss_cuts(self) -> CutSet: + if not (self.args.manifest_dir / "cuts_dev_gss.jsonl.gz").exists(): + logging.info("No GSS dev cuts found") + return None + logging.info("About to get AMI GSS-enhanced dev cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_dev_gss.jsonl.gz") + return cs.filter(self.remove_short_cuts) + + @lru_cache() + def test_ihm_cuts(self) -> CutSet: + logging.info("About to get AMI IHM test cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_ihm.jsonl.gz") + return cs.filter(self.remove_short_cuts) + + @lru_cache() + def test_sdm_cuts(self) -> CutSet: + logging.info("About to get AMI SDM test cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_sdm.jsonl.gz") + return cs.filter(self.remove_short_cuts) + + @lru_cache() + def test_gss_cuts(self) -> CutSet: + if not (self.args.manifest_dir / "cuts_test_gss.jsonl.gz").exists(): + logging.info("No GSS test cuts found") + return None + logging.info("About to get AMI GSS-enhanced test cuts") + cs = load_manifest_lazy(self.args.manifest_dir / "cuts_test_gss.jsonl.gz") + return cs.filter(self.remove_short_cuts) diff --git a/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py b/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py new file mode 120000 index 000000000..37516affc --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/beam_search.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decode.py b/egs/ami/ASR/pruned_transducer_stateless7/decode.py new file mode 100755 index 000000000..f47228fbe --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/decode.py @@ -0,0 +1,747 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7/decode.py \ + --iter 105000 \ + --avg 10 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 100 \ + --decoding-method greedy_search + +(2) beam search +./pruned_transducer_stateless7/decode.py \ + --iter 105000 \ + --avg 10 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 500 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7/decode.py \ + --iter 105000 \ + --avg 10 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 500 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search +./pruned_transducer_stateless7/decode.py \ + --iter 105000 \ + --avg 10 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 500 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +""" + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import AmiAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest_LG, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall import NgramLm +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=10, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, + word_table: Optional[k2.SymbolTable] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + word_table: + The word symbol table. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = model.device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, + word_table: Optional[k2.SymbolTable] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 100 + else: + log_interval = 2 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + test_set_cers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + wers_filename = ( + params.res_dir / f"wers-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(wers_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + # we also compute CER for AMI dataset. + results_char = [] + for res in results: + results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + cers_filename = ( + params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(cers_filename, "w") as f: + cer = write_error_stats( + f, f"{test_set_name}-{key}", results_char, enable_log=True + ) + test_set_cers[key] = cer + + logging.info("Wrote detailed error stats to {}".format(wers_filename)) + + test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])} + test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])} + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER\tCER", file=f) + for key in test_set_wers: + print( + "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]), + file=f, + ) + + s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key in test_set_wers: + s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AmiAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest_LG", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(f"{params.lang_dir}/bpe.model") + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + model.device = device + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + ami = AmiAsrDataModule(args) + + dev_ihm_cuts = ami.dev_ihm_cuts() + test_ihm_cuts = ami.test_ihm_cuts() + dev_sdm_cuts = ami.dev_sdm_cuts() + test_sdm_cuts = ami.test_sdm_cuts() + dev_gss_cuts = ami.dev_gss_cuts() + test_gss_cuts = ami.test_gss_cuts() + + dev_ihm_dl = ami.test_dataloaders(dev_ihm_cuts) + test_ihm_dl = ami.test_dataloaders(test_ihm_cuts) + dev_sdm_dl = ami.test_dataloaders(dev_sdm_cuts) + test_sdm_dl = ami.test_dataloaders(test_sdm_cuts) + if dev_gss_cuts is not None: + dev_gss_dl = ami.test_dataloaders(dev_gss_cuts) + if test_gss_cuts is not None: + test_gss_dl = ami.test_dataloaders(test_gss_cuts) + + test_sets = { + "dev_ihm": (dev_ihm_dl, dev_ihm_cuts), + "test_ihm": (test_ihm_dl, test_ihm_cuts), + "dev_sdm": (dev_sdm_dl, dev_sdm_cuts), + "test_sdm": (test_sdm_dl, test_sdm_cuts), + } + if dev_gss_cuts is not None: + test_sets["dev_gss"] = (dev_gss_dl, dev_gss_cuts) + if test_gss_cuts is not None: + test_sets["test_gss"] = (test_gss_dl, test_gss_cuts) + + for test_set in test_sets: + logging.info(f"Decoding {test_set}") + dl, cuts = test_sets[test_set] + results_dict = decode_dataset( + dl=dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/ami/ASR/pruned_transducer_stateless7/decoder.py b/egs/ami/ASR/pruned_transducer_stateless7/decoder.py new file mode 120000 index 000000000..8283d8c5a --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py b/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py new file mode 120000 index 000000000..0c2673d46 --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/export.py b/egs/ami/ASR/pruned_transducer_stateless7/export.py new file mode 120000 index 000000000..2713792e6 --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/export.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/joiner.py b/egs/ami/ASR/pruned_transducer_stateless7/joiner.py new file mode 120000 index 000000000..0f0c3c90a --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/model.py b/egs/ami/ASR/pruned_transducer_stateless7/model.py new file mode 120000 index 000000000..0d8bc665b --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/optim.py b/egs/ami/ASR/pruned_transducer_stateless7/optim.py new file mode 120000 index 000000000..8a05abb5f --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/scaling.py b/egs/ami/ASR/pruned_transducer_stateless7/scaling.py new file mode 120000 index 000000000..5f9be9fe0 --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py new file mode 120000 index 000000000..f9960e5c6 --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py new file mode 100755 index 000000000..b5efb3405 --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py @@ -0,0 +1,1184 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 15 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --max-duration 150 \ + --use-fp16 True + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import AmiAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=11, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=5000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=10, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 100, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"] + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = supervisions["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ((feature_lens - 7) // 2).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + if params.inf_check: + register_inf_check_hooks(model) + + ami = AmiAsrDataModule(args) + + # Here is the duration statistics of the training set. + # Cuts count: 1230033 + # Total duration (hh:mm:ss): 904:25:34 + # Speech duration (hh:mm:ss): 904:25:34 (100.0%) + # Duration statistics (seconds): + # mean 2.6 + # std 2.8 + # min 0.0 + # 25% 0.6 + # 50% 1.6 + # 75% 3.8 + # 99% 12.3 + # 99.5% 13.9 + # 99.9% 18.3 + # max 36.8 + + train_cuts = ami.train_cuts(sp=sp) + train_dl = ami.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict) + + valid_cuts = ami.dev_ihm_cuts() + valid_dl = ami.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + AmiAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py b/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py new file mode 120000 index 000000000..f2f66041e --- /dev/null +++ b/egs/ami/ASR/pruned_transducer_stateless7/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/ami/ASR/shared b/egs/ami/ASR/shared new file mode 120000 index 000000000..4cbd91a7e --- /dev/null +++ b/egs/ami/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared \ No newline at end of file From 61032e70e097aea63d191183466d0f1b16f9e16e Mon Sep 17 00:00:00 2001 From: abb128 <65567823+abb128@users.noreply.github.com> Date: Sat, 26 Nov 2022 04:10:37 +0200 Subject: [PATCH 05/27] Fix exception in find_checkpoints (#668) --- icefall/checkpoint.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 8aa0a8eeb..f0663a1df 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -292,7 +292,15 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]: """ checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt")) pattern = re.compile(r"checkpoint-([0-9]+).pt") - iter_checkpoints = [(int(pattern.search(c).group(1)), c) for c in checkpoints] + iter_checkpoints = [] + for c in checkpoints: + result = pattern.search(c) + if not result: + logging.warn(f"Invalid checkpoint filename {c}") + continue + + iter_checkpoints.append((int(result.group(1)), c)) + # iter_checkpoints is a list of tuples. Each tuple contains # two elements: (iteration_number, checkpoint-iteration_number.pt) From 9cf79cac3f23757e25b499947f045efb0f4d71a6 Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Sat, 26 Nov 2022 21:48:17 +0800 Subject: [PATCH 06/27] message formatting --- .../ASR/pruned_transducer_stateless7/optim.py | 76 +++++++++++-------- .../ASR/pruned_transducer_stateless7/train.py | 10 +-- 2 files changed, 45 insertions(+), 41 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index ab55381d7..790752fe1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -42,7 +42,7 @@ class BatchedOptimizer(Optimizer): super(BatchedOptimizer, self).__init__(params, defaults) @contextlib.contextmanager - def batched_params(self, param_group, group_params_names=None): + def batched_params(self, param_group, group_params_names): """ This function returns (technically, yields) a list of of tuples (p, state), where @@ -85,7 +85,9 @@ class BatchedOptimizer(Optimizer): batches_names[key].append(named_p) batches_names_keys = list(batches_names.keys()) - sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i]) + sorted_idx = sorted( + range(len(batches_names)), key=lambda i: batches_names_keys[i] + ) batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] @@ -174,7 +176,7 @@ class ScaledAdam(BatchedOptimizer): size_update_period=4, clipping_update_period=100, parameters_names=None, - show_dominant_parameters=False, + show_dominant_parameters=True, ): defaults = dict( @@ -211,7 +213,7 @@ class ScaledAdam(BatchedOptimizer): loss = closure() batch = True - assert len(self.param_groups) == len(self.parameters_names) + assert len(self.param_groups) == len(self.parameters_names) for group, group_params_names in zip(self.param_groups, self.parameters_names): @@ -381,42 +383,52 @@ class ScaledAdam(BatchedOptimizer): return ans def _show_gradient_dominating_parameter(self, pairs, tot_sumsq): - # ori means calculated with state["param_rms"] - # cur means calculated with "param_rms" of current param. - # bt is short batch - # all_sumsq_ori_rms - all_sumsq_ori = {} - all_sumsq_cur = {} + all_sumsq_orig = {} for (p, state, batch_param_names) in pairs: # p is a stacked batch parameters. - grad = p.grad + batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars - batch_sumsq_ori = grad**2 # sum() to change shape [1] to [] - batch_sumsq_cur = batch_sumsq_ori # sum() to change shape [1] to [] + batch_sumsq_orig = batch_grad**2 # Dummpy values used by following `zip` statement. - batch_rms_ori = torch.zeros(p.shape[0]) - batch_rms_cur = batch_rms_ori + batch_rms_orig = torch.ones(p.shape[0]) else: - batch_rms_ori = state["param_rms"] - batch_sumsq_ori = ((grad * batch_rms_ori) ** 2).sum(dim=list(range(1, grad.ndim))) + batch_rms_orig = state["param_rms"] + batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( + dim=list(range(1, batch_grad.ndim)) + ) - batch_rms_cur = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - batch_sumsq_cur = ((grad * batch_rms_cur) ** 2).sum(dim=list(range(1, grad.ndim))) + for name, sumsq_orig, rms, grad in zip( + batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad + ): - for name, sumsq_ori, sumsq_cur in zip( - batch_param_names, batch_sumsq_ori, batch_sumsq_cur): + proportion_orig = sumsq_orig / tot_sumsq + all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) - proportion_ori = sumsq_ori / tot_sumsq - proportion_cur = sumsq_cur / tot_sumsq - - all_sumsq_ori[name] = (proportion_ori, sumsq_ori) - all_sumsq_cur[name] = (proportion_cur, sumsq_cur) - - for rms_type, all_sumsq in zip(("ori", "cur"), (all_sumsq_ori, all_sumsq_cur)): - sorted_by_proportion = {k: v for k, v in sorted(all_sumsq.items(), key=lambda item: item[1][0], reverse=True)} - dominant_param_name = next(iter(sorted_by_proportion)) - dominant_proportion, dominant_sumsq = sorted_by_proportion[dominant_param_name] - logging.info(f"Dominant sumsq with {rms_type}_rms: {dominant_param_name} {dominant_proportion} {dominant_sumsq} {tot_sumsq}") + assert torch.isclose( + sum([value[0] for value in all_sumsq_orig.values()]).cpu(), + torch.tensor(1.0), + ) + sorted_by_proportion = { + k: v + for k, v in sorted( + all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True + ) + } + dominant_param_name = next(iter(sorted_by_proportion)) + ( + dominant_proportion, + dominant_sumsq, + dominant_rms, + dominant_grad, + ) = sorted_by_proportion[dominant_param_name] + logging.info( + f"Parameter Dominanting tot_sumsq {dominant_param_name}" + f" with proportion {dominant_proportion:.2f}," + f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" + f"={dominant_sumsq:.3e}," + f" grad_sumsq = {(dominant_grad**2).sum():.3e}," + f" orig_rms_sq={(dominant_rms**2).item():.3e}" + ) def _step_one_batch( self, group: dict, p: Tensor, state: dict, clipping_scale: float diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 8375b1a18..e5a3e68df 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -368,13 +368,6 @@ def get_parser(): help="Whether to use half precision training.", ) - parser.add_argument( - "--show-dominant-parameters", - type=str2bool, - default=False, - help="Whether to show dominant parameters.", - ) - add_model_arguments(parser) return parser @@ -998,8 +991,7 @@ def run(rank, world_size, args): parameters_names = [] parameters_names.append([name_param_pair[0] for name_param_pair in model.named_parameters()]) optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, - clipping_scale=2.0, parameters_names=parameters_names, - show_dominant_parameters=params.show_dominant_parameters) + clipping_scale=2.0, parameters_names=parameters_names) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) From 6693d907d3ddd5c5eade144b55a57c8831d6d9b2 Mon Sep 17 00:00:00 2001 From: huangruizhe Date: Sat, 26 Nov 2022 22:26:09 -0500 Subject: [PATCH 07/27] shuffle full Librispeech data (#574) * shuffled full/partial librispeech data * fixed the code style issue * Shuffled full librispeech data off-line * Fixed style, addressed comments, and removed redandunt codes * Used the suggested version of black * Propagated the changes to other folders for librispeech (except conformer_mmi and streaming_conformer_ctc) --- egs/librispeech/ASR/conformer_ctc/train.py | 6 +++--- egs/librispeech/ASR/conformer_ctc2/train.py | 6 +++--- .../ASR/conv_emformer_transducer_stateless/train.py | 6 +++--- .../ASR/conv_emformer_transducer_stateless2/train.py | 6 +++--- egs/librispeech/ASR/lstm_transducer_stateless/train.py | 6 +++--- .../ASR/lstm_transducer_stateless2/train.py | 6 +++--- egs/librispeech/ASR/prepare.sh | 5 +++++ .../ASR/pruned_stateless_emformer_rnnt2/train.py | 6 +++--- .../ASR/pruned_transducer_stateless/train.py | 6 +++--- .../ASR/pruned_transducer_stateless2/train.py | 6 +++--- .../ASR/pruned_transducer_stateless3/train.py | 6 +++--- .../ASR/pruned_transducer_stateless4/train.py | 6 +++--- .../ASR/pruned_transducer_stateless5/train.py | 6 +++--- .../ASR/pruned_transducer_stateless6/train.py | 6 +++--- egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py | 10 ++++++++++ egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 8 ++++---- egs/librispeech/ASR/transducer/train.py | 6 +++--- egs/librispeech/ASR/transducer_lstm/train.py | 6 +++--- egs/librispeech/ASR/transducer_stateless/train.py | 6 +++--- egs/librispeech/ASR/transducer_stateless2/train.py | 6 +++--- .../ASR/transducer_stateless_multi_datasets/train.py | 6 +++--- 21 files changed, 73 insertions(+), 58 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 1449bc310..99fe64793 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -687,10 +687,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index ceea0c22c..121fdb256 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -928,10 +928,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index 213115854..6bb5505aa 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -970,10 +970,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 6a019fd63..8462ae92a 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -970,10 +970,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index a54108f6d..feb81d500 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -954,10 +954,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 8736384b4..4fc4fa7f8 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -1108,10 +1108,10 @@ def run(rank, world_size, args): librispeech = LibriSpeech(manifest_dir=args.manifest_dir) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() train_cuts = filter_short_and_long_utterances(train_cuts, sp) diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 8668af0e4..542bbcdd8 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -123,6 +123,11 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then touch data/fbank/.librispeech.done fi + cat <(gunzip -c data/fbank/librispeech_cuts_train-clean-100.jsonl.gz) \ + <(gunzip -c data/fbank/librispeech_cuts_train-clean-360.jsonl.gz) \ + <(gunzip -c data/fbank/librispeech_cuts_train-other-500.jsonl.gz) | \ + shuf | gzip -c > data/fbank/librispeech_cuts_train-all-shuf.jsonl.gz + if [ ! -e data/fbank/.librispeech-validated.done ]; then log "Validating data/fbank for LibriSpeech" parts=( diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index ed3fa1521..3601e1e11 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -882,10 +882,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index 4dabbccc1..cf4032027 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -873,10 +873,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 86333fc97..6c19f2cb0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -931,10 +931,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 281ba4650..fdafa5a87 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -1065,10 +1065,10 @@ def run(rank, world_size, args): librispeech = LibriSpeech(manifest_dir=args.manifest_dir) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() train_cuts = filter_short_and_long_utterances(train_cuts, sp) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index cb56c8294..9bd7df401 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -978,10 +978,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 436620744..847c80ab0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -1009,10 +1009,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index 8f4d3b879..57753599a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -970,10 +970,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 95d1b273a..c5787835d 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -414,6 +414,16 @@ class LibriSpeechAsrDataModule: self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" ) + @lru_cache() + def train_all_shuf_cuts(self) -> CutSet: + logging.info( + "About to get the shuffled train-clean-100, \ + train-clean-360 and train-other-500 cuts" + ) + return load_manifest_lazy( + self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz" + ) + @lru_cache() def dev_clean_cuts(self) -> CutSet: logging.info("About to get dev-clean cuts") diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 071ac792b..0aa1587ba 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -173,7 +173,7 @@ def get_params() -> AttributeDict: { "exp_dir": Path("tdnn_lstm_ctc/exp"), "lang_dir": Path("data/lang_phone"), - "lr": 1e-3, + "lr": 1e-4, "feature_dim": 80, "weight_decay": 5e-4, "subsampling_factor": 3, @@ -557,10 +557,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py index 674ea10a6..29625754e 100755 --- a/egs/librispeech/ASR/transducer/train.py +++ b/egs/librispeech/ASR/transducer/train.py @@ -614,10 +614,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index 57bda63fd..792708bc0 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -620,10 +620,10 @@ def run(rank, world_size, args): librispeech = LibriSpeechAsrDataModule(args) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index bcb883fa5..8db9b59e7 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -641,10 +641,10 @@ def run(rank, world_size, args): if params.print_diagnostics: diagnostic = diagnostics.attach_diagnostics(model) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py index 68e247f23..1c3a33870 100755 --- a/egs/librispeech/ASR/transducer_stateless2/train.py +++ b/egs/librispeech/ASR/transducer_stateless2/train.py @@ -629,10 +629,10 @@ def run(rank, world_size, args): if params.print_diagnostics: diagnostic = diagnostics.attach_diagnostics(model) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py index 88987d91c..dafccd088 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py @@ -752,10 +752,10 @@ def run(rank, world_size, args): librispeech = LibriSpeech(manifest_dir=args.manifest_dir) - train_cuts = librispeech.train_clean_100_cuts() if params.full_libri: - train_cuts += librispeech.train_clean_360_cuts() - train_cuts += librispeech.train_other_500_cuts() + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() train_cuts = filter_short_and_long_utterances(train_cuts) From 4fee3e7f1ea6c2aefe7594e325ede1e530e54d3d Mon Sep 17 00:00:00 2001 From: Guo Liyong Date: Mon, 28 Nov 2022 16:55:18 +0800 Subject: [PATCH 08/27] impove comment --- .../ASR/pruned_transducer_stateless7/optim.py | 63 +++++++++++++------ .../ASR/pruned_transducer_stateless7/train.py | 12 +++- 2 files changed, 54 insertions(+), 21 deletions(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 790752fe1..ff8fbb32c 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -64,13 +64,15 @@ class BatchedOptimizer(Optimizer): you can do: with self.batched_params(group["params"]) as batches: - for p, state in batches: + for p, state, p_names in batches: ... Args: group: a parameter group, which is a list of parameters; should be - one of self.groups. + one of self.param_groups. + group_params_names: name for each parameter in group, + which is List[str]. """ batches = defaultdict( list @@ -79,6 +81,7 @@ class BatchedOptimizer(Optimizer): list ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str + assert len(param_group) == len(group_params_names) for p, named_p in zip(param_group, group_params_names): key = (str(p.dtype), *p.shape) batches[key].append(p) @@ -94,9 +97,9 @@ class BatchedOptimizer(Optimizer): stacked_params_dict = dict() # turn batches into a list, in deterministic order. - # pairs will contain pairs of (stacked_param, state), one for each batch - # in `batches`. - pairs = [] + # tuples will contain tuples of (stacked_param, state, stacked_params_names), + # one for each batch in `batches`. + tuples = [] for batch, batch_names in zip(batches, batches_names): p = batch[0] @@ -110,11 +113,11 @@ class BatchedOptimizer(Optimizer): ) p_stacked.grad = grad stacked_params_dict[key] = p_stacked - pairs.append((p_stacked, state, batch_names)) + tuples.append((p_stacked, state, batch_names)) - yield pairs # <-- calling code will do the actual optimization here! + yield tuples # <-- calling code will do the actual optimization here! - for ((stacked_params, _state, _names), batch) in zip(pairs, batches): + for ((stacked_params, _state, _names), batch) in zip(tuples, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) @@ -179,6 +182,11 @@ class ScaledAdam(BatchedOptimizer): show_dominant_parameters=True, ): + assert parameters_names is not None, ( + "Please prepare parameters_names," + "which is a List[List[str]]. Each List[str] is for a group" + "and each str is for a parameter" + ) defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -193,6 +201,7 @@ class ScaledAdam(BatchedOptimizer): ) super(ScaledAdam, self).__init__(params, defaults) + assert len(self.param_groups) == len(parameters_names) self.parameters_names = parameters_names self.show_dominant_parameters = show_dominant_parameters @@ -213,7 +222,6 @@ class ScaledAdam(BatchedOptimizer): loss = closure() batch = True - assert len(self.param_groups) == len(self.parameters_names) for group, group_params_names in zip(self.param_groups, self.parameters_names): @@ -292,7 +300,7 @@ class ScaledAdam(BatchedOptimizer): state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) def _get_clipping_scale( - self, group: dict, pairs: List[Tuple[Tensor, dict, List[str]]] + self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] ) -> float: """ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients @@ -300,12 +308,16 @@ class ScaledAdam(BatchedOptimizer): Args: group: the parameter group, an item in self.param_groups - pairs: a list of pairs of (param, state) where param is a batched set of parameters, with a .grad - (1st dim is batch dim) and state is the state-dict where optimization parameters are kept. + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". """ - assert len(pairs) >= 1 + assert len(tuples) >= 1 clipping_scale = group["clipping_scale"] - (first_p, first_state, _) = pairs[0] + (first_p, first_state, _) = tuples[0] step = first_state["step"] if clipping_scale is None or step == 0: # no clipping. return early on step == 0 because the other @@ -314,7 +326,7 @@ class ScaledAdam(BatchedOptimizer): clipping_update_period = group["clipping_update_period"] tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in pairs: + for (p, state, param_names) in tuples: grad = p.grad if grad.is_sparse: raise RuntimeError( @@ -379,12 +391,27 @@ class ScaledAdam(BatchedOptimizer): ) if self.show_dominant_parameters: assert p.shape[0] == len(param_names) - self._show_gradient_dominating_parameter(pairs, tot_sumsq) + self._show_gradient_dominating_parameter(tuples, tot_sumsq) return ans - def _show_gradient_dominating_parameter(self, pairs, tot_sumsq): + def _show_gradient_dominating_parameter( + self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor + ): + """ + Show information of parameter wihch dominanting tot_sumsq. + + Args: + tuples: a list of tuples of (param, state, param_names) + where param is a batched set of parameters, + with a .grad (1st dim is batch dim) + and state is the state-dict where optimization parameters are kept. + param_names is a List[str] while each str is name for a parameter + in batched set of parameters "param". + tot_sumsq: sumsq of all parameters. Though it's could be calculated + from tuples, we still pass it to save some time. + """ all_sumsq_orig = {} - for (p, state, batch_param_names) in pairs: + for (p, state, batch_param_names) in tuples: # p is a stacked batch parameters. batch_grad = p.grad if p.numel() == p.shape[0]: # a batch of scalars diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index e5a3e68df..31a3a0505 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -989,9 +989,15 @@ def run(rank, world_size, args): model = DDP(model, device_ids=[rank], find_unused_parameters=True) parameters_names = [] - parameters_names.append([name_param_pair[0] for name_param_pair in model.named_parameters()]) - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, - clipping_scale=2.0, parameters_names=parameters_names) + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) From ece728d895c11545eb3232caa4f6a1c907192064 Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 28 Nov 2022 22:34:02 +0800 Subject: [PATCH 09/27] Apply delay penalty on k2 ctc loss (#669) * add init files * fix bug, apply delay penalty * fix decoding code and getting timestamps * add option applying delay penalty on ctc log-prob * fix bug of streaming decoding * minor change for bpe-based case * add test_model.py * add README.md * add CI --- .flake8 | 2 +- ...n-librispeech-conformer-ctc3-2022-11-28.sh | 119 ++ ...-librispeech-conformer-ctc3-2022-11-28.yml | 151 +++ egs/librispeech/ASR/RESULTS.md | 102 +- .../ASR/conformer_ctc3/__init__.py | 1 + .../ASR/conformer_ctc3/asr_datamodule.py | 1 + .../ASR/conformer_ctc3/conformer.py | 1 + egs/librispeech/ASR/conformer_ctc3/decode.py | 1004 +++++++++++++++ .../ASR/conformer_ctc3/encoder_interface.py | 1 + egs/librispeech/ASR/conformer_ctc3/export.py | 292 +++++ .../ASR/conformer_ctc3/jit_pretrained.py | 406 ++++++ egs/librispeech/ASR/conformer_ctc3/lstmp.py | 1 + egs/librispeech/ASR/conformer_ctc3/model.py | 122 ++ egs/librispeech/ASR/conformer_ctc3/optim.py | 1 + .../ASR/conformer_ctc3/pretrained.py | 458 +++++++ egs/librispeech/ASR/conformer_ctc3/scaling.py | 1 + .../ASR/conformer_ctc3/scaling_converter.py | 1 + .../ASR/conformer_ctc3/test_model.py | 82 ++ egs/librispeech/ASR/conformer_ctc3/train.py | 1108 +++++++++++++++++ icefall/bpe_graph_compiler.py | 5 +- icefall/char_graph_compiler.py | 3 +- icefall/checkpoint.py | 2 +- icefall/graph_compiler.py | 4 + icefall/utils.py | 51 +- 24 files changed, 3876 insertions(+), 43 deletions(-) create mode 100755 .github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh create mode 100644 .github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml create mode 120000 egs/librispeech/ASR/conformer_ctc3/__init__.py create mode 120000 egs/librispeech/ASR/conformer_ctc3/asr_datamodule.py create mode 120000 egs/librispeech/ASR/conformer_ctc3/conformer.py create mode 100755 egs/librispeech/ASR/conformer_ctc3/decode.py create mode 120000 egs/librispeech/ASR/conformer_ctc3/encoder_interface.py create mode 100755 egs/librispeech/ASR/conformer_ctc3/export.py create mode 100755 egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py create mode 120000 egs/librispeech/ASR/conformer_ctc3/lstmp.py create mode 100644 egs/librispeech/ASR/conformer_ctc3/model.py create mode 120000 egs/librispeech/ASR/conformer_ctc3/optim.py create mode 100755 egs/librispeech/ASR/conformer_ctc3/pretrained.py create mode 120000 egs/librispeech/ASR/conformer_ctc3/scaling.py create mode 120000 egs/librispeech/ASR/conformer_ctc3/scaling_converter.py create mode 100755 egs/librispeech/ASR/conformer_ctc3/test_model.py create mode 100755 egs/librispeech/ASR/conformer_ctc3/train.py diff --git a/.flake8 b/.flake8 index 609fa2c03..a0f44263c 100644 --- a/.flake8 +++ b/.flake8 @@ -11,7 +11,7 @@ per-file-ignores = egs/*/ASR/*/scaling.py: E501, egs/librispeech/ASR/lstm_transducer_stateless*/*.py: E501, E203 egs/librispeech/ASR/conv_emformer_transducer_stateless*/*.py: E501, E203 - egs/librispeech/ASR/conformer_ctc2/*py: E501, + egs/librispeech/ASR/conformer_ctc*/*py: E501, egs/librispeech/ASR/RESULTS.md: E999, # invalid escape sequence (cause by tex formular), W605 diff --git a/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh new file mode 100755 index 000000000..27944807f --- /dev/null +++ b/.github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh @@ -0,0 +1,119 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conformer-ctc3-2022-11-27 + +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +git lfs pull --include "data/*" +git lfs pull --include "exp/jit_trace.pt" +git lfs pull --include "exp/pretrained.pt" +ln -s pretrained.pt epoch-99.pt +ls -lh *.pt +popd + +log "Decode with models exported by torch.jit.trace()" + +for m in ctc-decoding 1best; do + ./conformer_ctc3/jit_pretrained.py \ + --model-filename $repo/exp/jit_trace.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +log "Export to torchscript model" + +./conformer_ctc3/export.py \ + --exp-dir $repo/exp \ + --lang-dir $repo/data/lang_bpe_500 \ + --jit-trace 1 \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 + +ls -lh $repo/exp/*.pt + +log "Decode with models exported by torch.jit.trace()" + +for m in ctc-decoding 1best; do + ./conformer_ctc3/jit_pretrained.py \ + --model-filename $repo/exp/jit_trace.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for m in ctc-decoding 1best; do + ./conformer_ctc3/pretrained.py \ + --checkpoint $repo/exp/pretrained.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p conformer_ctc3/exp + ln -s $PWD/$repo/exp/pretrained.pt conformer_ctc3/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh conformer_ctc3/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in ctc-decoding 1best; do + log "Decoding with $method" + ./conformer_ctc3/decode.py \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --exp-dir conformer_ctc3/exp/ \ + --max-duration $max_duration \ + --decoding-method $method \ + --lm-dir data/lm + done + + rm conformer_ctc3/exp/*.pt +fi diff --git a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml new file mode 100644 index 000000000..21f396c32 --- /dev/null +++ b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml @@ -0,0 +1,151 @@ +# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-conformer-ctc3-2022-11-28 +# zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +jobs: + run_librispeech_2022_11_28_conformer_ctc3: + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other-v2 + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree sox + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-conformer-ctc3-2022-11-28.sh + + - name: Display decoding results for librispeech conformer_ctc3 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/librispeech/ASR/ + tree ./conformer_ctc3/exp + + cd conformer_ctc3 + echo "results for conformer_ctc3" + echo "===ctc-decoding===" + find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===1best===" + find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Upload decoding results for librispeech conformer_ctc3 + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-conformer_ctc3-2022-11-28 + path: egs/librispeech/ASR/conformer_ctc3/exp/ diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 030e47b86..efd60ba81 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,106 @@ ## Results +### LibriSpeech BPE training results (Conformer CTC, supporting delay penalty) + +#### [conformer_ctc3](./conformer_ctc3) + +It implements Conformer model training with CTC loss. +For streaming mode, it supports symbol delay penalty. + +See for more details. + +##### training on full librispeech + +This model contains 12 encoder layers. The number of model parameters is 77352694. + +The WERs are: + +| | test-clean | test-other | comment | +|-------------------------------------|------------|------------|----------------------| +| ctc-decoding | 3.09 | 7.62 | --epoch 25 --avg 7 | +| 1best | 2.87 | 6.44 | --epoch 25 --avg 7 | +| nbest | 2.88 | 6.5 | --epoch 25 --avg 7 | +| nbest-rescoring | 2.71 | 6.1 | --epoch 25 --avg 7 | +| whole-lattice-rescoring | 2.71 | 6.04 | --epoch 25 --avg 7 | + +The training command is: + +```bash +./conformer_ctc3/train.py \ + --world-size 4 \ + --num-epochs 25 \ + --start-epoch 1 \ + --exp-dir conformer_ctc3/full \ + --full-libri 1 \ + --max-duration 300 \ + --master-port 12345 +``` + +The tensorboard log can be found at + + +The decoding command using different methods is: +```bash +for method in ctc-decoding 1best nbest nbest-rescoring whole-lattice-rescoring; do + ./conformer_ctc3/decode.py \ + --epoch 25 \ + --avg 7 \ + --exp-dir conformer_ctc3/exp \ + --max-duration 300 \ + --decoding-method $method \ + --manifest-dir data/fbank \ + --lm-dir data/lm \ +done +``` + +Pretrained models, training logs, decoding logs, and decoding results +are available at + + +The command to train a streaming model with symbol delay penalty is: +```bash +./conformer_ctc3/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir conformer_ctc3/exp \ + --full-libri 1 \ + --dynamic-chunk-training 1 \ + --causal-convolution 1 \ + --short-chunk-size 25 \ + --num-left-chunks 4 \ + --max-duration 300 \ + --delay-penalty 0.1 +``` +To evaluate symbol delay, you should: +(1) Generate cuts with word-time alignments: +```bash +./local/add_alignment_librispeech.py \ + --alignments-dir data/alignment \ + --cuts-in-dir data/fbank \ + --cuts-out-dir data/fbank_ali +``` +(2) Set the argument "--manifest-dir data/fbank_ali" while decoding. +For example: +```bash +./conformer_ctc3/decode.py \ + --epoch 25 \ + --avg 7 \ + --exp-dir ./conformer_ctc3/exp \ + --max-duration 300 \ + --decoding-method ctc-decoding \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --decode-chunk-size 16 \ + --left-context 64 \ + --manifest-dir data/fbank_ali +``` +Note: It supports to calculate symbol delay with following decoding methods: + - ctc-greedy-search + - ctc-decoding + - 1best + + ### pruned_transducer_stateless8 (zipformer + multidataset) See for more details. @@ -115,7 +216,6 @@ done ``` - ### LibriSpeech BPE training results (Pruned Stateless LSTM RNN-T + gradient filter) #### [lstm_transducer_stateless3](./lstm_transducer_stateless3) diff --git a/egs/librispeech/ASR/conformer_ctc3/__init__.py b/egs/librispeech/ASR/conformer_ctc3/__init__.py new file mode 120000 index 000000000..b24e5e357 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/__init__.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/__init__.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc3/asr_datamodule.py b/egs/librispeech/ASR/conformer_ctc3/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc3/conformer.py b/egs/librispeech/ASR/conformer_ctc3/conformer.py new file mode 120000 index 000000000..3b84b9573 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/conformer.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/conformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py new file mode 100755 index 000000000..8eca2ae02 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -0,0 +1,1004 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) decode in non-streaming mode (take ctc-decoding as an example) +./conformer_ctc3/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./conformer_ctc3/exp \ + --max-duration 600 \ + --decoding-method ctc-decoding + +(2) decode in streaming mode (take ctc-decoding as an example) +./conformer_ctc3/decode.py \ + --epoch 30 \ + --avg 15 \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --decode-chunk-size 16 \ + --left-context 64 \ + --exp-dir ./conformer_ctc3/exp \ + --max-duration 600 \ + --decoding-method ctc-decoding + +To evaluate symbol delay, you should: +(1) Generate cuts with word-time alignments: +./local/add_alignment_librispeech.py \ + --alignments-dir data/alignment \ + --cuts-in-dir data/fbank \ + --cuts-out-dir data/fbank_ali +(2) Set the argument "--manifest-dir data/fbank_ali" while decoding. +For example: +./conformer_ctc3/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./conformer_ctc3/exp \ + --max-duration 600 \ + --decoding-method ctc-decoding \ + --simulate-streaming 1 \ + --causal-convolution 1 \ + --decode-chunk-size 16 \ + --left-context 64 \ + --manifest-dir data/fbank_ali +Note: It supports calculating symbol delay with following decoding methods: + - ctc-greedy-search + - ctc-decoding + - 1best +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from train import add_model_arguments, get_ctc_model, get_params + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + DecodingResults, + get_texts, + get_texts_with_timestamp, + make_pad_mask, + parse_hyp_and_timestamp, + setup_logger, + store_transcripts_and_timestamps, + str2bool, + write_error_stats_with_timestamps, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless4/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) ctc-greedy-search. It only use CTC output and a sentence piece + model for decoding. It produces the same results with ctc-decoding. + - (2) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (3) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (4) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + you have trained an RNN LM using ./rnn_lm/train.py + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def ctc_greedy_search( + ctc_probs: torch.Tensor, + nnet_output_lens: torch.Tensor, +) -> List[List[int]]: + """Apply CTC greedy search + Args: + ctc_probs (torch.Tensor): (batch, max_len, feat_dim) + nnet_output_lens (torch.Tensor): (batch, ) + Returns: + List[List[int]]: best path result + """ + topk_prob, topk_index = ctc_probs.topk(1, dim=2) # (B, maxlen, 1) + topk_index = topk_index.squeeze(2) # (B, maxlen) + mask = make_pad_mask(nnet_output_lens) + topk_index = topk_index.masked_fill_(mask, 0) # (B, maxlen) + hyps = [hyp.tolist() for hyp in topk_index] + scores = topk_prob.max(1) + ret_hyps = [] + timestamps = [] + for i in range(len(hyps)): + hyp, time = remove_duplicates_and_blank(hyps[i]) + ret_hyps.append(hyp) + timestamps.append(time) + return ret_hyps, timestamps, scores + + +def remove_duplicates_and_blank(hyp: List[int]) -> Tuple[List[int], List[int]]: + # modified from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py + new_hyp: List[int] = [] + time: List[int] = [] + cur = 0 + while cur < len(hyp): + if hyp[cur] != 0: + new_hyp.append(hyp[cur]) + time.append(cur) + prev = cur + while cur < len(hyp) and hyp[cur] == hyp[prev]: + cur += 1 + return new_hyp, time + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, Tuple[List[List[str]], List[List[float]]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder(feature, feature_lens) + + nnet_output = model.get_ctc_output(encoder_out) + # nnet_output is (N, T, C) + + if params.decoding_method == "ctc-greedy-search": + hyps, timestamps, _ = ctc_greedy_search( + nnet_output, + encoder_out_lens, + ) + res = DecodingResults(hyps=hyps, timestamps=timestamps) + hyps, timestamps = parse_hyp_and_timestamp( + res=res, + sp=bpe_model, + subsampling_factor=params.subsampling_factor, + frame_shift_ms=params.frame_shift_ms, + ) + key = "ctc-greedy-search" + return {key: (hyps, timestamps)} + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + if params.decoding_method in ["1best", "nbest", "nbest-oracle"]: + hlg_scale_list = [0.2, 0.4, 0.6, 0.8, 1.0] + + ori_scores = decoding_graph.scores.clone() + + ans = {} + for hlg_scale in hlg_scale_list: + decoding_graph.scores = ori_scores * hlg_scale + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + key_suffix = f"-HLG-scale-{hlg_scale}" + + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle-{params.num_paths}-nbest-scale-{params.nbest_scale}" # noqa + timestamps = [[] for _ in range(len(hyps))] + ans[key + key_suffix] = (hyps, timestamps) + + elif params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, + use_double_scores=params.use_double_scores, + ) + key = "no-rescore" + res = get_texts_with_timestamp(best_path) + hyps, timestamps = parse_hyp_and_timestamp( + res=res, + subsampling_factor=params.subsampling_factor, + frame_shift_ms=params.frame_shift_ms, + word_table=word_table, + ) + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + timestamps = [[] for _ in range(len(hyps))] + + ans[key + key_suffix] = (hyps, timestamps) + + return ans + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + res = get_texts_with_timestamp(best_path) + hyps, timestamps = parse_hyp_and_timestamp( + res=res, + sp=bpe_model, + subsampling_factor=params.subsampling_factor, + frame_shift_ms=params.frame_shift_ms, + ) + key = "ctc-decoding" + return {key: (hyps, timestamps)} + + assert params.decoding_method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.decoding_method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.decoding_method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + timestamps = [[] for _ in range(len(hyps))] + ans[lm_scale_str] = (hyps, timestamps) + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + timestamps_ref = [] + for cut in batch["supervisions"]["cut"]: + for s in cut.supervisions: + time = [] + if s.alignment is not None and "word" in s.alignment: + time = [ + aliword.start + for aliword in s.alignment["word"] + if aliword.symbol != "" + ] + timestamps_ref.append(time) + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + for name, (hyps, timestamps_hyp) in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) and len(timestamps_hyp) == len( + timestamps_ref + ) + for cut_id, hyp_words, ref_text, time_hyp, time_ref in zip( + cut_ids, hyps, texts, timestamps_hyp, timestamps_ref + ): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[ + str, + List[Tuple[List[str], List[str], List[str], List[float], List[float]]], + ], +): + test_set_wers = dict() + test_set_delays = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts_and_timestamps(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer, mean_delay, var_delay = write_error_stats_with_timestamps( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + test_set_delays[key] = (mean_delay, var_delay) + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) + delays_info = ( + params.res_dir + / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(delays_info, "w") as f: + print("settings\tsymbol-delay", file=f) + for key, val in test_set_delays: + print( + "{}\tmean: {}s, variance: {}".format(key, val[0], val[1]), + file=f, + ) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_delays: + s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ( + "ctc-greedy-search", + "ctc-decoding", + "1best", + "nbest", + "nbest-rescoring", + "whole-lattice-rescoring", + "nbest-oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + params.vocab_size = num_classes + params.sos_id = sos_id + params.eos_id = eos_id + + if params.decoding_method in ["ctc-decoding", "ctc-greedy-search"]: + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.decoding_method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.decoding_method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + logging.info("About to create model") + model = get_ctc_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conformer_ctc3/encoder_interface.py b/egs/librispeech/ASR/conformer_ctc3/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc3/export.py b/egs/librispeech/ASR/conformer_ctc3/export.py new file mode 100755 index 000000000..c5b95d981 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/export.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: + +(1) Export to torchscript model using torch.jit.trace() + +./conformer_ctc3/export.py \ + --exp-dir ./conformer_ctc3/exp \ + --lang-dir data/lang_bpe_500 \ + --epoch 20 \ + --avg 10 \ + --jit-trace 1 + +It will generates the file: `jit_trace.pt`. + +(2) Export `model.state_dict()` + +./conformer_ctc3/export.py \ + --exp-dir ./conformer_ctc3/exp \ + --lang-dir data/lang_bpe_500 \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `conformer_ctc3/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./conformer_ctc3/decode.py \ + --exp-dir ./conformer_ctc3/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 100 \ + --lang-dir data/lang_bpe_500 +""" + +import argparse +import logging +from pathlib import Path + +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_ctc_model, get_params + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless4/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--jit-trace", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--streaming-model", + type=str2bool, + default=False, + help="""Whether to export a streaming model, if the models in exp-dir + are streaming model, this should be True. + """, + ) + + add_model_arguments(parser) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + params.vocab_size = num_classes + + if params.streaming_model: + assert params.causal_convolution + + logging.info(params) + + logging.info("About to create model") + model = get_ctc_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit_trace: + # TODO: will support streaming mode + assert not params.streaming_model + convert_scaled_to_non_scaled(model, inplace=True) + + logging.info("Using torch.jit.trace()") + + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + traced_model = torch.jit.trace(model, (x, x_lens)) + + filename = params.exp_dir / "jit_trace.pt" + traced_model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.trace()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py new file mode 100755 index 000000000..c96defd23 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py @@ -0,0 +1,406 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Usage (for non-streaming mode): + +(1) ctc-decoding +./conformer_ctc3/pretrained.py \ + --checkpoint conformer_ctc3/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + test_wavs/1089-134686-0001.wav + +(2) 1best +./conformer_ctc3/pretrained.py \ + --checkpoint conformer_ctc3/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + test_wavs/1089-134686-0001.wav + +(3) nbest-rescoring +./conformer_ctc3/pretrained.py \ + --checkpoint conformer_ctc3/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + test_wavs/1089-134686-0001.wav + +(4) whole-lattice-rescoring +./conformer_ctc3/pretrained.py \ + --checkpoint conformer_ctc3/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + test_wavs/1089-134686-0001.wav +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the torchscript model.", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + params.vocab_size = params.num_classes + + logging.info(f"{params}") + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + model = torch.jit.load(args.model_filename) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + nnet_output, _ = model(features, feature_lengths) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conformer_ctc3/lstmp.py b/egs/librispeech/ASR/conformer_ctc3/lstmp.py new file mode 120000 index 000000000..4f377cd01 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/lstmp.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc3/model.py b/egs/librispeech/ASR/conformer_ctc3/model.py new file mode 100644 index 000000000..f56df2006 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/model.py @@ -0,0 +1,122 @@ +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import math +from typing import Tuple + +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import ScaledLinear + + +class CTCModel(nn.Module): + """It implements https://www.cs.toronto.edu/~graves/icml_2006.pdf + "Connectionist Temporal Classification: Labelling Unsegmented + Sequence Data with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + encoder_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + encoder_dim: + The feature embedding dimension. + vocab_size: + The vocabulary size. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + + self.encoder = encoder + self.ctc_output_module = nn.Sequential( + nn.Dropout(p=0.1), + ScaledLinear(encoder_dim, vocab_size), + ) + + def get_ctc_output( + self, + encoder_out: torch.Tensor, + delay_penalty: float = 0.0, + blank_threshold: float = 0.99, + ): + """Compute ctc log-prob and optionally (delay_penalty > 0) apply delay penalty. + We first split utterance into sub-utterances according to the + blank probs, and then add sawtooth-like "blank-bonus" values to + the blank probs. + See https://github.com/k2-fsa/icefall/pull/669 for details. + + Args: + encoder_out: + A tensor with shape of (N, T, C). + delay_penalty: + A constant used to scale the delay penalty score. + blank_threshold: + The threshold used to split utterance into sub-utterances. + """ + output = self.ctc_output_module(encoder_out) + log_prob = nn.functional.log_softmax(output, dim=-1) + + if self.training and delay_penalty > 0: + T_arange = torch.arange(encoder_out.shape[1]).to(device=encoder_out.device) + # split into sub-utterances using the blank-id + mask = log_prob[:, :, 0] >= math.log(blank_threshold) # (B, T) + mask[:, 0] = True + cummax_out = (T_arange * mask).cummax(dim=-1)[0] # (B, T) + # the sawtooth "blank-bonus" value + penalty = T_arange - cummax_out # (B, T) + penalty_all = torch.zeros_like(log_prob) + penalty_all[:, :, 0] = delay_penalty * penalty + # apply latency penalty on probs + log_prob = log_prob + penalty_all + + return log_prob + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + warmup: float = 1.0, + delay_penalty: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + delay_penalty: + A constant used to scale the delay penalty score. + """ + encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup) + assert torch.all(encoder_out_lens > 0) + nnet_output = self.get_ctc_output(encoder_out, delay_penalty=delay_penalty) + return nnet_output, encoder_out_lens diff --git a/egs/librispeech/ASR/conformer_ctc3/optim.py b/egs/librispeech/ASR/conformer_ctc3/optim.py new file mode 120000 index 000000000..e2deb4492 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc3/pretrained.py b/egs/librispeech/ASR/conformer_ctc3/pretrained.py new file mode 100755 index 000000000..3628d6a5f --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/pretrained.py @@ -0,0 +1,458 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +Usage (for non-streaming mode): + +(1) ctc-decoding +./conformer_ctc3/pretrained.py \ + --checkpoint conformer_ctc3/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + test_wavs/1089-134686-0001.wav + +(2) 1best +./conformer_ctc3/pretrained.py \ + --checkpoint conformer_ctc3/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + test_wavs/1089-134686-0001.wav + +(3) nbest-rescoring +./conformer_ctc3/pretrained.py \ + --checkpoint conformer_ctc3/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + test_wavs/1089-134686-0001.wav + +(4) whole-lattice-rescoring +./conformer_ctc3/pretrained.py \ + --checkpoint conformer_ctc3/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + test_wavs/1089-134686-0001.wav +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_ctc_model, get_params + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + params.vocab_size = params.num_classes + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("About to create model") + model = get_ctc_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + # model forward + if params.simulate_streaming: + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=features, + x_lens=feature_lengths, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) + nnet_output = model.get_ctc_output(encoder_out) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conformer_ctc3/scaling.py b/egs/librispeech/ASR/conformer_ctc3/scaling.py new file mode 120000 index 000000000..09d802cc4 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc3/scaling_converter.py b/egs/librispeech/ASR/conformer_ctc3/scaling_converter.py new file mode 120000 index 000000000..3b667058d --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc3/test_model.py b/egs/librispeech/ASR/conformer_ctc3/test_model.py new file mode 100755 index 000000000..b97b7eed8 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/test_model.py @@ -0,0 +1,82 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./conformer_ctc3/test_model.py +""" + +import torch + +from train import get_params, get_ctc_model + + +def test_model(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.unk_id = 2 + + params.dynamic_chunk_training = False + params.short_chunk_size = 25 + params.num_left_chunks = 4 + params.causal_convolution = False + + model = get_ctc_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + features = torch.randn(2, 100, 80) + feature_lengths = torch.full((2,), 100) + model(x=features, x_lens=feature_lengths) + + +def test_model_streaming(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.unk_id = 2 + + params.dynamic_chunk_training = True + params.short_chunk_size = 25 + params.num_left_chunks = 4 + params.causal_convolution = True + + model = get_ctc_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + features = torch.randn(2, 100, 80) + feature_lengths = torch.full((2,), 100) + encoder_out, _ = model.encoder(x=features, x_lens=feature_lengths) + model.get_ctc_output(encoder_out) + + +def main(): + test_model() + test_model_streaming() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py new file mode 100755 index 000000000..fb3b740c1 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc3/train.py @@ -0,0 +1,1108 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./conformer_ctc3/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir conformer_ctc3/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./conformer_ctc3/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir conformer_ctc3/exp \ + --full-libri 1 \ + --max-duration 550 + +# train a streaming model +./conformer_ctc3/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir conformer_ctc3/exp \ + --full-libri 1 \ + --dynamic-chunk-training 1 \ + --causal-convolution 1 \ + --short-chunk-size 25 \ + --num-left-chunks 4 \ + --max-duration 300 \ + --delay-penalty 0.0 +""" + +import argparse +import copy +import logging +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import CTCModel +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.graph_compiler import CtcTrainingGraphCompiler +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc3/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="""The initial learning rate. This value should not need to be + changed.""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant used to scale the symbol delay penalty, + to encourage symbol emit earlier for streaming models. + It is almost the same as the `delay_penalty` in our `rnnt_loss`, See + https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + + parser.add_argument( + "--nnet-delay-penalty", + type=float, + default=0.0, + help="""A constant to penalize symbol delay, which is applied on + the nnet_output after log-softmax. + We recommend using --delay-penalty instead. + See https://github.com/k2-fsa/icefall/pull/669 for details.""", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "encoder_dim": 512, + "nhead": 8, + "dim_feedforward": 2048, + "num_encoder_layers": 12, + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Conformer( + num_features=params.feature_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + dynamic_chunk_training=params.dynamic_chunk_training, + short_chunk_size=params.short_chunk_size, + num_left_chunks=params.num_left_chunks, + causal=params.causal_convolution, + ) + return encoder + + +def get_ctc_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + model = CTCModel( + encoder=encoder, + encoder_dim=params.encoder_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler], + batch: dict, + is_training: bool, + warmup: float = 1.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + with torch.set_grad_enabled(is_training): + nnet_output, encoder_out_lens = model( + feature, + feature_lens, + warmup=warmup, + delay_penalty=params.nnet_delay_penalty if warmup >= 1.0 else 0, + ) + assert torch.all(encoder_out_lens > 0) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + if isinstance(graph_compiler, BpeCtcTrainingGraphCompiler): + # Works with a BPE model + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + elif isinstance(graph_compiler, CtcTrainingGraphCompiler): + # Works with a phone lexicon + decoding_graph = graph_compiler.compile(texts) + else: + raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + delay_penalty=params.delay_penalty if warmup >= 1.0 else 0.0, + reduction=params.reduction, + use_double_scores=params.use_double_scores, + ) + ctc_loss_is_finite = torch.isfinite(ctc_loss) + if not torch.all(ctc_loss_is_finite): + logging.info("Not all losses are finite!\n" f"ctc_loss: {ctc_loss}") + ctc_loss = ctc_loss[ctc_loss_is_finite] + + # If either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if torch.all(~ctc_loss_is_finite): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + loss = ctc_loss.sum() + + assert loss.requires_grad == is_training + + info = MetricsTracker() + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. + info["frames"] = supervision_segments[:, 2].sum().item() + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler], + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler], + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + if params.print_diagnostics and batch_idx == 30: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + params.vocab_size = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + if "lang_bpe" in str(params.lang_dir): + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + elif "lang_phone" in str(params.lang_dir): + graph_compiler = CtcTrainingGraphCompiler( + lexicon, + device=device, + ) + # Manually add the sos/eos ID with their default values + # from the BPE recipe which we're adapting here. + graph_compiler.sos_id = 1 + graph_compiler.eos_id = 1 + else: + raise ValueError( + f"Unsupported type of lang dir (we expected it to have " + f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" + ) + + if params.dynamic_chunk_training: + assert ( + params.causal_convolution + ), "dynamic_chunk_training requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_ctc_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + diagnostic = diagnostics.attach_diagnostics(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if params.start_batch <= 0 and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: Union[BpeCtcTrainingGraphCompiler, CtcTrainingGraphCompiler], + params: AttributeDict, + warmup: float, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + warmup=warmup, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/icefall/bpe_graph_compiler.py b/icefall/bpe_graph_compiler.py index e76b7ea32..d9659c2dd 100644 --- a/icefall/bpe_graph_compiler.py +++ b/icefall/bpe_graph_compiler.py @@ -83,11 +83,12 @@ class BpeCtcTrainingGraphCompiler(object): Args: piece_ids: It is a list-of-list integer IDs. - modified: + modified: See :func:`k2.ctc_graph` for its meaning. Return: Return an FsaVec, which is the result of composing a CTC topology with linear FSAs constructed from the given piece IDs. """ - return k2.ctc_graph(piece_ids, modified=modified, device=self.device) + graph = k2.ctc_graph(piece_ids, modified=modified, device=self.device) + return graph diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py index c31db6e4c..5f9571d42 100644 --- a/icefall/char_graph_compiler.py +++ b/icefall/char_graph_compiler.py @@ -117,4 +117,5 @@ class CharCtcTrainingGraphCompiler(object): CTC topology with linear FSAs constructed from the given piece IDs. """ - return k2.ctc_graph(token_ids, modified=modified, device=self.device) + graph = k2.ctc_graph(token_ids, modified=modified, device=self.device) + return graph diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index f0663a1df..c83c56a53 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -298,7 +298,7 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]: if not result: logging.warn(f"Invalid checkpoint filename {c}") continue - + iter_checkpoints.append((int(result.group(1)), c)) # iter_checkpoints is a list of tuples. Each tuple contains diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py index e2ff03f61..84be81254 100644 --- a/icefall/graph_compiler.py +++ b/icefall/graph_compiler.py @@ -79,6 +79,10 @@ class CtcTrainingGraphCompiler(object): fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops) + self.ctc_topo._is_repeat_token_ = ( + self.ctc_topo.labels != self.ctc_topo.aux_labels + ) + decoding_graph = k2.compose( self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False ) diff --git a/icefall/utils.py b/icefall/utils.py index b4d8e9a51..d852491c8 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -670,8 +670,8 @@ def write_error_stats_with_timestamps( all_delay = [] for cut_id, ref, hyp, time_ref, time_hyp in results: ali = kaldialign.align(ref, hyp, ERR) - has_time_ref = len(time_ref) > 0 - if has_time_ref: + has_time = len(time_ref) > 0 and len(time_hyp) > 0 + if has_time: # pointer to timestamp_hyp p_hyp = 0 # pointer to timestamp_ref @@ -680,28 +680,28 @@ def write_error_stats_with_timestamps( if ref_word == ERR: ins[hyp_word] += 1 words[hyp_word][3] += 1 - if has_time_ref: + if has_time: p_hyp += 1 elif hyp_word == ERR: dels[ref_word] += 1 words[ref_word][4] += 1 - if has_time_ref: + if has_time: p_ref += 1 elif hyp_word != ref_word: subs[(ref_word, hyp_word)] += 1 words[ref_word][1] += 1 words[hyp_word][2] += 1 - if has_time_ref: + if has_time: p_hyp += 1 p_ref += 1 else: words[ref_word][0] += 1 num_corr += 1 - if has_time_ref: + if has_time: all_delay.append(time_hyp[p_hyp] - time_ref[p_ref]) p_hyp += 1 p_ref += 1 - if has_time_ref: + if has_time: assert p_hyp == len(hyp), (p_hyp, len(hyp)) assert p_ref == len(ref), (p_ref, len(ref)) @@ -1327,10 +1327,9 @@ def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]: def parse_hyp_and_timestamp( res: DecodingResults, - decoding_method: str, - sp: spm.SentencePieceProcessor, subsampling_factor: int, frame_shift_ms: float = 10, + sp: Optional[spm.SentencePieceProcessor] = None, word_table: Optional[k2.SymbolTable] = None, ) -> Tuple[List[List[str]], List[List[float]]]: """Parse hypothesis and timestamp. @@ -1338,51 +1337,29 @@ def parse_hyp_and_timestamp( Args: res: A DecodingResults object. - decoding_method: - Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_LG - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - sp: - The BPE model. subsampling_factor: The integer subsampling factor. frame_shift_ms: The float frame shift used for feature extraction. + sp: + The BPE model. word_table: The word symbol table. Returns: Return a list of hypothesis and timestamp. """ - assert decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_LG", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - hyps = [] timestamps = [] N = len(res.hyps) assert len(res.timestamps) == N, (len(res.timestamps), N) use_word_table = False - if ( - decoding_method == "fast_beam_search_nbest_LG" - and decoding_method == "fast_beam_search_LG" - ): - assert word_table is not None + if word_table is not None: + assert sp is None use_word_table = True + else: + assert sp is not None and word_table is None for i in range(N): time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms) From 4b5bc480e8a5ac253dcd22b08dfa59083dadd6fd Mon Sep 17 00:00:00 2001 From: marcoyang1998 <45973641+marcoyang1998@users.noreply.github.com> Date: Wed, 30 Nov 2022 17:26:05 +0800 Subject: [PATCH 10/27] Add low-order density ratio in RNNLM shallow fusion (#678) * Support LODR in RNNLM shallow fusion * fix style * fix code style * update workflow and CI * update results * propagate changes to stateless3 * add decoding results for stateless3+giga * fix CI --- ...-lstm-transducer-stateless2-2022-09-03.yml | 67 ++++- ...-lstm-transducer-stateless2-2022-09-03.yml | 15 +- egs/librispeech/ASR/RESULTS.md | 87 ++++++ .../ASR/lstm_transducer_stateless2/decode.py | 51 +++- .../beam_search.py | 264 ++++++++++++++++++ .../pruned_transducer_stateless3/decode.py | 181 +++++++++++- 6 files changed, 646 insertions(+), 19 deletions(-) diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index 6ce92d022..ac5b15979 100755 --- a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -16,6 +16,7 @@ log "Downloading pre-trained model from $repo_url" git lfs install git clone $repo_url repo=$(basename $repo_url) +abs_repo=$(realpath $repo) log "Display test files" tree $repo/ @@ -178,21 +179,27 @@ echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm log "Download pre-trained RNN-LM model from ${lm_repo_url}" - git clone $lm_repo_url + GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url lm_repo=$(basename $lm_repo_url) pushd $lm_repo git lfs pull --include "exp/pretrained.pt" - cd exp - ln -s pretrained.pt epoch-88.pt + mv exp/pretrained.pt exp/epoch-88.pt popd + mkdir -p lstm_transducer_stateless2/exp + ln -sf $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh lstm_transducer_stateless2/exp + + log "Decoding test-clean and test-other" + ./lstm_transducer_stateless2/decode.py \ --use-averaged-model 0 \ - --epoch 99 \ + --epoch 999 \ --avg 1 \ - --exp-dir $repo/exp \ - --lang-dir $repo/data/lang_bpe_500 \ - --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --exp-dir lstm_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method modified_beam_search_rnnlm_shallow_fusion \ --beam 4 \ @@ -204,6 +211,52 @@ if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"shallow-fusion" ]]; then --rnn-lm-tie-weights 1 fi +if [[ x"${GITHUB_EVENT_LABEL_NAME}" == x"LODR" ]]; then + bigram_repo_url=https://huggingface.co/marcoyang/librispeech_bigram + log "Download bi-gram LM from ${bigram_repo_url}" + GIT_LFS_SKIP_SMUDGE=1 git clone $bigram_repo_url + bigramlm_repo=$(basename $bigram_repo_url) + pushd $bigramlm_repo + git lfs pull --include "2gram.fst.txt" + cp 2gram.fst.txt $abs_repo/data/lang_bpe_500/. + popd + + lm_repo_url=https://huggingface.co/ezerhouni/icefall-librispeech-rnn-lm + log "Download pre-trained RNN-LM model from ${lm_repo_url}" + GIT_LFS_SKIP_SMUDGE=1 git clone $lm_repo_url + lm_repo=$(basename $lm_repo_url) + pushd $lm_repo + git lfs pull --include "exp/pretrained.pt" + mv exp/pretrained.pt exp/epoch-88.pt + popd + + mkdir -p lstm_transducer_stateless2/exp + ln -sf $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh lstm_transducer_stateless2/exp + + log "Decoding test-clean and test-other" + + ./lstm_transducer_stateless2/decode.py \ + --use-averaged-model 0 \ + --epoch 999 \ + --avg 1 \ + --exp-dir lstm_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_rnnlm_LODR \ + --beam 4 \ + --rnn-lm-scale 0.3 \ + --rnn-lm-exp-dir $lm_repo/exp \ + --rnn-lm-epoch 88 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 \ + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 +fi + if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" ]]; then mkdir -p lstm_transducer_stateless2/exp ln -s $PWD/$repo/exp/pretrained.pt lstm_transducer_stateless2/exp/epoch-999.pt diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index a90841fb6..5f0acf9b8 100644 --- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -18,7 +18,7 @@ on: jobs: run_librispeech_lstm_transducer_stateless2_2022_09_03: - if: github.event.label.name == 'ready' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule' + if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule' runs-on: ${{ matrix.os }} strategy: matrix: @@ -139,9 +139,20 @@ jobs: find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 find modified_beam_search_rnnlm_shallow_fusion -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + - name: Display decoding results for lstm_transducer_stateless2 + if: github.event.label.name == 'LODR' + shell: bash + run: | + cd egs/librispeech/ASR + tree lstm_transducer_stateless2/exp + cd lstm_transducer_stateless2/exp + echo "===modified_beam_search_rnnlm_LODR===" + find modified_beam_search_rnnlm_LODR -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find modified_beam_search_rnnlm_LODR -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + - name: Upload decoding results for lstm_transducer_stateless2 uses: actions/upload-artifact@v2 - if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' + if: github.event_name == 'schedule' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'LODR' with: name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-lstm_transducer_stateless2-2022-09-03 path: egs/librispeech/ASR/lstm_transducer_stateless2/exp/ diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index efd60ba81..c2ea3d050 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -318,6 +318,7 @@ The WERs are: | greedy search (max sym per frame 1) | 2.78 | 7.36 | --iter 468000 --avg 16 | | modified_beam_search | 2.73 | 7.15 | --iter 468000 --avg 16 | | modified_beam_search + RNNLM shallow fusion | 2.42 | 6.46 | --iter 468000 --avg 16 | +| modified_beam_search + RNNLM shallow fusion | 2.28 | 5.94 | --iter 468000 --avg 16 | | fast_beam_search | 2.76 | 7.31 | --iter 468000 --avg 16 | | greedy search (max sym per frame 1) | 2.77 | 7.35 | --iter 472000 --avg 18 | | modified_beam_search | 2.75 | 7.08 | --iter 472000 --avg 18 | @@ -393,6 +394,32 @@ for iter in 472000; do done done +You may also decode using LODR + RNNLM shallow fusion. This decoding method is proposed in . +It subtracts the internal language model score during shallow fusion, which is approximated by a bi-gram model. The bi-gram can be +generated by `generate-lm.sh`, or you may download it from . + +The decoding command is as follows: + +for iter in 472000; do + for avg in 8 10 12 14 16 18; do + ./lstm_transducer_stateless2/decode.py \ + --iter $iter \ + --avg $avg \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_rnnlm_LODR \ + --beam 4 \ + --rnn-lm-scale 0.4 \ + --rnn-lm-exp-dir /path/to/RNNLM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 \ + --token-ngram 2 \ + --ngram-lm-scale -0.16 + done +done + Pretrained models, training logs, decoding logs, and decoding results are available at @@ -1912,6 +1939,8 @@ subset so that the gigaspeech dataloader never exhausts. |-------------------------------------|------------|------------|---------------------------------------------| | greedy search (max sym per frame 1) | 2.03 | 4.70 | --iter 1224000 --avg 14 --max-duration 600 | | modified beam search | 2.00 | 4.63 | --iter 1224000 --avg 14 --max-duration 600 | +| modified beam search + rnnlm shallow fusion | 1.94 | 4.2 | --iter 1224000 --avg 14 --max-duration 600 | +| modified beam search + LODR | 1.83 | 4.03 | --iter 1224000 --avg 14 --max-duration 600 | | fast beam search | 2.10 | 4.68 | --iter 1224000 --avg 14 --max-duration 600 | The training commands are: @@ -1957,6 +1986,64 @@ for iter in 1224000; do done done ``` +You may also decode using shallow fusion with external RNNLM. To do so you need to +download a well-trained RNNLM from this link + +```bash +rnn_lm_scale=0.3 + +for iter in 1224000; do + for avg in 14; do + for method in modified_beam_search_rnnlm_shallow_fusion ; do + ./pruned_transducer_stateless3/decode.py \ + --iter $iter \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless3/exp-0.9/ \ + --max-duration 600 \ + --decoding-method $method \ + --max-sym-per-frame 1 \ + --beam 4 \ + --max-contexts 32 \ + --rnn-lm-scale $rnn_lm_scale \ + --rnn-lm-exp-dir /path/to/RNNLM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + done + done +done +``` + +If you want to try out with LODR decoding, use the following command. This assums you have a bi-gram LM trained on LibriSpeech text. You can also download the bi-gram LM from here and put it under the directory `data/lang_bpe_500`. + +```bash +rnn_lm_scale=0.4 + +for iter in 1224000; do + for avg in 14; do + for method in modified_beam_search_rnnlm_LODR ; do + ./pruned_transducer_stateless3/decode.py \ + --iter $iter \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless3/exp-0.9/ \ + --max-duration 600 \ + --decoding-method $method \ + --max-sym-per-frame 1 \ + --beam 4 \ + --max-contexts 32 \ + --rnn-lm-scale $rnn_lm_scale \ + --rnn-lm-exp-dir /path/to/RNNLM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 \ + --tokens-ngram 2 \ + --ngram-lm-scale -0.14 + done + done +done +``` The pretrained models, training logs, decoding logs, and decoding results can be found at diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index 69f695fef..fa5bf1825 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -107,8 +107,25 @@ Usage: --rnn-lm-avg 1 \ --rnn-lm-num-layers 3 \ --rnn-lm-tie-weights 1 -""" +(9) modified beam search with RNNLM shallow fusion + LODR +./lstm_transducer_stateless2/decode.py \ + --epoch 35 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./lstm_transducer_stateless2/exp \ + --decoding-method modified_beam_search_rnnlm_LODR \ + --beam 4 \ + --max-contexts 4 \ + --rnn-lm-scale 0.4 \ + --rnn-lm-exp-dir /path/to/RNNLM/exp \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 \ + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 \ +""" import argparse import logging @@ -132,6 +149,7 @@ from beam_search import ( greedy_search_batch, modified_beam_search, modified_beam_search_ngram_rescoring, + modified_beam_search_rnnlm_LODR, modified_beam_search_rnnlm_shallow_fusion, ) from librispeech import LibriSpeech @@ -235,7 +253,8 @@ def get_parser(): - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG - modified_beam_search_ngram_rescoring - - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion + - modified_beam_search_rnnlm_shallow_fusion + - modified_beam_search_rnnlm_LODR If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -394,7 +413,8 @@ def get_parser(): type=int, default=3, help="""Token Ngram used for rescoring. - Used only when the decoding method is modified_beam_search_ngram_rescoring""", + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", ) parser.add_argument( @@ -402,7 +422,8 @@ def get_parser(): type=int, default=500, help="""ID of the backoff symbol. - Used only when the decoding method is modified_beam_search_ngram_rescoring""", + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", ) add_model_arguments(parser) @@ -572,6 +593,20 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_rnnlm_LODR": + hyp_tokens = modified_beam_search_rnnlm_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + sp=sp, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + rnnlm=rnnlm, + rnnlm_scale=rnnlm_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) @@ -760,6 +795,7 @@ def main(): "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", + "modified_beam_search_rnnlm_LODR", "modified_beam_search_ngram_rescoring", "modified_beam_search_rnnlm_shallow_fusion", ) @@ -788,6 +824,9 @@ def main(): if "rnnlm" in params.decoding_method: params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" + if "LODR" in params.decoding_method: + params.suffix += "-LODR" + if params.use_averaged_model: params.suffix += "-use-averaged-model" @@ -901,7 +940,7 @@ def main(): model.eval() # only load N-gram LM when needed - if "ngram" in params.decoding_method: + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: lm_filename = f"{params.tokens_ngram}gram.fst.txt" logging.info(f"lm filename: {lm_filename}") ngram_lm = NgramLm( @@ -910,6 +949,7 @@ def main(): is_binary=False, ) logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale else: ngram_lm = None ngram_lm_scale = None @@ -933,7 +973,6 @@ def main(): ) rnn_lm_model.to(device) rnn_lm_model.eval() - else: rnn_lm_model = None rnn_lm_scale = 0.0 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 5e9428b60..59c8ed5b5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -2083,3 +2083,267 @@ def modified_beam_search_rnnlm_shallow_fusion( tokens=ans, timestamps=ans_timestamps, ) + + +def modified_beam_search_rnnlm_LODR( + model: Transducer, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + sp: spm.SentencePieceProcessor, + LODR_lm: NgramLm, + LODR_lm_scale: float, + rnnlm: RnnLmModel, + rnnlm_scale: float, + beam: int = 4, +) -> List[List[int]]: + """This function implements LODR (https://arxiv.org/abs/2203.16776) with + `modified_beam_search`. It uses a bi-gram language model as the estimate + of the internal language model and subtracts its score during shallow fusion + with an external language model. This implementation uses a RNNLM as the + external language model. + + Args: + model (Transducer): + The transducer model + encoder_out (torch.Tensor): + Encoder output in (N,T,C) + encoder_out_lens (torch.Tensor): + A 1-D tensor of shape (N,), containing the number of + valid frames in encoder_out before padding. + sp: + Sentence piece generator. + LODR_lm: + A low order n-gram LM + LODR_lm_scale: + The scale of the LODR_lm + rnnlm (RnnLmModel): + RNNLM, the external language model + rnnlm_scale (float): + scale of RNNLM in shallow fusion + beam (int, optional): + Beam size. Defaults to 4. + + Returns: + Return a list-of-list of token IDs. ans[i] is the decoding results + for the i-th utterance. + + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + assert rnnlm is not None + lm_scale = rnnlm_scale + vocab_size = rnnlm.vocab_size + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + sos_id = sp.piece_to_id("") + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + # get initial lm score and lm state by scoring the "sos" token + sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) + init_score, init_states = rnnlm.score_token(sos_token) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + state=init_states, # state of the RNNLM + lm_score=init_score.reshape(-1), + state_cost=NgramLmStateCost( + LODR_lm + ), # state of the source domain ngram + ) + ) + + rnnlm.clean_cache() + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] # get batch + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + """ + for all hyps with a non-blank new token, score this token. + It is a little confusing here because this for-loop + looks very similar to the one below. Here, we go through all + top-k tokens and only add the non-blanks ones to the token_list. + The RNNLM will score those tokens given the LM states. Note that + the variable `scores` is the LM score after seeing the new + non-blank token. + """ + token_list = [] + hs = [] + cs = [] + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + assert new_token != 0, new_token + token_list.append([new_token]) + # store the LSTM states + hs.append(hyp.state[0]) + cs.append(hyp.state[1]) + + # forward RNNLM to get new states and scores + if len(token_list) != 0: + tokens_to_score = ( + torch.tensor(token_list) + .to(torch.int64) + .to(device) + .reshape(-1, 1) + ) + + hs = torch.cat(hs, dim=1).to(device) + cs = torch.cat(cs, dim=1).to(device) + scores, lm_states = rnnlm.score_token(tokens_to_score, (hs, cs)) + + count = 0 # index, used to locate score and lm states + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + ys = hyp.ys[:] + + # current score of hyp + lm_score = hyp.lm_score + state = hyp.state + + hyp_log_prob = topk_log_probs[k] # get score of current hyp + new_token = topk_token_indexes[k] + if new_token not in (blank_id, unk_id): + + ys.append(new_token) + state_cost = hyp.state_cost.forward_one_step(new_token) + + # calculate the score of the latest token + current_ngram_score = ( + state_cost.lm_score - hyp.state_cost.lm_score + ) + + assert current_ngram_score <= 0.0, ( + state_cost.lm_score, + hyp.state_cost.lm_score, + ) + # score = score + RNNLM_score - LODR_score + # LODR_LM_scale is a negative number here + hyp_log_prob += ( + lm_score[new_token] * lm_scale + + LODR_lm_scale * current_ngram_score + ) # add the lm score + + lm_score = scores[count] + state = ( + lm_states[0][:, count, :].unsqueeze(1), + lm_states[1][:, count, :].unsqueeze(1), + ) + count += 1 + else: + state_cost = hyp.state_cost + + new_hyp = Hypothesis( + ys=ys, + log_prob=hyp_log_prob, + state=state, + lm_score=lm_score, + state_cost=state_cost, + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 03137501f..e00aab34a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 # -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang +# Xiaoyu Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -90,8 +91,40 @@ Usage: --beam 20.0 \ --max-contexts 8 \ --max-states 64 -""" +(8) modified beam search (with RNNLM shallow fusion) +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search_rnnlm_shallow_fusion \ + --beam 4 \ + --rnn-lm-scale 0.3 \ + --rnn-lm-exp-dir /path/to/RNNLM \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + +(9) modified beam search with RNNLM shallow fusion + LODR +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --max-duration 600 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --decoding-method modified_beam_search_rnnlm_LODR \ + --beam 4 \ + --max-contexts 4 \ + --rnn-lm-scale 0.4 \ + --rnn-lm-exp-dir /path/to/RNNLM/exp \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 \ + --tokens-ngram 2 \ + --ngram-lm-scale -0.16 \ +""" import argparse import logging @@ -116,10 +149,14 @@ from beam_search import ( greedy_search, greedy_search_batch, modified_beam_search, + modified_beam_search_ngram_rescoring, + modified_beam_search_rnnlm_LODR, + modified_beam_search_rnnlm_shallow_fusion, ) from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model +from icefall import NgramLm from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint from icefall.lexicon import Lexicon from icefall.rnn_lm.model import RnnLmModel @@ -202,6 +239,9 @@ def get_parser(): - fast_beam_search_nbest - fast_beam_search_nbest_oracle - fast_beam_search_nbest_LG + - modified_beam_search_ngram_rescoring + - modified_beam_search_rnnlm_shallow_fusion + - modified_beam_search_rnnlm_LODR If you use fast_beam_search_nbest_LG, you have to specify `--lang-dir`, which should contain `LG.pt`. """, @@ -263,6 +303,7 @@ def get_parser(): default=2, help="The context size in the decoder. 1 means bigram; 2 means tri-gram", ) + parser.add_argument( "--max-sym-per-frame", type=int, @@ -341,6 +382,15 @@ def get_parser(): """, ) + parser.add_argument( + "--rnn-lm-scale", + type=float, + default=0.0, + help="""Used only when --method is modified-beam-search_rnnlm_shallow_fusion. + It specifies the path to RNN LM exp dir. + """, + ) + parser.add_argument( "--rnn-lm-exp-dir", type=str, @@ -397,6 +447,24 @@ def get_parser(): """, ) + parser.add_argument( + "--tokens-ngram", + type=int, + default=3, + help="""Token Ngram used for rescoring. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) + + parser.add_argument( + "--backoff-id", + type=int, + default=500, + help="""ID of the backoff symbol. + Used only when the decoding method is + modified_beam_search_ngram_rescoring""", + ) + add_model_arguments(parser) return parser @@ -410,7 +478,10 @@ def decode_one_batch( word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None, - rnn_lm_model: torch.nn.Module = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + rnn_lm_model: Optional[RnnLmModel] = None, + rnnlm_scale: float = 1.0, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the following format: @@ -444,6 +515,14 @@ def decode_one_batch( fast_beam_search_nbest, fast_beam_search_nbest_oracle, or fast_beam_search_with_nbest_rescoring. It an FsaVec containing an acceptor. + rnn_lm_model: + A rnnlm which can be used for rescoring or shallow fusion + rnnlm_scale: + The scale of the rnnlm. + ngram_lm: + A ngram lm. Used in LODR decoding. + ngram_lm_scale: + The scale of the ngram language model. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -607,6 +686,43 @@ def decode_one_batch( nbest_scale=params.nbest_scale, temperature=params.temperature, ) + elif params.decoding_method == "modified_beam_search_ngram_rescoring": + hyp_tokens = modified_beam_search_ngram_rescoring( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": + hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + sp=sp, + rnnlm=rnn_lm_model, + rnnlm_scale=rnnlm_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_rnnlm_LODR": + hyp_tokens = modified_beam_search_rnnlm_LODR( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + sp=sp, + LODR_lm=ngram_lm, + LODR_lm_scale=ngram_lm_scale, + rnnlm=rnn_lm_model, + rnnlm_scale=rnnlm_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) else: batch_size = encoder_out.size(0) @@ -693,7 +809,10 @@ def decode_dataset( word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None, - rnn_lm_model: torch.nn.Module = None, + ngram_lm: Optional[NgramLm] = None, + ngram_lm_scale: float = 1.0, + rnn_lm_model: Optional[RnnLmModel] = None, + rnnlm_scale: float = 1.0, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -749,7 +868,10 @@ def decode_dataset( decoding_graph=decoding_graph, batch=batch, G=G, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, rnn_lm_model=rnn_lm_model, + rnnlm_scale=rnnlm_scale, ) for name, hyps in hyps_dict.items(): @@ -900,6 +1022,9 @@ def main(): "modified_beam_search", "fast_beam_search_with_nbest_rescoring", "fast_beam_search_with_nbest_rnn_rescoring", + "modified_beam_search_rnnlm_LODR", + "modified_beam_search_ngram_rescoring", + "modified_beam_search_rnnlm_shallow_fusion", ) params.res_dir = params.exp_dir / params.decoding_method @@ -930,6 +1055,13 @@ def main(): params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-temperature-{params.temperature}" + if "rnnlm" in params.decoding_method: + params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" + if "LODR" in params.decoding_method: + params.suffix += "-LODR" + if "ngram" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") logging.info("Decoding started") @@ -1048,6 +1180,44 @@ def main(): word_table = None rnn_lm_model = None + # only load N-gram LM when needed + if "ngram" in params.decoding_method or "LODR" in params.decoding_method: + lm_filename = f"{params.tokens_ngram}gram.fst.txt" + logging.info(f"lm filename: {lm_filename}") + ngram_lm = NgramLm( + str(params.lang_dir / lm_filename), + backoff_id=params.backoff_id, + is_binary=False, + ) + logging.info(f"num states: {ngram_lm.lm.num_states}") + ngram_lm_scale = params.ngram_lm_scale + else: + ngram_lm = None + ngram_lm_scale = None + + # only load rnnlm if used + if "rnnlm" in params.decoding_method: + rnn_lm_scale = params.rnn_lm_scale + + rnn_lm_model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + assert params.rnn_lm_avg == 1 + + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + rnn_lm_model.eval() + else: + rnn_lm_model = None + rnn_lm_scale = 0.0 + num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -1074,7 +1244,10 @@ def main(): word_table=word_table, decoding_graph=decoding_graph, G=G, + ngram_lm=ngram_lm, + ngram_lm_scale=ngram_lm_scale, rnn_lm_model=rnn_lm_model, + rnnlm_scale=rnn_lm_scale, ) save_results( From 556c63fbb741bcbc1669ec6848e06b08480d001f Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 1 Dec 2022 08:58:18 +0800 Subject: [PATCH 11/27] Describe how to fix segfault in doc (#719) --- docs/source/installation/index.rst | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/source/installation/index.rst b/docs/source/installation/index.rst index c4474c3d9..5b9fb2664 100644 --- a/docs/source/installation/index.rst +++ b/docs/source/installation/index.rst @@ -393,6 +393,17 @@ Now let us run the training part: We use ``export CUDA_VISIBLE_DEVICES=""`` so that ``icefall`` uses CPU even if there are GPUs available. +.. hint:: + + In case you get a ``Segmentation fault (core dump)`` error, please use: + + .. code-block:: bash + + export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + + See more at `` if you are + interested. + The training log is given below: .. code-block:: From 2bca7032afb0d5b9eb60f7bcf3bc15ad1e8d8a83 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Thu, 1 Dec 2022 15:57:43 +0800 Subject: [PATCH 12/27] Update RNNLM training scripts (#720) * Update RNNLM training scripts * Fix a typo * Fix CI --- .github/workflows/run-ptb-rnn-lm.yml | 67 ++++++++++++++++++++ egs/librispeech/ASR/local/train_bpe_model.py | 4 ++ egs/ptb/LM/prepare.sh | 38 ++++++----- egs/ptb/LM/rnn_lm | 1 + egs/ptb/LM/train-rnn-lm.sh | 67 ++++++++++++++++++++ icefall/rnn_lm/compute_perplexity.py | 2 +- icefall/rnn_lm/dataset.py | 4 +- icefall/rnn_lm/train.py | 10 +-- 8 files changed, 170 insertions(+), 23 deletions(-) create mode 100644 .github/workflows/run-ptb-rnn-lm.yml create mode 120000 egs/ptb/LM/rnn_lm create mode 100755 egs/ptb/LM/train-rnn-lm.sh diff --git a/.github/workflows/run-ptb-rnn-lm.yml b/.github/workflows/run-ptb-rnn-lm.yml new file mode 100644 index 000000000..8ebc2e79b --- /dev/null +++ b/.github/workflows/run-ptb-rnn-lm.yml @@ -0,0 +1,67 @@ +name: run-ptb-rnn-lm-training + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +jobs: + run_ptb_rnn_lm_training: + if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: ["3.8"] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf + + - name: Prepare data + shell: bash + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + cd egs/ptb/LM + ./prepare.sh + + - name: Run training + shell: bash + run: | + export PYTHONPATH=$PWD:$PYTHONPATH + cd egs/ptb/LM + ./train-rnn-lm.sh --world-size 1 --num-epochs 5 --use-epoch 4 --use-avg 2 + + - name: Upload pretrained models + uses: actions/upload-artifact@v2 + if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule' + with: + name: python-${{ matrix.python-version }}-ubuntu-rnn-lm-ptb + path: egs/ptb/LM/my-rnnlm-exp/ diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py index 42aba9572..7f6f47e16 100755 --- a/egs/librispeech/ASR/local/train_bpe_model.py +++ b/egs/librispeech/ASR/local/train_bpe_model.py @@ -89,6 +89,10 @@ def main(): bos_id=-1, eos_id=-1, ) + else: + print(f"{model_file} exists - skipping") + return + shutil.copyfile(model_file, f"{lang_dir}/bpe.model") diff --git a/egs/ptb/LM/prepare.sh b/egs/ptb/LM/prepare.sh index 91c3c667a..69fab999a 100755 --- a/egs/ptb/LM/prepare.sh +++ b/egs/ptb/LM/prepare.sh @@ -22,9 +22,9 @@ dl_dir=$PWD/download # if the array contains xxx, yyy vocab_sizes=( 500 - 1000 - 2000 - 5000 + # 1000 + # 2000 + # 5000 ) # All files generated by this script are saved in "data". @@ -42,11 +42,14 @@ log "dl_dir: $dl_dir" if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then log "Stage -1: Download data" + + # Caution: The downloaded data has already been normalized for LM training. + if [ ! -f $dl_dir/.complete ]; then - url=https://raw.githubusercontent.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage/master/data/ - wget --no-verbose --directory-prefix $dl_dir $url/ptb.train.txt - wget --no-verbose --directory-prefix $dl_dir $url/ptb.valid.txt - wget --no-verbose --directory-prefix $dl_dir $url/ptb.test.txt + url=http://raw.githubusercontent.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage/master/data + wget --directory-prefix $dl_dir $url/ptb.train.txt + wget --directory-prefix $dl_dir $url/ptb.valid.txt + wget --directory-prefix $dl_dir $url/ptb.test.txt touch $dl_dir/.complete fi fi @@ -54,11 +57,15 @@ fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Train BPE model" + # Caution: You have to use the same bpe model for training your acoustic model + # Caution: You have to use the same bpe model for training your acoustic model + # Caution: You have to use the same bpe model for training your acoustic model + for vocab_size in ${vocab_sizes[@]}; do - out_dir=data/bpe_${vocab_size} - mkdir -p $out_dir + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir ./local/train_bpe_model.py \ - --out-dir $out_dir \ + --lang-dir $lang_dir \ --vocab-size $vocab_size \ --transcript $dl_dir/ptb.train.txt done @@ -69,20 +76,21 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then # Note: ptb.train.txt has already been normalized for vocab_size in ${vocab_sizes[@]}; do - out_dir=data/bpe_${vocab_size} + lang_dir=data/lang_bpe_${vocab_size} + out_dir=data/lm_training_bpe_${vocab_size} mkdir -p $out_dir ./local/prepare_lm_training_data.py \ - --bpe-model $out_dir/bpe.model \ + --bpe-model $lang_dir/bpe.model \ --lm-data $dl_dir/ptb.train.txt \ --lm-archive $out_dir/lm_data.pt ./local/prepare_lm_training_data.py \ - --bpe-model $out_dir/bpe.model \ + --bpe-model $lang_dir/bpe.model \ --lm-data $dl_dir/ptb.valid.txt \ --lm-archive $out_dir/lm_data-valid.pt ./local/prepare_lm_training_data.py \ - --bpe-model $out_dir/bpe.model \ + --bpe-model $lang_dir/bpe.model \ --lm-data $dl_dir/ptb.test.txt \ --lm-archive $out_dir/lm_data-test.pt done @@ -98,7 +106,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then # in a sentence. for vocab_size in ${vocab_sizes[@]}; do - out_dir=data/bpe_${vocab_size} + out_dir=data/lm_training_bpe_${vocab_size} mkdir -p $out_dir ./local/sort_lm_training_data.py \ --in-lm-data $out_dir/lm_data.pt \ diff --git a/egs/ptb/LM/rnn_lm b/egs/ptb/LM/rnn_lm new file mode 120000 index 000000000..87f29771e --- /dev/null +++ b/egs/ptb/LM/rnn_lm @@ -0,0 +1 @@ +../../../icefall/rnn_lm \ No newline at end of file diff --git a/egs/ptb/LM/train-rnn-lm.sh b/egs/ptb/LM/train-rnn-lm.sh new file mode 100755 index 000000000..29c609ee1 --- /dev/null +++ b/egs/ptb/LM/train-rnn-lm.sh @@ -0,0 +1,67 @@ +#!/usr/bin/env bash + +# Please run ./prepare.sh first + +stage=-1 +stop_stage=100 + +# Number of GPUs to use for training +world_size=1 + +# Number of epochs to train +num_epochs=20 + +# Use this epoch for computing ppl +use_epoch=19 + +# number of models to average for computing ppl +use_avg=2 + +exp_dir=./my-rnnlm-exp + +. shared/parse_options.sh || exit 1 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Training RNN LM" + + ./rnn_lm/train.py \ + --exp-dir $exp_dir \ + --start-epoch 0 \ + --num-epochs $num_epochs \ + --world-size $world_size \ + --use-fp16 0 \ + --vocab-size 500 \ + \ + --lm-data ./data/lm_training_bpe_500/sorted_lm_data.pt \ + --lm-data-valid ./data/lm_training_bpe_500/sorted_lm_data-valid.pt \ + \ + --embedding-dim 800 \ + --hidden-dim 200 \ + --num-layers 2 \ + --tie-weights false \ + --batch-size 50 +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Computing perplexity" + + ./rnn_lm/compute_perplexity.py \ + --exp-dir $exp_dir \ + --epoch $use_epoch \ + --avg $use_avg \ + --vocab-size 500 \ + \ + --lm-data ./data/lm_training_bpe_500/sorted_lm_data-test.pt \ + \ + --embedding-dim 800 \ + --hidden-dim 200 \ + --num-layers 2 \ + --tie-weights false \ + --batch-size 50 +fi diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py index 550801a8f..f75a89590 100755 --- a/icefall/rnn_lm/compute_perplexity.py +++ b/icefall/rnn_lm/compute_perplexity.py @@ -20,7 +20,7 @@ Usage: ./rnn_lm/compute_perplexity.py \ --epoch 4 \ --avg 2 \ - --lm-data ./data/bpe_500/sorted_lm_data-test.pt + --lm-data ./data/lm_training_bpe_500/sorted_lm_data-test.pt """ diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py index 4bf982503..53be53f64 100644 --- a/icefall/rnn_lm/dataset.py +++ b/icefall/rnn_lm/dataset.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey, Fangjun Kuang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -194,7 +194,7 @@ def get_dataloader( batch_size=params.batch_size, ) if is_distributed: - sampler = DistributedSampler(dataset, shuffle=True, drop_last=False) + sampler = DistributedSampler(dataset, shuffle=True, drop_last=True) else: sampler = None diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index 3ba5bfbee..803da99d6 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -24,7 +24,7 @@ Usage: --use-fp16 0 \ --embedding-dim 800 \ --hidden-dim 200 \ - --num-layers 2\ + --num-layers 2 \ --batch-size 400 """ @@ -83,7 +83,7 @@ def get_parser(): parser.add_argument( "--num-epochs", type=int, - default=10, + default=30, help="Number of epochs to train.", ) @@ -110,14 +110,14 @@ def get_parser(): parser.add_argument( "--use-fp16", type=str2bool, - default=False, + default=True, help="Whether to use half precision training.", ) parser.add_argument( "--batch-size", type=int, - default=50, + default=400, ) parser.add_argument( @@ -165,7 +165,7 @@ def get_parser(): parser.add_argument( "--tie-weights", type=str2bool, - default=False, + default=True, help="""True to share the weights between the input embedding layer and the last output linear layer """, From 04c9fc9c9f9e481cbfae18bb34252b878ff51f6a Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 2 Dec 2022 09:18:28 +0800 Subject: [PATCH 13/27] Fix for older versions of k2 (#725) --- icefall/graph_compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py index 84be81254..0dcd777ad 100644 --- a/icefall/graph_compiler.py +++ b/icefall/graph_compiler.py @@ -81,7 +81,7 @@ class CtcTrainingGraphCompiler(object): self.ctc_topo._is_repeat_token_ = ( self.ctc_topo.labels != self.ctc_topo.aux_labels - ) + ).int() decoding_graph = k2.compose( self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False From 6533f359c998cee6fcb618f7b221cbfee05512e8 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 2 Dec 2022 10:53:06 +0800 Subject: [PATCH 14/27] Fix CI (#726) * Fix CI * Disable shuffle for yesno. See https://github.com/k2-fsa/icefall/issues/197 --- .github/workflows/build-doc.yml | 4 ++ .github/workflows/run-aishell-2022-06-20.yml | 4 ++ .../workflows/run-gigaspeech-2022-05-13.yml | 4 ++ .../workflows/run-librispeech-2022-03-12.yml | 4 ++ .../workflows/run-librispeech-2022-04-29.yml | 4 ++ .../workflows/run-librispeech-2022-05-13.yml | 4 ++ .../run-librispeech-2022-11-11-stateless7.yml | 4 ++ .../run-librispeech-2022-11-14-stateless8.yml | 4 ++ ...-librispeech-conformer-ctc3-2022-11-28.yml | 4 ++ ...-lstm-transducer-stateless2-2022-09-03.yml | 4 ++ ...runed-transducer-stateless3-2022-05-13.yml | 4 ++ ...aming-transducer-stateless2-2022-06-26.yml | 4 ++ ...peech-transducer-stateless2-2022-04-19.yml | 4 ++ .../run-pretrained-conformer-ctc.yml | 4 ++ ...-transducer-stateless-librispeech-100h.yml | 4 ++ ...r-stateless-librispeech-multi-datasets.yml | 4 ++ ...ransducer-stateless-modified-2-aishell.yml | 4 ++ ...-transducer-stateless-modified-aishell.yml | 4 ++ .../run-pretrained-transducer-stateless.yml | 4 ++ .../workflows/run-pretrained-transducer.yml | 4 ++ .github/workflows/run-ptb-rnn-lm.yml | 4 ++ ...netspeech-pruned-transducer-stateless2.yml | 6 +- .github/workflows/run-yesno-recipe.yml | 10 +++- .github/workflows/style_check.yml | 4 ++ .github/workflows/test.yml | 60 ++++++++----------- egs/librispeech/ASR/local/train_bpe_model.py | 1 - .../beam_search.py | 13 +--- .../test_scaling.py | 8 --- egs/yesno/ASR/tdnn/asr_datamodule.py | 2 +- 29 files changed, 128 insertions(+), 60 deletions(-) diff --git a/.github/workflows/build-doc.yml b/.github/workflows/build-doc.yml index dd0969f51..d7fe2c964 100644 --- a/.github/workflows/build-doc.yml +++ b/.github/workflows/build-doc.yml @@ -26,6 +26,10 @@ on: pull_request: types: [labeled] +concurrency: + group: build_doc-${{ github.ref }} + cancel-in-progress: true + jobs: build-doc: if: github.event.label.name == 'doc' || github.event_name == 'push' diff --git a/.github/workflows/run-aishell-2022-06-20.yml b/.github/workflows/run-aishell-2022-06-20.yml index e46b01a08..1865a0da8 100644 --- a/.github/workflows/run-aishell-2022-06-20.yml +++ b/.github/workflows/run-aishell-2022-06-20.yml @@ -34,6 +34,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_aishell_2022_06_20-${{ github.ref }} + cancel-in-progress: true + jobs: run_aishell_2022_06_20: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' diff --git a/.github/workflows/run-gigaspeech-2022-05-13.yml b/.github/workflows/run-gigaspeech-2022-05-13.yml index c631927fa..e438c5dba 100644 --- a/.github/workflows/run-gigaspeech-2022-05-13.yml +++ b/.github/workflows/run-gigaspeech-2022-05-13.yml @@ -33,6 +33,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_gigaspeech_2022_05_13-${{ github.ref }} + cancel-in-progress: true + jobs: run_gigaspeech_2022_05_13: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' diff --git a/.github/workflows/run-librispeech-2022-03-12.yml b/.github/workflows/run-librispeech-2022-03-12.yml index 5df710006..3ba6850cd 100644 --- a/.github/workflows/run-librispeech-2022-03-12.yml +++ b/.github/workflows/run-librispeech-2022-03-12.yml @@ -33,6 +33,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_2022_03_12-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_2022_03_12: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' diff --git a/.github/workflows/run-librispeech-2022-04-29.yml b/.github/workflows/run-librispeech-2022-04-29.yml index 24c062442..595b410b8 100644 --- a/.github/workflows/run-librispeech-2022-04-29.yml +++ b/.github/workflows/run-librispeech-2022-04-29.yml @@ -33,6 +33,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_2022_04_29-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_2022_04_29: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' diff --git a/.github/workflows/run-librispeech-2022-05-13.yml b/.github/workflows/run-librispeech-2022-05-13.yml index 29215ec25..eb0b06a2d 100644 --- a/.github/workflows/run-librispeech-2022-05-13.yml +++ b/.github/workflows/run-librispeech-2022-05-13.yml @@ -33,6 +33,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_2022_05_13-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_2022_05_13: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' diff --git a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml index 3b98b500e..365e2761a 100644 --- a/.github/workflows/run-librispeech-2022-11-11-stateless7.yml +++ b/.github/workflows/run-librispeech-2022-11-11-stateless7.yml @@ -33,6 +33,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_2022_11_11_zipformer-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_2022_11_11_zipformer: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' diff --git a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml index eaab35189..acb11a8f4 100644 --- a/.github/workflows/run-librispeech-2022-11-14-stateless8.yml +++ b/.github/workflows/run-librispeech-2022-11-14-stateless8.yml @@ -33,6 +33,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_2022_11_14_zipformer_stateless8-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_2022_11_14_zipformer_stateless8: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' diff --git a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml index 21f396c32..d763fb1c5 100644 --- a/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml +++ b/.github/workflows/run-librispeech-conformer-ctc3-2022-11-28.yml @@ -33,6 +33,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_2022_11_28_conformer_ctc3-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_2022_11_28_conformer_ctc3: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index 5f0acf9b8..59f116fde 100644 --- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -16,6 +16,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_lstm_transducer_stateless2_2022_09_03-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_lstm_transducer_stateless2_2022_09_03: if: github.event.label.name == 'ready' || github.event.label.name == 'LODR' || github.event.label.name == 'shallow-fusion' || github.event.label.name == 'ncnn' || github.event.label.name == 'onnx' || github.event_name == 'push' || github.event_name == 'schedule' diff --git a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml index 66a2c240b..2c2bcab0c 100644 --- a/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml +++ b/.github/workflows/run-librispeech-pruned-transducer-stateless3-2022-05-13.yml @@ -33,6 +33,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_pruned_transducer_stateless3_2022_05_13-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_pruned_transducer_stateless3_2022_05_13: if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' diff --git a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml index 55428861c..ac7e58b20 100644 --- a/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml +++ b/.github/workflows/run-librispeech-streaming-transducer-stateless2-2022-06-26.yml @@ -33,6 +33,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_streaming_2022_06_26-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_streaming_2022_06_26: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' diff --git a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml index f520405e1..575727e22 100644 --- a/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml +++ b/.github/workflows/run-librispeech-transducer-stateless2-2022-04-19.yml @@ -33,6 +33,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_librispeech_2022_04_19-${{ github.ref }} + cancel-in-progress: true + jobs: run_librispeech_2022_04_19: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' diff --git a/.github/workflows/run-pretrained-conformer-ctc.yml b/.github/workflows/run-pretrained-conformer-ctc.yml index 9bc6a481f..7dbfd2bd9 100644 --- a/.github/workflows/run-pretrained-conformer-ctc.yml +++ b/.github/workflows/run-pretrained-conformer-ctc.yml @@ -23,6 +23,10 @@ on: pull_request: types: [labeled] +concurrency: + group: run_pre_trained_conformer_ctc-${{ github.ref }} + cancel-in-progress: true + jobs: run_pre_trained_conformer_ctc: if: github.event.label.name == 'ready' || github.event_name == 'push' diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml index 7a0f30b0f..d6b3de8d4 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-100h.yml @@ -32,6 +32,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h-${{ github.ref }} + cancel-in-progress: true + jobs: run_pre_trained_transducer_stateless_multi_datasets_librispeech_100h: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' diff --git a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml index 797f3fe50..749fb3fca 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-librispeech-multi-datasets.yml @@ -32,6 +32,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h-${{ github.ref }} + cancel-in-progress: true + jobs: run_pre_trained_transducer_stateless_multi_datasets_librispeech_960h: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml index 29e665881..92bf6feb8 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-2-aishell.yml @@ -23,6 +23,10 @@ on: pull_request: types: [labeled] +concurrency: + group: run_pre_trained_transducer_stateless_modified_2_aishell-${{ github.ref }} + cancel-in-progress: true + jobs: run_pre_trained_transducer_stateless_modified_2_aishell: if: github.event.label.name == 'ready' || github.event_name == 'push' diff --git a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml index 6193f28e7..e51da8bd8 100644 --- a/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml +++ b/.github/workflows/run-pretrained-transducer-stateless-modified-aishell.yml @@ -23,6 +23,10 @@ on: pull_request: types: [labeled] +concurrency: + group: run_pre_trained_transducer_stateless_modified_aishell-${{ github.ref }} + cancel-in-progress: true + jobs: run_pre_trained_transducer_stateless_modified_aishell: if: github.event.label.name == 'ready' || github.event_name == 'push' diff --git a/.github/workflows/run-pretrained-transducer-stateless.yml b/.github/workflows/run-pretrained-transducer-stateless.yml index 32208076c..2103d0510 100644 --- a/.github/workflows/run-pretrained-transducer-stateless.yml +++ b/.github/workflows/run-pretrained-transducer-stateless.yml @@ -32,6 +32,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_pre_trained_transducer_stateless-${{ github.ref }} + cancel-in-progress: true + jobs: run_pre_trained_transducer_stateless: if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' diff --git a/.github/workflows/run-pretrained-transducer.yml b/.github/workflows/run-pretrained-transducer.yml index 965d0f655..902319b55 100644 --- a/.github/workflows/run-pretrained-transducer.yml +++ b/.github/workflows/run-pretrained-transducer.yml @@ -23,6 +23,10 @@ on: pull_request: types: [labeled] +concurrency: + group: run_pre_trained_transducer-${{ github.ref }} + cancel-in-progress: true + jobs: run_pre_trained_transducer: if: github.event.label.name == 'ready' || github.event_name == 'push' diff --git a/.github/workflows/run-ptb-rnn-lm.yml b/.github/workflows/run-ptb-rnn-lm.yml index 8ebc2e79b..47ed958f2 100644 --- a/.github/workflows/run-ptb-rnn-lm.yml +++ b/.github/workflows/run-ptb-rnn-lm.yml @@ -16,6 +16,10 @@ on: # nightly build at 15:50 UTC time every day - cron: "50 15 * * *" +concurrency: + group: run_ptb_rnn_lm_training-${{ github.ref }} + cancel-in-progress: true + jobs: run_ptb_rnn_lm_training: if: github.event.label.name == 'ready' || github.event.label.name == 'rnnlm' || github.event_name == 'push' || github.event_name == 'schedule' diff --git a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml index d96a3bfe6..8a7be0b80 100644 --- a/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml +++ b/.github/workflows/run-wenetspeech-pruned-transducer-stateless2.yml @@ -23,8 +23,12 @@ on: pull_request: types: [labeled] +concurrency: + group: run_wenetspeech_pruned_transducer_stateless2-${{ github.ref }} + cancel-in-progress: true + jobs: - run_librispeech_pruned_transducer_stateless3_2022_05_13: + run_wenetspeech_pruned_transducer_stateless2: if: github.event.label.name == 'onnx' || github.event.label.name == 'ready' || github.event_name == 'push' || github.event.label.name == 'wenetspeech' runs-on: ${{ matrix.os }} strategy: diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index ce77c47df..ed343aee5 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -21,11 +21,15 @@ on: branches: - master pull_request: - types: [labeled] + branches: + - master + +concurrency: + group: run-yesno-recipe-${{ github.ref }} + cancel-in-progress: true jobs: run-yesno-recipe: - if: github.event.label.name == 'ready' || github.event_name == 'push' runs-on: ${{ matrix.os }} strategy: matrix: @@ -61,7 +65,7 @@ jobs: - name: Install Python dependencies run: | - grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install pip uninstall -y protobuf pip install --no-binary protobuf protobuf diff --git a/.github/workflows/style_check.yml b/.github/workflows/style_check.yml index 45d261ccc..fc1dcbfd4 100644 --- a/.github/workflows/style_check.yml +++ b/.github/workflows/style_check.yml @@ -24,6 +24,10 @@ on: branches: - master +concurrency: + group: style_check-${{ github.ref }} + cancel-in-progress: true + jobs: style_check: runs-on: ${{ matrix.os }} diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 04fc0265f..4dbe99827 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -21,26 +21,23 @@ on: branches: - master pull_request: - types: [labeled] + branches: + - master + +concurrency: + group: test-${{ github.ref }} + cancel-in-progress: true jobs: test: - if: github.event.label.name == 'ready' || github.event_name == 'push' runs-on: ${{ matrix.os }} strategy: matrix: - # os: [ubuntu-18.04, macos-10.15] - # disable macOS test for now. - os: [ubuntu-18.04] - python-version: [3.7, 3.8] - torch: ["1.8.0", "1.11.0"] - torchaudio: ["0.8.0", "0.11.0"] - k2-version: ["1.15.1.dev20220427"] - exclude: - - torch: "1.8.0" - torchaudio: "0.11.0" - - torch: "1.11.0" - torchaudio: "0.8.0" + os: [ubuntu-latest] + python-version: ["3.8"] + torch: ["1.10.0"] + torchaudio: ["0.10.0"] + k2-version: ["1.23.2.dev20221201"] fail-fast: false @@ -67,11 +64,7 @@ jobs: # numpy 1.20.x does not support python 3.6 pip install numpy==1.19 pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html - if [[ ${{ matrix.torchaudio }} == "0.11.0" ]]; then - pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html - else - pip install torchaudio==${{ matrix.torchaudio }} - fi + pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ pip install git+https://github.com/lhotse-speech/lhotse @@ -81,7 +74,6 @@ jobs: pip install kaldifst pip install onnxruntime - pip install -r requirements.txt - name: Install graphviz @@ -124,16 +116,14 @@ jobs: cd ../transducer_stateless pytest -v -s - if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then - cd ../transducer - pytest -v -s + cd ../transducer + pytest -v -s - cd ../transducer_stateless2 - pytest -v -s + cd ../transducer_stateless2 + pytest -v -s - cd ../transducer_lstm - pytest -v -s - fi + cd ../transducer_lstm + pytest -v -s - name: Run tests if: startsWith(matrix.os, 'macos') @@ -164,13 +154,11 @@ jobs: cd ../transducer_stateless pytest -v -s - if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then - cd ../transducer - pytest -v -s + cd ../transducer + pytest -v -s - cd ../transducer_stateless2 - pytest -v -s + cd ../transducer_stateless2 + pytest -v -s - cd ../transducer_lstm - pytest -v -s - fi + cd ../transducer_lstm + pytest -v -s diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py index 7f6f47e16..43142aee4 100755 --- a/egs/librispeech/ASR/local/train_bpe_model.py +++ b/egs/librispeech/ASR/local/train_bpe_model.py @@ -93,7 +93,6 @@ def main(): print(f"{model_file} exists - skipping") return - shutil.copyfile(model_file, f"{lang_dir}/bpe.model") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 59c8ed5b5..b324cc9b7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -2230,9 +2230,7 @@ def modified_beam_search_rnnlm_LODR( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor( - shape=log_probs_shape, value=log_probs - ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) """ for all hyps with a non-blank new token, score this token. It is a little confusing here because this for-loop @@ -2267,10 +2265,7 @@ def modified_beam_search_rnnlm_LODR( # forward RNNLM to get new states and scores if len(token_list) != 0: tokens_to_score = ( - torch.tensor(token_list) - .to(torch.int64) - .to(device) - .reshape(-1, 1) + torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) ) hs = torch.cat(hs, dim=1).to(device) @@ -2304,9 +2299,7 @@ def modified_beam_search_rnnlm_LODR( state_cost = hyp.state_cost.forward_one_step(new_token) # calculate the score of the latest token - current_ngram_score = ( - state_cost.lm_score - hyp.state_cost.lm_score - ) + current_ngram_score = state_cost.lm_score - hyp.state_cost.lm_score assert current_ngram_score <= 0.0, ( state_cost.lm_score, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py index e9dfe6d5e..42de2410a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_scaling.py @@ -52,17 +52,9 @@ def test_scaled_conv2d(): torch.jit.script(conv2d) -def test_activation_balancer(): - act = ActivationBalancer( - channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 - ) - torch.jit.script(act) - - def main(): test_scaled_conv1d() test_scaled_conv2d() - test_activation_balancer() if __name__ == "__main__": diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py index 85e5f1358..3c1682fa1 100644 --- a/egs/yesno/ASR/tdnn/asr_datamodule.py +++ b/egs/yesno/ASR/tdnn/asr_datamodule.py @@ -121,7 +121,7 @@ class YesNoAsrDataModule(DataModule): group.add_argument( "--shuffle", type=str2bool, - default=True, + default=False, help="When enabled (=default), the examples will be " "shuffled for each epoch.", ) From 6f719816673761ceda0bfe6bece5a44b151ead46 Mon Sep 17 00:00:00 2001 From: Amir Hussein <36240131+AmirHussein96@users.noreply.github.com> Date: Thu, 1 Dec 2022 21:58:34 -0500 Subject: [PATCH 15/27] MGB2 (#396) * mgb2 * mgb2 * adding pruned transducer stateless to mgb2 * update display_manifest_statistics.py * . * stateless transducer MGB-2 * Update README.md * Update RESULTS.md * Update prepare_lang_bpe.py * Update asr_datamodule.py * .nfs removed * Adding symlink * . * resolving conflicts * Update .gitignore * black formatting * Update compile_hlg.py * Update compute_fbank_musan.py * Update convert_transcript_words_to_tokens.py * Update download_lm.py * Update generate_unique_lexicon.py * adding simlinks * fixing symbolic links --- .gitignore | 20 + egs/mgb2/ASR/README.md | 43 + egs/mgb2/ASR/RESULTS.md | 236 ++++ egs/mgb2/ASR/conformer_ctc/__init__.py | 0 egs/mgb2/ASR/conformer_ctc/ali.py | 395 ++++++ egs/mgb2/ASR/conformer_ctc/asr_datamodule.py | 372 ++++++ egs/mgb2/ASR/conformer_ctc/compile_hlg.py | 1 + .../ASR/conformer_ctc/compute_fbank_musan.py | 1 + egs/mgb2/ASR/conformer_ctc/conformer.py | 1 + .../convert_transcript_words_to_tokens.py | 1 + egs/mgb2/ASR/conformer_ctc/decode.py | 695 ++++++++++ egs/mgb2/ASR/conformer_ctc/download_lm.py | 1 + egs/mgb2/ASR/conformer_ctc/export.py | 1 + .../conformer_ctc/generate_unique_lexicon.py | 1 + egs/mgb2/ASR/conformer_ctc/label_smoothing.py | 1 + egs/mgb2/ASR/conformer_ctc/pretrained.py | 430 ++++++ egs/mgb2/ASR/conformer_ctc/subsampling.py | 1 + .../ASR/conformer_ctc/test_label_smoothing.py | 1 + .../ASR/conformer_ctc/test_subsampling.py | 1 + .../ASR/conformer_ctc/test_transformer.py | 1 + egs/mgb2/ASR/conformer_ctc/train.py | 766 +++++++++++ egs/mgb2/ASR/conformer_ctc/transformer.py | 1 + egs/mgb2/ASR/local/__init__.py | 0 egs/mgb2/ASR/local/compile_hlg.py | 1 + egs/mgb2/ASR/local/compute_fbank_mgb2.py | 101 ++ egs/mgb2/ASR/local/compute_fbank_musan.py | 108 ++ .../convert_transcript_words_to_tokens.py | 103 ++ .../ASR/local/display_manifest_statistics.py | 97 ++ egs/mgb2/ASR/local/generate_unique_lexicon.py | 1 + egs/mgb2/ASR/local/prep_mgb2_lexicon.sh | 30 + egs/mgb2/ASR/local/prepare_lang.py | 1 + egs/mgb2/ASR/local/prepare_lang_bpe.py | 1 + egs/mgb2/ASR/local/prepare_mgb2_lexicon.py | 37 + egs/mgb2/ASR/local/test_prepare_lang.py | 1 + egs/mgb2/ASR/prepare.sh | 234 ++++ .../pruned_transducer_stateless5/__init__.py | 0 .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../pruned_transducer_stateless5/conformer.py | 1 + .../pruned_transducer_stateless5/decode.py | 625 +++++++++ .../pruned_transducer_stateless5/decoder.py | 1 + .../encoder_interface.py | 1 + .../pruned_transducer_stateless5/export.py | 272 ++++ .../pruned_transducer_stateless5/joiner.py | 1 + .../ASR/pruned_transducer_stateless5/model.py | 1 + .../ASR/pruned_transducer_stateless5/optim.py | 1 + .../pretrained.py | 344 +++++ .../pruned_transducer_stateless5/scaling.py | 1 + .../test_model.py | 1 + .../ASR/pruned_transducer_stateless5/train.py | 1176 +++++++++++++++++ egs/mgb2/ASR/shared | 1 + icefall/diagnostics.py | 2 +- 52 files changed, 6114 insertions(+), 1 deletion(-) create mode 100644 egs/mgb2/ASR/README.md create mode 100644 egs/mgb2/ASR/RESULTS.md create mode 100644 egs/mgb2/ASR/conformer_ctc/__init__.py create mode 100755 egs/mgb2/ASR/conformer_ctc/ali.py create mode 100644 egs/mgb2/ASR/conformer_ctc/asr_datamodule.py create mode 120000 egs/mgb2/ASR/conformer_ctc/compile_hlg.py create mode 120000 egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py create mode 120000 egs/mgb2/ASR/conformer_ctc/conformer.py create mode 120000 egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py create mode 100755 egs/mgb2/ASR/conformer_ctc/decode.py create mode 120000 egs/mgb2/ASR/conformer_ctc/download_lm.py create mode 120000 egs/mgb2/ASR/conformer_ctc/export.py create mode 120000 egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py create mode 120000 egs/mgb2/ASR/conformer_ctc/label_smoothing.py create mode 100755 egs/mgb2/ASR/conformer_ctc/pretrained.py create mode 120000 egs/mgb2/ASR/conformer_ctc/subsampling.py create mode 120000 egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py create mode 120000 egs/mgb2/ASR/conformer_ctc/test_subsampling.py create mode 120000 egs/mgb2/ASR/conformer_ctc/test_transformer.py create mode 100755 egs/mgb2/ASR/conformer_ctc/train.py create mode 120000 egs/mgb2/ASR/conformer_ctc/transformer.py create mode 100644 egs/mgb2/ASR/local/__init__.py create mode 120000 egs/mgb2/ASR/local/compile_hlg.py create mode 100755 egs/mgb2/ASR/local/compute_fbank_mgb2.py create mode 100755 egs/mgb2/ASR/local/compute_fbank_musan.py create mode 100755 egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py create mode 100755 egs/mgb2/ASR/local/display_manifest_statistics.py create mode 120000 egs/mgb2/ASR/local/generate_unique_lexicon.py create mode 100755 egs/mgb2/ASR/local/prep_mgb2_lexicon.sh create mode 120000 egs/mgb2/ASR/local/prepare_lang.py create mode 120000 egs/mgb2/ASR/local/prepare_lang_bpe.py create mode 100755 egs/mgb2/ASR/local/prepare_mgb2_lexicon.py create mode 120000 egs/mgb2/ASR/local/test_prepare_lang.py create mode 100755 egs/mgb2/ASR/prepare.sh create mode 100644 egs/mgb2/ASR/pruned_transducer_stateless5/__init__.py create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py create mode 100755 egs/mgb2/ASR/pruned_transducer_stateless5/decode.py create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py create mode 100755 egs/mgb2/ASR/pruned_transducer_stateless5/export.py create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/model.py create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/optim.py create mode 100755 egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py create mode 120000 egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py create mode 100755 egs/mgb2/ASR/pruned_transducer_stateless5/train.py create mode 120000 egs/mgb2/ASR/shared diff --git a/.gitignore b/.gitignore index 406deff6a..583410f45 100644 --- a/.gitignore +++ b/.gitignore @@ -11,5 +11,25 @@ log *.bak *-bak *bak.py + +# Ignore Mac system files +.DS_store + +# Ignore node_modules folder +node_modules + +# ignore .nfs + +.nfs* + +# Ignore all text files +*.txt + +# Ignore files related to API keys +.env + +# Ignore SASS config files +.sass-cache + *.param *.bin diff --git a/egs/mgb2/ASR/README.md b/egs/mgb2/ASR/README.md new file mode 100644 index 000000000..2bc4b000b --- /dev/null +++ b/egs/mgb2/ASR/README.md @@ -0,0 +1,43 @@ +# MGB2 + +The Multi-Dialect Broadcast News Arabic Speech Recognition (MGB-2): +The second edition of the Multi-Genre Broadcast (MGB-2) Challenge is +an evaluation of speech recognition and lightly supervised alignment +using TV recordings in Arabic. The speech data is broad and multi-genre, +spanning the whole range of TV output, and represents a challenging task for +speech technology. In 2016, the challenge featured two new Arabic tracks based +on TV data from Aljazeera. It was an official challenge at the 2016 IEEE +Workshop on Spoken Language Technology. The 1,200 hours MGB-2: from Aljazeera +TV programs have been manually captioned with no timing information. +QCRI Arabic ASR system has been used to recognize all programs. The ASR output +was used to align the manual captioning and produce speech segments for +training speech recognition. More than 20 hours from 2015 programs have been +transcribed verbatim and manually segmented. This data is split into a +development set of 10 hours, and a similar evaluation set of 10 hours. +Both the development and evaluation data have been released in the 2016 MGB +challenge + +Official reference: + +Ali, Ahmed, et al. "The MGB-2 challenge: Arabic multi-dialect broadcast media recognition." +2016 IEEE Spoken Language Technology Workshop (SLT). IEEE, 2016. + +IEEE link: https://ieeexplore.ieee.org/abstract/document/7846277 + +## Stateless Pruned Transducer Performance Record (after 30 epochs) + +| | dev | test | comment | +|------------------------------------|------------|------------|------------------------------------------| +| greedy search | 15.52 | 15.28 | --epoch 18, --avg 5, --max-duration 200 | +| modified beam search | 13.88 | 13.7 | --epoch 18, --avg 5, --max-duration 200 | +| fast beam search | 14.62 | 14.36 | --epoch 18, --avg 5, --max-duration 200 | + +## Conformer-CTC Performance Record (after 40 epochs) + +| Decoding method | dev WER | test WER | +|---------------------------|------------|---------| +| attention-decoder | 15.62 | 15.01 | +| whole-lattice-rescoring | 15.89 | 15.08 | + + +See [RESULTS](/egs/mgb2/ASR/RESULTS.md) for details. diff --git a/egs/mgb2/ASR/RESULTS.md b/egs/mgb2/ASR/RESULTS.md new file mode 100644 index 000000000..2a7ea7664 --- /dev/null +++ b/egs/mgb2/ASR/RESULTS.md @@ -0,0 +1,236 @@ +# Results + + +### MGB2 all data BPE training results (Stateless Pruned Transducer) + +#### 2022-09-07 + +The WERs are + +| | dev | test | comment | +|------------------------------------|------------|------------|------------------------------------------| +| greedy search | 15.52 | 15.28 | --epoch 18, --avg 5, --max-duration 200 | +| modified beam search | 13.88 | 13.7 | --epoch 18, --avg 5, --max-duration 200 | +| fast beam search | 14.62 | 14.36 | --epoch 18, --avg 5, --max-duration 200| + +The training command for reproducing is given below: + +``` +export CUDA_VISIBLE_DEVICES="0,1,2,3" + + + +./pruned_transducer_stateless5/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless5/exp \ + --max-duration 300 \ + --num-buckets 50 +``` + +The tensorboard training log can be found at +https://tensorboard.dev/experiment/YyNv45pfQ0GqWzZ898WOlw/#scalars + +The decoding command is: +``` +epoch=18 +avg=5 +for method in greedy_search modified_beam_search fast_beam_search; do + ./pruned_transducer_stateless5/decode.py \ + --epoch $epoch \ + --beam-size 10 \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 200 \ + --decoding-method $method \ + --max-sym-per-frame 1 \ + --num-encoder-layers 12 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --use-averaged-model True +done +``` + +### MGB2 all data BPE training results (Conformer-CTC) (after 40 epochs) + +#### 2022-06-04 + +You can find a pretrained model, training logs, decoding logs, and decoding results at: +https://huggingface.co/AmirHussein/icefall-asr-mgb2-conformer_ctc-2022-27-06 + +The best WER, as of 2022-06-04, for the MGB2 test dataset is below + +Using whole lattice HLG decoding + n-gram LM rescoring + +| | dev | test | +|-----|------------|------------| +| WER | 15.62 | 15.01 | + +Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.1 | - | + + +Using n-best (n=0.5) attention decoder rescoring + +| | dev | test | +|-----|------------|------------| +| WER | 15.89 | 15.08 | + +Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.01 | 0.5 | + + +To reproduce the above result, use the following commands for training: + +# Note: the model was trained on V-100 32GB GPU + +``` +cd egs/mgb2/ASR +. ./path.sh +./prepare.sh +export CUDA_VISIBLE_DEVICES="0,1" +./conformer_ctc/train.py \ + --lang-dir data/lang_bpe_5000 \ + --att-rate 0.8 \ + --lr-factor 10 \ + --max-duration \ + --concatenate-cuts 0 \ + --world-size 2 \ + --bucketing-sampler 1 \ + --max-duration 100 \ + --start-epoch 0 \ + --num-epochs 40 + +``` + +and the following command for nbest decoding + +``` +./conformer_ctc/decode.py \ + --lang-dir data/lang_bpe_5000 \ + --max-duration 30 \ + --concatenate-cuts 0 \ + --bucketing-sampler 1 \ + --num-paths 1000 \ + --epoch 40 \ + --avg 5 \ + --method attention-decoder \ + --nbest-scale 0.5 +``` + +and the following command for whole-lattice decoding + +``` +./conformer_ctc/decode.py \ + --epoch 40 \ + --avg 5 \ + --exp-dir conformer_ctc/exp_5000_att0.8 \ + --lang-dir data/lang_bpe_5000 \ + --max-duration 30 \ + --concatenate-cuts 0 \ + --bucketing-sampler 1 \ + --num-paths 1000 \ + --method whole-lattice-rescoring +``` + + +The tensorboard log for training is available at +https://tensorboard.dev/experiment/QYNzOi52RwOX8yvtpl3hMw/#scalars + + +### MGB2 100h BPE training results (Conformer-CTC) (after 33 epochs) + +#### 2022-06-04 + +The best WER, as of 2022-06-04, for the MGB2 test dataset is below + +Using whole lattice HLG decoding + n-gram LM rescoring + +| | dev | test | +|-----|------------|------------| +| WER | 25.32 | 23.53 | + +Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.1 | - | + + +Using n-best (n=0.5) HLG decoding + n-gram LM rescoring + attention decoder rescoring: + +| | dev | test | +|-----|------------|------------| +| WER | 27.87 | 26.12 | + +Scale values used in n-gram LM rescoring and attention rescoring for the best WERs are: +| ngram_lm_scale | attention_scale | +|----------------|-----------------| +| 0.01 | 0.3 | + + +To reproduce the above result, use the following commands for training: + +# Note: the model was trained on V-100 32GB GPU + +``` +cd egs/mgb2/ASR +. ./path.sh +./prepare.sh +export CUDA_VISIBLE_DEVICES="0,1" +./conformer_ctc/train.py \ + --lang-dir data/lang_bpe_5000 \ + --att-rate 0.8 \ + --lr-factor 10 \ + --max-duration \ + --concatenate-cuts 0 \ + --world-size 2 \ + --bucketing-sampler 1 \ + --max-duration 100 \ + --start-epoch 0 \ + --num-epochs 40 + +``` + +and the following command for nbest decoding + +``` +./conformer_ctc/decode.py \ + --lang-dir data/lang_bpe_5000 \ + --max-duration 30 \ + --concatenate-cuts 0 \ + --bucketing-sampler 1 \ + --num-paths 1000 \ + --epoch 40 \ + --avg 5 \ + --method attention-decoder \ + --nbest-scale 0.5 +``` + +and the following command for whole-lattice decoding + +``` +./conformer_ctc/decode.py \ + --lang-dir data/lang_bpe_5000 \ + --max-duration 30 \ + --concatenate-cuts 0 \ + --bucketing-sampler 1 \ + --num-paths 1000 \ + --epoch 40 \ + --avg 5 \ + --method whole-lattice-rescoring +``` + +The tensorboard log for training is available at + + + + + diff --git a/egs/mgb2/ASR/conformer_ctc/__init__.py b/egs/mgb2/ASR/conformer_ctc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/mgb2/ASR/conformer_ctc/ali.py b/egs/mgb2/ASR/conformer_ctc/ali.py new file mode 100755 index 000000000..aea962dcd --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/ali.py @@ -0,0 +1,395 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + ./conformer_ctc/ali.py \ + --exp-dir ./conformer_ctc/exp \ + --lang-dir ./data/lang_bpe_500 \ + --epoch 20 \ + --avg 10 \ + --max-duration 300 \ + --dataset train-clean-100 \ + --out-dir data/ali +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import numpy as np +import torch +from asr_datamodule import LibriSpeechAsrDataModule +from conformer import Conformer +from lhotse import CutSet +from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.decode import one_best_decoding +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + encode_supervisions, + get_alignments, + setup_logger, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=34, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--out-dir", + type=str, + required=True, + help="""Output directory. + It contains 3 generated files: + + - labels_xxx.h5 + - aux_labels_xxx.h5 + - cuts_xxx.json.gz + + where xxx is the value of `--dataset`. For instance, if + `--dataset` is `train-clean-100`, it will contain 3 files: + + - `labels_train-clean-100.h5` + - `aux_labels_train-clean-100.h5` + - `cuts_train-clean-100.json.gz` + + Note: Both labels_xxx.h5 and aux_labels_xxx.h5 contain framewise + alignment. The difference is that labels_xxx.h5 contains repeats. + """, + ) + + parser.add_argument( + "--dataset", + type=str, + required=True, + help="""The name of the dataset to compute alignments for. + Possible values are: + - test-clean. + - test-other + - train-clean-100 + - train-clean-360 + - train-other-500 + - dev-clean + - dev-other + """, + ) + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "lm_dir": Path("data/lm"), + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "subsampling_factor": 4, + # Set it to 0 since attention decoder + # is not used for computing alignments + "num_decoder_layers": 0, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "output_beam": 10, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def compute_alignments( + model: torch.nn.Module, + dl: torch.utils.data.DataLoader, + labels_writer: FeaturesWriter, + aux_labels_writer: FeaturesWriter, + params: AttributeDict, + graph_compiler: BpeCtcTrainingGraphCompiler, +) -> CutSet: + """Compute the framewise alignments of a dataset. + + Args: + model: + The neural network model. + dl: + Dataloader containing the dataset. + params: + Parameters for computing alignments. + graph_compiler: + It converts token IDs to decoding graphs. + Returns: + Return a CutSet. Each cut has two custom fields: labels_alignment + and aux_labels_alignment, containing framewise alignments information. + Both are of type `lhotse.array.TemporalArray`. The difference between + the two alignments is that `labels_alignment` contain repeats. + """ + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + num_cuts = 0 + + device = graph_compiler.device + cuts = [] + for batch_idx, batch in enumerate(dl): + feature = batch["inputs"] + + # at entry, feature is [N, T, C] + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + cut_list = supervisions["cut"] + + for cut in cut_list: + assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}" + + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is [N, T, C] + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + # we need also to sort cut_ids as encode_supervisions() + # reorders "texts". + # In general, new2old is an identity map since lhotse sorts the returned + # cuts by duration in descending order + new2old = supervision_segments[:, 0].tolist() + + cut_list = [cut_list[i] for i in new2old] + + token_ids = graph_compiler.texts_to_ids(texts) + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + lattice = k2.intersect_dense( + decoding_graph, + dense_fsa_vec, + params.output_beam, + ) + + best_path = one_best_decoding( + lattice=lattice, + use_double_scores=params.use_double_scores, + ) + + labels_ali = get_alignments(best_path, kind="labels") + aux_labels_ali = get_alignments(best_path, kind="aux_labels") + assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) + for cut, labels, aux_labels in zip(cut_list, labels_ali, aux_labels_ali): + cut.labels_alignment = labels_writer.store_array( + key=cut.id, + value=np.asarray(labels, dtype=np.int32), + # frame shift is 0.01s, subsampling_factor is 4 + frame_shift=0.04, + temporal_dim=0, + start=0, + ) + cut.aux_labels_alignment = aux_labels_writer.store_array( + key=cut.id, + value=np.asarray(aux_labels, dtype=np.int32), + # frame shift is 0.01s, subsampling_factor is 4 + frame_shift=0.04, + temporal_dim=0, + start=0, + ) + + cuts += cut_list + + num_cuts += len(cut_list) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + return CutSet.from_cuts(cuts) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + + args.enable_spec_aug = False + args.enable_musan = False + args.return_cuts = True + args.concatenate_cuts = False + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-ali") + + logging.info(f"Computing alignments for {params.dataset} - started") + logging.info(params) + + out_dir = Path(params.out_dir) + out_dir.mkdir(exist_ok=True) + + out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" + out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" + out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz" + + for f in ( + out_labels_ali_filename, + out_aux_labels_ali_filename, + out_manifest_filename, + ): + if f.exists(): + logging.info(f"{f} exists - skipping") + return + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + model.to(device) + + if params.avg == 1: + load_checkpoint( + f"{params.exp_dir}/epoch-{params.epoch}.pt", model, strict=False + ) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + + model.eval() + + librispeech = LibriSpeechAsrDataModule(args) + if params.dataset == "test-clean": + test_clean_cuts = librispeech.test_clean_cuts() + dl = librispeech.test_dataloaders(test_clean_cuts) + elif params.dataset == "test-other": + test_other_cuts = librispeech.test_other_cuts() + dl = librispeech.test_dataloaders(test_other_cuts) + elif params.dataset == "train-clean-100": + train_clean_100_cuts = librispeech.train_clean_100_cuts() + dl = librispeech.train_dataloaders(train_clean_100_cuts) + elif params.dataset == "train-clean-360": + train_clean_360_cuts = librispeech.train_clean_360_cuts() + dl = librispeech.train_dataloaders(train_clean_360_cuts) + elif params.dataset == "train-other-500": + train_other_500_cuts = librispeech.train_other_500_cuts() + dl = librispeech.train_dataloaders(train_other_500_cuts) + elif params.dataset == "dev-clean": + dev_clean_cuts = librispeech.dev_clean_cuts() + dl = librispeech.valid_dataloaders(dev_clean_cuts) + else: + assert params.dataset == "dev-other", f"{params.dataset}" + dev_other_cuts = librispeech.dev_other_cuts() + dl = librispeech.valid_dataloaders(dev_other_cuts) + + logging.info(f"Processing {params.dataset}") + with NumpyHdf5Writer(out_labels_ali_filename) as labels_writer: + with NumpyHdf5Writer(out_aux_labels_ali_filename) as aux_labels_writer: + cut_set = compute_alignments( + model=model, + dl=dl, + labels_writer=labels_writer, + aux_labels_writer=aux_labels_writer, + params=params, + graph_compiler=graph_compiler, + ) + + cut_set.to_file(out_manifest_filename) + + logging.info( + f"For dataset {params.dataset}, its alignments with repeats are " + f"saved to {out_labels_ali_filename}, the alignments without repeats " + f"are saved to {out_aux_labels_ali_filename}, and the cut manifest " + f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}" + ) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py new file mode 100644 index 000000000..8242e986d --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/asr_datamodule.py @@ -0,0 +1,372 @@ +# Copyright 2022 Johns Hopkins University (Amir Hussein) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import ( + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class MGB2AsrDataModule: + + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=1, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz") + + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else PrecomputedFeatures(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, max_duration=self.args.max_duration, shuffle=False + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz") + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + + return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz") + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + + return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz") diff --git a/egs/mgb2/ASR/conformer_ctc/compile_hlg.py b/egs/mgb2/ASR/conformer_ctc/compile_hlg.py new file mode 120000 index 000000000..471aa7fb4 --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/compile_hlg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py b/egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/conformer.py b/egs/mgb2/ASR/conformer_ctc/conformer.py new file mode 120000 index 000000000..d1f4209d7 --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/conformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/conformer.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py b/egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py new file mode 120000 index 000000000..2ce13fd69 --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/convert_transcript_words_to_tokens.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/decode.py b/egs/mgb2/ASR/conformer_ctc/decode.py new file mode 100755 index 000000000..f771d7f1e --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/decode.py @@ -0,0 +1,695 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import pdb +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import MGB2AsrDataModule +from conformer import Conformer + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=50, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=5, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--method", + type=str, + default="attention-decoder", + help="""Decoding method. + Supported values are: + - (0) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (1) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (2) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (3) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (4) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + - (5) attention-decoder. Extract n paths from the LM rescored + lattice, the path with the highest score is the decoding result. + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=20, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, attention-decoder, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="The lang dir", + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_decoder_layers": 6, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + "env_info": get_env_info(), + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + + - params.method is "1best", it uses 1best decoding without LM rescoring. + - params.method is "nbest", it uses nbest decoding without LM rescoring. + - params.method is "nbest-rescoring", it uses nbest LM rescoring. + - params.method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + + nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.method in ["1best", "nbest"]: + if params.method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + elif params.method == "attention-decoder": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + # TODO: pass `lattice` instead of `rescored_lattice` to + # `rescore_with_attention_decoder` + + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + nbest_scale=params.nbest_scale, + ) + else: + assert False, f"Unsupported decoding method: {params.method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + sos_id: int, + eos_id: int, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.method is ctc-decoding. + word_table: + It is the word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + G: + An LM. It is not None when params.method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + # pdb.set_trace() + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + if hyps_dict is not None: + for lm_scale, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[lm_scale].extend(this_batch) + else: + assert len(results) > 0, "It should not decode to empty in the first batch!" + this_batch = [] + hyp_words = [] + for ref_text in texts: + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + for lm_scale in results.keys(): + results[lm_scale].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + if params.method == "attention-decoder": + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" + store_transcripts(filename=recog_path, texts=results) + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) + test_set_wers[key] = wer + + if enable_log: + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + MGB2AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode") + logging.info("Decoding started") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + if params.method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + "attention-decoder", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.method in ["whole-lattice-rescoring", "attention-decoder"]: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.to(device) + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + MGB2 = MGB2AsrDataModule(args) + + test_cuts = MGB2.test_cuts() + dev_cuts = MGB2.dev_cuts() + + test_dl = MGB2.test_dataloaders(test_cuts) + dev_dl = MGB2.test_dataloaders(dev_cuts) + + test_sets = ["test", "dev"] + test_all_dl = [test_dl, dev_dl] + + for test_set, test_dl in zip(test_sets, test_all_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + sos_id=sos_id, + eos_id=eos_id, + ) + + save_results(params=params, test_set_name=test_set, results_dict=results_dict) + + logging.info("Done!") + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/mgb2/ASR/conformer_ctc/download_lm.py b/egs/mgb2/ASR/conformer_ctc/download_lm.py new file mode 120000 index 000000000..c9668bd2d --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/download_lm.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/download_lm.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/export.py b/egs/mgb2/ASR/conformer_ctc/export.py new file mode 120000 index 000000000..60e314d9d --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/export.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py b/egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py new file mode 120000 index 000000000..c0aea1403 --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/generate_unique_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/generate_unique_lexicon.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/label_smoothing.py b/egs/mgb2/ASR/conformer_ctc/label_smoothing.py new file mode 120000 index 000000000..e9d239fff --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/pretrained.py b/egs/mgb2/ASR/conformer_ctc/pretrained.py new file mode 100755 index 000000000..d30ca98d8 --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/pretrained.py @@ -0,0 +1,430 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Mingshuang Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from conformer import Conformer +from torch.nn.utils.rnn import pad_sequence + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_attention_decoder, + rescore_with_whole_lattice, +) +from icefall.utils import AttributeDict, get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + (3) attention-decoder - Extract n paths from the rescored + lattice and use the transformer attention decoder for + rescoring. + We call it HLG decoding + n-gram LM rescoring + attention + decoder rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or attention-decoder. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and attention-decoder. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--attention-decoder-scale", + type=float, + default=1.2, + help=""" + Used only when method is attention-decoder. + It specifies the scale for attention decoder scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is attention-decoder. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--sos-id", + type=int, + default=1, + help=""" + Used only when method is attention-decoder. + It specifies ID for the SOS token. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--eos-id", + type=int, + default=1, + help=""" + Used only when method is attention-decoder. + It specifies ID for the EOS token. + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def get_params() -> AttributeDict: + params = AttributeDict( + { + "sample_rate": 16000, + # parameters for conformer + "subsampling_factor": 4, + "vgg_frontend": False, + "use_feat_batchnorm": True, + "feature_dim": 80, + "nhead": 8, + "attention_dim": 512, + "num_decoder_layers": 6, + # parameters for decoding + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + if args.method != "attention-decoder": + # to save memory as the attention decoder + # will not be used + params.num_decoder_layers = 0 + + params.update(vars(args)) + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=params.num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=params.vgg_frontend, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + # Note: We don't use key padding mask for attention during decoding + with torch.no_grad(): + nnet_output, memory, memory_key_padding_mask = model(features) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "whole-lattice-rescoring", + "attention-decoder", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = G.to(device) + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "attention-decoder": + logging.info("Use HLG + LM rescoring + attention decoder rescoring") + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=None + ) + best_path_dict = rescore_with_attention_decoder( + lattice=rescored_lattice, + num_paths=params.num_paths, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=params.sos_id, + eos_id=params.eos_id, + nbest_scale=params.nbest_scale, + ngram_lm_scale=params.ngram_lm_scale, + attention_scale=params.attention_decoder_scale, + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/mgb2/ASR/conformer_ctc/subsampling.py b/egs/mgb2/ASR/conformer_ctc/subsampling.py new file mode 120000 index 000000000..16354dc73 --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/subsampling.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py b/egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py new file mode 120000 index 000000000..04b959ecf --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/test_label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/test_label_smoothing.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/test_subsampling.py b/egs/mgb2/ASR/conformer_ctc/test_subsampling.py new file mode 120000 index 000000000..98c3be3e6 --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/test_subsampling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/test_subsampling.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/test_transformer.py b/egs/mgb2/ASR/conformer_ctc/test_transformer.py new file mode 120000 index 000000000..8b0990ec6 --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/test_transformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/test_transformer.py \ No newline at end of file diff --git a/egs/mgb2/ASR/conformer_ctc/train.py b/egs/mgb2/ASR/conformer_ctc/train.py new file mode 100755 index 000000000..08ffee210 --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/train.py @@ -0,0 +1,766 @@ +#!/usr/bin/env python3 +# Copyright 2022 Johns Hopkins University (Amir Hussein) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + + +import argparse +import logging +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import k2 +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import MGB2AsrDataModule +from conformer import Conformer +from lhotse.cut import Cut +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter +from transformer import Noam + +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=50, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + conformer_ctc/exp/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="conformer_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_bpe_500", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--att-rate", + type=float, + default=0.8, + help="""The attention rate. + The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss + """, + ) + + parser.add_argument( + "--num-decoder-layers", + type=int, + default=6, + help="""Number of decoder layer of transformer decoder. + Setting this to 0 will not create the decoder at all (pure CTC model) + """, + ) + + parser.add_argument( + "--lr-factor", + type=float, + default=5.0, + help="The lr_factor for Noam optimizer", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - use_feat_batchnorm: Normalization for the input features, can be a + boolean indicating whether to do batch + normalization, or a float which means just scaling + the input features with this float value. + If given a float value, we will remove batchnorm + layer in `ConvolutionModule` as well. + + - attention_dim: Hidden dim for multi-head attention model. + + - head: Number of heads of multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - beam_size: It is used in k2.ctc_loss + + - reduction: It is used in k2.ctc_loss + + - use_double_scores: It is used in k2.ctc_loss + + - weight_decay: The weight_decay for the optimizer. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "use_feat_batchnorm": True, + "attention_dim": 512, + "nhead": 8, + "num_decoder_layers": 6, + # parameters for loss + "beam_size": 10, + "reduction": "sum", + "use_double_scores": True, + # parameters for Noam + "weight_decay": 1e-6, + "warm_step": 80000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: nn.Module, + batch: dict, + graph_compiler: BpeCtcTrainingGraphCompiler, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + graph_compiler: + It is used to build a decoding graph from a ctc topo and training + transcript. The training transcript is contained in the given `batch`, + while the ctc topo is built when this compiler is instantiated. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + """ + device = graph_compiler.device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + with torch.set_grad_enabled(is_training): + nnet_output, encoder_memory, memory_mask = model(feature, supervisions) + # nnet_output is (N, T, C) + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + supervision_segments, texts = encode_supervisions( + supervisions, subsampling_factor=params.subsampling_factor + ) + + token_ids = graph_compiler.texts_to_ids(texts) + + decoding_graph = graph_compiler.compile(token_ids) + + dense_fsa_vec = k2.DenseFsaVec( + nnet_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction="none", + use_double_scores=params.use_double_scores, + ) + # filter inf from ctc_loss + ctc_loss = torch.sum( + torch.where( + ctc_loss != float("inf"), + ctc_loss, + torch.tensor(0, dtype=torch.float32).to(device), + ) + ) + + if params.att_rate != 0.0: + with torch.set_grad_enabled(is_training): + mmodel = model.module if hasattr(model, "module") else model + # Note: We need to generate an unsorted version of token_ids + # `encode_supervisions()` called above sorts text, but + # encoder_memory and memory_mask are not sorted, so we + # use an unsorted version `supervisions["text"]` to regenerate + # the token_ids + # + # See https://github.com/k2-fsa/icefall/issues/97 + # for more details + unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + + att_loss = mmodel.decoder_forward( + encoder_memory, + memory_mask, + token_ids=unsorted_token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss + else: + loss = ctc_loss + att_loss = torch.tensor([0]) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + info["frames"] = supervision_segments[:, 2].sum().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + if params.att_rate != 0.0: + info["att_loss"] = att_loss.detach().cpu().item() + + info["loss"] = loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + graph_compiler: BpeCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + graph_compiler: + It is used to convert transcripts to FSAs. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + loss, loss_info = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + # if tot_loss is None: + # logging.warning("Batch mismatch. Skipping ...") + # del batch + # del tot_loss + # continue; + # elif tot_loss.isinf() or tot_loss.isnan(): + # logging.warning("NaN or Inf loss. Skipping ...") + # del batch + # del tot_loss + # continue; + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}" + ) + + if batch_idx % params.log_interval == 0: + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + else: + logging.warning( + f"Batch {batch_idx} mismatch in dimentions between the input and the output. Skipping ..." + ) + continue + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(42) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + + logging.info("About to create model") + model = Conformer( + num_features=params.feature_dim, + nhead=params.nhead, + d_model=params.attention_dim, + num_classes=num_classes, + subsampling_factor=params.subsampling_factor, + num_decoder_layers=params.num_decoder_layers, + vgg_frontend=False, + use_feat_batchnorm=params.use_feat_batchnorm, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if world_size > 1: + model = DDP(model, device_ids=[rank]) + + optimizer = Noam( + model.parameters(), + model_size=params.attention_dim, + factor=params.lr_factor, + warm_step=params.warm_step, + weight_decay=params.weight_decay, + ) + + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + + MGB2 = MGB2AsrDataModule(args) + + train_cuts = MGB2.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 0.5 <= c.duration <= 30.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_dl = MGB2.train_dataloaders(train_cuts) + + valid_cuts = MGB2.dev_cuts() + valid_dl = MGB2.test_dataloaders(valid_cuts) + + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + for epoch in range(params.start_epoch, params.num_epochs): + train_dl.sampler.set_epoch(epoch) + + cur_lr = optimizer._rate + if tb_writer is not None: + tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + if rank == 0: + logging.info("epoch {}, learning rate {}".format(epoch, cur_lr)) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + graph_compiler=graph_compiler, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: nn.Module, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: BpeCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 0 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + optimizer.zero_grad() + loss, _ = compute_loss( + params=params, + model=model, + batch=batch, + graph_compiler=graph_compiler, + is_training=True, + ) + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + MGB2AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/mgb2/ASR/conformer_ctc/transformer.py b/egs/mgb2/ASR/conformer_ctc/transformer.py new file mode 120000 index 000000000..1c3f43fcf --- /dev/null +++ b/egs/mgb2/ASR/conformer_ctc/transformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/transformer.py \ No newline at end of file diff --git a/egs/mgb2/ASR/local/__init__.py b/egs/mgb2/ASR/local/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/mgb2/ASR/local/compile_hlg.py b/egs/mgb2/ASR/local/compile_hlg.py new file mode 120000 index 000000000..471aa7fb4 --- /dev/null +++ b/egs/mgb2/ASR/local/compile_hlg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/mgb2/ASR/local/compute_fbank_mgb2.py b/egs/mgb2/ASR/local/compute_fbank_mgb2.py new file mode 100755 index 000000000..6cae69e41 --- /dev/null +++ b/egs/mgb2/ASR/local/compute_fbank_mgb2.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +# Copyright 2022 Johns Hopkins University (Amir Hussein) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the MGB2 dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_mgb2(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + num_jobs = min(15, os.cpu_count()) + num_mel_bins = 80 + + dataset_parts = ( + "train", + "test", + "dev", + ) + manifests = read_manifests_if_cached( + prefix="mgb2", dataset_parts=dataset_parts, output_dir=src_dir + ) + assert manifests is not None + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + if (output_dir / f"cuts_{partition}.json.gz").is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if "train" in partition: + cut_set = ( + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + ) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + logging.info("About to split cuts into smaller chunks.") + cut_set = cut_set.trim_to_supervisions( + keep_overlapping=False, min_duration=None + ) + cut_set.to_file(output_dir / f"cuts_{partition}.jsonl.gz") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + compute_fbank_mgb2() diff --git a/egs/mgb2/ASR/local/compute_fbank_musan.py b/egs/mgb2/ASR/local/compute_fbank_musan.py new file mode 100755 index 000000000..5d0d69a13 --- /dev/null +++ b/egs/mgb2/ASR/local/compute_fbank_musan.py @@ -0,0 +1,108 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the musan dataset. +It looks for manifests in the directory data/manifests. +The generated fbank features are saved in data/fbank. +""" + +import logging +import os +from pathlib import Path + +import torch +from lhotse import ( + ChunkedLilcomHdf5Writer, + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + combine, +) +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def compute_fbank_musan(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + num_jobs = min(15, os.cpu_count()) + num_mel_bins = 80 + + dataset_parts = ( + "music", + "speech", + "noise", + ) + prefix = "musan" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + prefix=prefix, + dataset_parts=dataset_parts, + output_dir=src_dir, + suffix=suffix, + ) + assert manifests is not None + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + ) + + musan_cuts_path = output_dir / "cuts_musan.jsonl.gz" + + if musan_cuts_path.is_file(): + logging.info(f"{musan_cuts_path} already exists - skipping") + return + + logging.info("Extracting features for Musan") + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + # create chunks of Musan with duration 5 - 10 seconds + musan_cuts = ( + CutSet.from_manifests( + recordings=combine(part["recordings"] for part in manifests.values()) + ) + .cut_into_windows(10.0) + .filter(lambda c: c.duration > 5) + .compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/feats_musan", + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + ) + musan_cuts.to_file(musan_cuts_path) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_fbank_musan() diff --git a/egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py b/egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py new file mode 100755 index 000000000..a8d5117c9 --- /dev/null +++ b/egs/mgb2/ASR/local/convert_transcript_words_to_tokens.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 + +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +""" +Convert a transcript file containing words to a corpus file containing tokens +for LM training with the help of a lexicon. + +If the lexicon contains phones, the resulting LM will be a phone LM; If the +lexicon contains word pieces, the resulting LM will be a word piece LM. + +If a word has multiple pronunciations, the one that appears first in the lexicon +is kept; others are removed. + +If the input transcript is: + + hello zoo world hello + world zoo + foo zoo world hellO + +and if the lexicon is + + SPN + hello h e l l o 2 + hello h e l l o + world w o r l d + zoo z o o + +Then the output is + + h e l l o 2 z o o w o r l d h e l l o 2 + w o r l d z o o + SPN z o o w o r l d SPN +""" + +import argparse +from pathlib import Path +from typing import Dict, List + +from generate_unique_lexicon import filter_multiple_pronunications + +from icefall.lexicon import read_lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--transcript", + type=str, + help="The input transcript file." + "We assume that the transcript file consists of " + "lines. Each line consists of space separated words.", + ) + parser.add_argument("--lexicon", type=str, help="The input lexicon file.") + parser.add_argument("--oov", type=str, default="", help="The OOV word.") + + return parser.parse_args() + + +def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: + """ + Args: + lexicon: + A dict containing pronunciations. Its keys are words and values + are pronunciations (i.e., tokens). + line: + A line of transcript consisting of space(s) separated words. + oov_token: + The pronunciation of the oov word if a word in `line` is not present + in the lexicon. + Returns: + Return None. + """ + s = "" + words = line.strip().split() + for i, w in enumerate(words): + tokens = lexicon.get(w, oov_token) + s += " ".join(tokens) + s += " " + print(s.strip()) + + +def main(): + args = get_args() + assert Path(args.lexicon).is_file() + assert Path(args.transcript).is_file() + assert len(args.oov) > 0 + + # Only the first pronunciation of a word is kept + lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon)) + + lexicon = dict(lexicon) + + assert args.oov in lexicon + + oov_token = lexicon[args.oov] + + with open(args.transcript) as f: + for line in f: + process_line(lexicon=lexicon, line=line, oov_token=oov_token) + + +if __name__ == "__main__": + main() diff --git a/egs/mgb2/ASR/local/display_manifest_statistics.py b/egs/mgb2/ASR/local/display_manifest_statistics.py new file mode 100755 index 000000000..d3e224905 --- /dev/null +++ b/egs/mgb2/ASR/local/display_manifest_statistics.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in a manifest. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. + +See the function `remove_short_and_long_utt()` in transducer/train.py +for usage. +""" + + +from lhotse import load_manifest + + +def main(): + # path = "./data/fbank/cuts_train.jsonl.gz" + path = "./data/fbank/cuts_dev.jsonl.gz" + # path = "./data/fbank/cuts_test.jsonl.gz" + + cuts = load_manifest(path) + cuts.describe() + + +if __name__ == "__main__": + main() + +""" +# train + +Cuts count: 1125309 +Total duration (hours): 3403.9 +Speech duration (hours): 3403.9 (100.0%) +*** +Duration statistics (seconds): +mean 10.9 +std 10.1 +min 0.2 +25% 5.2 +50% 7.8 +75% 12.7 +99% 52.0 +99.5% 65.1 +99.9% 99.5 +max 228.9 + + +# test +Cuts count: 5365 +Total duration (hours): 9.6 +Speech duration (hours): 9.6 (100.0%) +*** +Duration statistics (seconds): +mean 6.4 +std 1.5 +min 1.6 +25% 5.3 +50% 6.5 +75% 7.6 +99% 9.5 +99.5% 9.7 +99.9% 10.3 +max 12.4 + +# dev +Cuts count: 5002 +Total duration (hours): 8.5 +Speech duration (hours): 8.5 (100.0%) +*** +Duration statistics (seconds): +mean 6.1 +std 1.7 +min 1.5 +25% 4.8 +50% 6.2 +75% 7.4 +99% 9.5 +99.5% 9.7 +99.9% 10.1 +max 20.3 + +""" diff --git a/egs/mgb2/ASR/local/generate_unique_lexicon.py b/egs/mgb2/ASR/local/generate_unique_lexicon.py new file mode 120000 index 000000000..c0aea1403 --- /dev/null +++ b/egs/mgb2/ASR/local/generate_unique_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/generate_unique_lexicon.py \ No newline at end of file diff --git a/egs/mgb2/ASR/local/prep_mgb2_lexicon.sh b/egs/mgb2/ASR/local/prep_mgb2_lexicon.sh new file mode 100755 index 000000000..3b673db6f --- /dev/null +++ b/egs/mgb2/ASR/local/prep_mgb2_lexicon.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +# Copyright 2022 QCRI (author: Amir Hussein) +# Apache 2.0 +# This script prepares the graphemic lexicon. + +dir=data/local/dict +lexicon_url1="https://arabicspeech.org/arabicspeech-portal-resources/lexicon/ar-ar_grapheme_lexicon_20160209.bz2"; +lexicon_url2="https://arabicspeech.org/arabicspeech-portal-resources/lexicon/ar-ar_phoneme_lexicon_20140317.bz2"; +stage=0 +lang_dir=download/lm +mkdir -p $lang_dir + +if [ $stage -le 0 ]; then + echo "$0: Downloading text for lexicon... $(date)." + wget --no-check-certificate -P $lang_dir $lexicon_url1 + wget --no-check-certificate -P $lang_dir $lexicon_url2 + bzcat $lang_dir/ar-ar_grapheme_lexicon_20160209.bz2 | sed '1,3d' | awk '{print $1}' > $lang_dir/grapheme_lexicon + bzcat $lang_dir/ar-ar_phoneme_lexicon_20140317.bz2 | sed '1,3d' | awk '{print $1}' >> $lang_dir/phoneme_lexicon + cat download/lm/train/text | cut -d ' ' -f 2- | tr -s " " "\n" | sort -u >> $lang_dir/uniq_words +fi + + +if [ $stage -le 0 ]; then + echo "$0: processing lexicon text and creating lexicon... $(date)." + # remove vowels and rare alef wasla + cat $lang_dir/uniq_words | sed -e 's:[FNKaui\~o\`]::g' -e 's:{:}:g' | sed -r '/^\s*$/d' | sort -u > $lang_dir/grapheme_lexicon.txt +fi + +echo "$0: Lexicon preparation succeeded" diff --git a/egs/mgb2/ASR/local/prepare_lang.py b/egs/mgb2/ASR/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/mgb2/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/mgb2/ASR/local/prepare_lang_bpe.py b/egs/mgb2/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/mgb2/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/mgb2/ASR/local/prepare_mgb2_lexicon.py b/egs/mgb2/ASR/local/prepare_mgb2_lexicon.py new file mode 100755 index 000000000..99e1fa34d --- /dev/null +++ b/egs/mgb2/ASR/local/prepare_mgb2_lexicon.py @@ -0,0 +1,37 @@ +#!/usr/bin/env python3 + +# Copyright 2022 Amir Hussein +# Apache 2.0 + +# This script prepares givel a column of words lexicon. + +import argparse + + +def get_args(): + parser = argparse.ArgumentParser( + description="""Creates the list of characters and words in lexicon""" + ) + parser.add_argument("input", type=str, help="""Input list of words file""") + parser.add_argument("output", type=str, help="""output graphemic lexicon""") + args = parser.parse_args() + return args + + +def main(): + lex = {} + args = get_args() + with open(args.input, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + characters = list(line) + characters = " ".join(["V" if char == "*" else char for char in characters]) + lex[line] = characters + + with open(args.output, "w", encoding="utf-8") as fp: + for key in sorted(lex): + fp.write(key + " " + lex[key] + "\n") + + +if __name__ == "__main__": + main() diff --git a/egs/mgb2/ASR/local/test_prepare_lang.py b/egs/mgb2/ASR/local/test_prepare_lang.py new file mode 120000 index 000000000..f0f864998 --- /dev/null +++ b/egs/mgb2/ASR/local/test_prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/test_prepare_lang.py \ No newline at end of file diff --git a/egs/mgb2/ASR/prepare.sh b/egs/mgb2/ASR/prepare.sh new file mode 100755 index 000000000..899d15d97 --- /dev/null +++ b/egs/mgb2/ASR/prepare.sh @@ -0,0 +1,234 @@ +#!/usr/bin/env bash +# Copyright 2022 Johns Hopkins University (Amir Hussein) +# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) + +set -eou pipefail +nj=30 +stage=7 +stop_stage=1000 + +# We assume dl_dir (download dir) contains the following +# directories and files. +# +# - $dl_dir/mgb2 +# +# You can download the data from +# +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech +# +# Note: MGB2 is not available for direct +# download, however you can fill out the form and +# download it from https://arabicspeech.org/mgb2 + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + 5000 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/MGB2, + # you can create a symlink + # + # ln -sfv /path/to/mgb2 $dl_dir/MGB2 + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare mgb2 manifest" + # We assume that you have downloaded the mgb2 corpus + # to $dl_dir/mgb2 + mkdir -p data/manifests + + lhotse prepare mgb2 $dl_dir/mgb2 data/manifests + +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + mkdir -p data/manifests + lhotse prepare musan $dl_dir/musan data/manifests +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute fbank for mgb2" + mkdir -p data/fbank + ./local/compute_fbank_mgb2.py + # shufling the data + gunzip -c data/fbank/cuts_train.jsonl.gz | shuf | gzip -c > data/fbank/cuts_train_shuf.jsonl.gz +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + mkdir -p data/fbank + ./local/compute_fbank_musan.py +fi + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare phone based lang" + if [[ ! -e download/lm/train/text ]]; then + # export train text file to build grapheme lexicon + lhotse kaldi export \ + data/manifests/mgb2_recordings_train.jsonl.gz \ + data/manifests/mgb2_supervisions_train.jsonl.gz \ + download/lm/train + fi + + lang_dir=data/lang_phone + mkdir -p $lang_dir + ./local/prep_mgb2_lexicon.sh + python local/prepare_mgb2_lexicon.py $dl_dir/lm/grapheme_lexicon.txt $dl_dir/lm/lexicon.txt + (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | + cat - $dl_dir/lm/lexicon.txt | + sort | uniq > $lang_dir/lexicon.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_dir + fi +fi + + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data for BPE training" + files=$( + find "$dl_dir/lm/train" -name "text" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- | sed -r '/^\s*$/d' + done > $lang_dir/transcript_words.txt + fi + + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + fi + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare bigram P" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/transcript_tokens.txt ]; then + ./local/convert_transcript_words_to_tokens.py \ + --lexicon $lang_dir/lexicon.txt \ + --transcript $lang_dir/transcript_words.txt \ + --oov "" \ + > $lang_dir/transcript_tokens.txt + fi + + if [ ! -f $lang_dir/P.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 2 \ + -text $lang_dir/transcript_tokens.txt \ + -lm $lang_dir/P.arpa + fi + + if [ ! -f $lang_dir/P.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=2 \ + $lang_dir/P.arpa > $lang_dir/P.fst.txt + fi + done +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare G" + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p data/lm + if [ ! -f data/lm/G_3_gram.fst.txt ]; then + # It is used in building HLG + ./shared/make_kn_lm.py \ + -ngram-order 3 \ + -text $lang_dir/transcript_words.txt \ + -lm $lang_dir/G.arpa + + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + $lang_dir/G.arpa > data/lm/G_3_gram.fst.txt + fi + + if [ ! -f data/lm/G_4_gram.fst.txt ]; then + # It is used for LM rescoring + ./shared/make_kn_lm.py \ + -ngram-order 4 \ + -text $lang_dir/transcript_words.txt \ + -lm $lang_dir/4-gram.arpa + + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + $lang_dir/4-gram.arpa > data/lm/G_4_gram.fst.txt + fi + done +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compile HLG" + ./local/compile_hlg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + done +fi diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/__init__.py b/egs/mgb2/ASR/pruned_transducer_stateless5/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py new file mode 120000 index 000000000..a73848de9 --- /dev/null +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -0,0 +1 @@ +../conformer_ctc/asr_datamodule.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py b/egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py new file mode 120000 index 000000000..02d01b343 --- /dev/null +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/beam_search.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py b/egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py new file mode 120000 index 000000000..c7c1a4b6e --- /dev/null +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/conformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py new file mode 100755 index 000000000..1463f8f67 --- /dev/null +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/decode.py @@ -0,0 +1,625 @@ +#!/usr/bin/env python3 +# Copyright 2022 Johns Hopkins (authors: Amir Hussein) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless5/decode.py \ + --epoch 18 \ + --avg 5 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 200 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless5/decode.py \ + --epoch 18 \ + --avg 5 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 200 \ + --decoding-method beam_search \ + --beam-size 10 + +(3) modified beam search +./pruned_transducer_stateless5/decode.py \ + --epoch 18 \ + --avg 5 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 10 + +(4) fast beam search +./pruned_transducer_stateless5/decode.py \ + --epoch 18 \ + --avg 5 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 200 \ + --decoding-method fast_beam_search \ + --beam-size 10 \ + --max-contexts 4 \ + --max-states 8 +""" + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import MGB2AsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_2000/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + MGB2AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + MGB2 = MGB2AsrDataModule(args) + + test_cuts = MGB2.test_cuts() + dev_cuts = MGB2.dev_cuts() + + test_dl = MGB2.test_dataloaders(test_cuts) + dev_dl = MGB2.test_dataloaders(dev_cuts) + + test_sets = ["test", "dev"] + test_all_dl = [test_dl, dev_dl] + + for test_set, test_dl in zip(test_sets, test_all_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py b/egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py new file mode 120000 index 000000000..6775ee67e --- /dev/null +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/decoder.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py new file mode 120000 index 000000000..972e44ca4 --- /dev/null +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/encoder_interface.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/export.py b/egs/mgb2/ASR/pruned_transducer_stateless5/export.py new file mode 100755 index 000000000..7a5d7f680 --- /dev/null +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/export.py @@ -0,0 +1,272 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./pruned_transducer_stateless5/export.py \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file exp_dir/pretrained.pt + +To use the generated file with `pruned_transducer_stateless5/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless5/decode.py \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + assert args.jit is False, "Support torchscript will be added later" + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py b/egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py new file mode 120000 index 000000000..f5279e151 --- /dev/null +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/joiner.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/model.py b/egs/mgb2/ASR/pruned_transducer_stateless5/model.py new file mode 120000 index 000000000..7b417fd89 --- /dev/null +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/model.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/optim.py b/egs/mgb2/ASR/pruned_transducer_stateless5/optim.py new file mode 120000 index 000000000..210374f22 --- /dev/null +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/optim.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py b/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py new file mode 100755 index 000000000..77ba0873b --- /dev/null +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/pretrained.py @@ -0,0 +1,344 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) greedy search +./pruned_transducer_stateless5/pretrained.py \ + --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless5/pretrained.py \ + --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless5/pretrained.py \ + --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless5/pretrained.py \ + --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by +./pruned_transducer_stateless5/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py b/egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py new file mode 120000 index 000000000..ff7bfeda9 --- /dev/null +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/scaling.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py b/egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py new file mode 120000 index 000000000..b71d7bb81 --- /dev/null +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/test_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/test_model.py \ No newline at end of file diff --git a/egs/mgb2/ASR/pruned_transducer_stateless5/train.py b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py new file mode 100755 index 000000000..e1b623353 --- /dev/null +++ b/egs/mgb2/ASR/pruned_transducer_stateless5/train.py @@ -0,0 +1,1176 @@ +#!/usr/bin/env python3 +# Copyright 2022 Johns Hopkins (authors: Amir Hussein) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless5/train.py \ + --world-size 2 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless5/exp \ + --max-duration 200 \ + --num-buckets 50 + +# For mix precision training: + +./pruned_transducer_stateless5/train.py \ + --world-size 2 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless5/exp \ + --max-duration 200 \ + --num-buckets 50 + +""" + +# xxx +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import nvidia_smi +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import MGB2AsrDataModule +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=int, + default=12, + help="Number of conformer encoder layers..", + ) + + parser.add_argument( + "--dim-feedforward", + type=int, + default=2048, + help="Feedforward dimension of the conformer encoder layer.", + ) + + parser.add_argument( + "--nhead", + type=int, + default=8, + help="Number of attention heads in the conformer encoder layer.", + ) + + parser.add_argument( + "--encoder-dim", + type=int, + default=512, + help="Attention dimension in the conformer encoder layer.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_2000/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="The initial learning rate. This value should not need " "to be changed.", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=10, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=True, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + # parameters for Noam + "model_warm_step": 80000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Conformer( + num_features=params.feature_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + warmup: float = 1.0, + reduction="none", +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + reduction="none", + ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + is_finite = simple_loss_is_finite & pruned_loss_is_finite + inf_flag = False + if not torch.all(is_finite): + inf_flag = True + logging.info( + "Not all losses are finite!\n" + f"simple_loss: {simple_loss}\n" + f"pruned_loss: {pruned_loss}" + ) + display_and_save_batch(batch, params=params, sp=sp) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + + simple_loss = simple_loss.sum() + pruned_loss = pruned_loss.sum() + + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = ( + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + # info["utterances"] = feature.size(0) + # # averaged input duration in frames over utterances + # info["utt_duration"] = feature_lens.sum().item() + # # averaged padding proportion over utterances + # info["utt_pad_proportion"] = ( + # ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + # ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info, inf_flag + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + with torch.no_grad(): + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info, inf_flag = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + + if batch["inputs"].shape[0] == len(batch["supervisions"]["text"]): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info, inf_flag = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + if not inf_flag: + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + else: + continue + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + # https://silpara.medium.com/check-gpu-memory-usage-from-python-ccca503322ea + memory_debugging() + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + else: + logging.warning( + f"Batch {batch_idx} mismatch in dimentions between the input and the output. Skipping ..." + ) + continue + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def memory_debugging(): + # memory nvidia debugging + nvidia_smi.nvmlInit() + + deviceCount = nvidia_smi.nvmlDeviceGetCount() + for i in range(deviceCount): + handle = nvidia_smi.nvmlDeviceGetHandleByIndex(i) + info = nvidia_smi.nvmlDeviceGetMemoryInfo(handle) + logging.info( + "Device {}: {}, Memory : ({:.2f}% free): {}(total), {} (free), {} (used)".format( + i, + nvidia_smi.nvmlDeviceGetName(handle), + 100 * info.free / info.total, + info.total, + info.free, + info.used, + ) + ) + + nvidia_smi.nvmlShutdown() + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + MGB2 = MGB2AsrDataModule(args) + train_cuts = MGB2.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 30 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 0.5 <= c.duration <= 30.0 + + def remove_short_and_long_text(c: Cut): + # Keep only text with charachters between 20 and 450 + + return 20 <= len(c.supervisions[0].text) <= 450 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + train_cuts = train_cuts.filter(remove_short_and_long_text) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = MGB2.train_dataloaders(train_cuts, sampler_state_dict=sampler_state_dict) + + valid_cuts = MGB2.dev_cuts() + valid_dl = MGB2.test_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + # warmup = 0.0 is so that the derivs for the pruned loss stay zero + # (i.e. are not remembered by the decaying-average in adam), because + # we want to avoid these params being subject to shrinkage in adam. + with torch.cuda.amp.autocast(enabled=params.use_fp16): + + loss, _, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=0.0, + ) + loss.backward() + # clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + + +def main(): + parser = get_parser() + MGB2AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/mgb2/ASR/shared b/egs/mgb2/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/mgb2/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 207c12bf1..6589579d1 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -263,7 +263,7 @@ class TensorDiagnostic(object): ans += f", norm={norm:.2g}" mean = stats.mean().item() rms = (stats**2).mean().sqrt().item() - ans += f", mean={mean:.3g}, rms={rms:.3g}" + ans += f", mean={mean:.2g}, rms={rms:.2g}" # OK, "ans" contains the actual stats, e.g. # ans = "percentiles: [0.43 0.46 0.48 0.49 0.49 0.5 0.51 0.52 0.53 0.54 0.59], mean=0.5, rms=0.5" From 7700ddcb38b5ba0d91334947e3cac44825f1cf7c Mon Sep 17 00:00:00 2001 From: Weiji Zhuang Date: Fri, 2 Dec 2022 17:40:42 +0800 Subject: [PATCH 16/27] update multidataset zipformer results (#728) --- egs/librispeech/ASR/RESULTS.md | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index c2ea3d050..0885fb9b6 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -108,21 +108,25 @@ See for more details. [pruned_transducer_stateless8](./pruned_transducer_stateless8) The tensorboard log can be found at - + You can find a pretrained model, training logs, decoding logs, and decoding results at: - + You can use to deploy it. Number of model parameters: 70369391, i.e., 70.37 M -| | test-clean | test-other | comment | -|----------------------|------------|-------------|----------------------------------------| -| greedy search | 1.87 | 4.38 | --epoch 16 --avg 2 --max-duration 600 | -| modified beam search | 1.81 | 4.34 | --epoch 16 --avg 2 --max-duration 600 | -| fast beam search | 1.91 | 4.33 | --epoch 16 --avg 2 --max-duration 600 | +| decoding method | test-clean | test-other | comment | +|----------------------|------------|------------|--------------------| +| greedy_search | 1.81 | 4.18 | --epoch 20 --avg 4 | +| fast_beam_search | 1.82 | 4.15 | --epoch 20 --avg 4 | +| modified_beam_search | 1.78 | **4.08** | --epoch 20 --avg 4 | +| greedy_search | 1.84 | 4.3 | --epoch 19 --avg 8 | +| fast_beam_search |**1.77** | 4.25 | --epoch 19 --avg 8 | +| modified_beam_search | 1.81 | 4.16 | --epoch 19 --avg 8 | + The training commands are: ```bash @@ -142,15 +146,15 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" The decoding commands are: ```bash -for m in greedy_search fast_beam_search modified_beam_search ; do - for epoch in 16; do - for avg in 2; do +for m in greedy_search fast_beam_search modified_beam_search; do + for epoch in $(seq 20 -1 10); do + for avg in $(seq 9 -1 1); do ./pruned_transducer_stateless8/decode.py \ --epoch $epoch \ --avg $avg \ --use-averaged-model 1 \ --exp-dir ./pruned_transducer_stateless8/exp \ - --feedforward-dims "1024,1024,2048,2048,1024" \ + --feedforward-dims "1024,1024,2048,2048,1024" \ --max-duration 600 \ --decoding-method $m done From 8eb4b9d96da0432c1c27901f2964da954583d69a Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Sat, 3 Dec 2022 19:01:10 +0800 Subject: [PATCH 17/27] Combining rnnt loss and k2-ctc loss for Dan's Zipformer (#683) * init files * add ctc as auxiliary loss and ctc_decode.py * tuning the scalar of HLG score for 1best, nbest and nbest-oracle * rename to pruned_transducer_stateless7_ctc * fix doc * fix bug, recover the hlg scores * modify ctc_decode.py, move out the hlg scale * fix hlg_scale * add export.py and pretrained.py, and so on * upload files, update README.md and RESULTS.md * add CI test --- ...ed-transducer-stateless7-ctc-2022-12-01.sh | 147 ++ ...-librispeech-2022-12-01-stateless7-ctc.yml | 163 +++ egs/librispeech/ASR/README.md | 1 + egs/librispeech/ASR/RESULTS.md | 79 ++ .../ASR/conformer_ctc3/jit_pretrained.py | 20 +- .../__init__.py | 0 .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../ctc_decode.py | 818 +++++++++++ .../decode.py | 841 +++++++++++ .../decoder.py | 1 + .../encoder_interface.py | 1 + .../export.py | 320 +++++ .../jit_pretrained.py | 271 ++++ .../jit_pretrained_ctc.py | 423 ++++++ .../joiner.py | 1 + .../pruned_transducer_stateless7_ctc/model.py | 198 +++ .../pruned_transducer_stateless7_ctc/optim.py | 1 + .../pretrained.py | 353 +++++ .../pretrained_ctc.py | 441 ++++++ .../scaling.py | 1 + .../scaling_converter.py | 1 + .../test_model.py | 56 + .../pruned_transducer_stateless7_ctc/train.py | 1252 +++++++++++++++++ .../zipformer.py | 1 + icefall/utils.py | 18 +- 26 files changed, 5396 insertions(+), 14 deletions(-) create mode 100755 .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh create mode 100644 .github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/__init__.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/asr_datamodule.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/beam_search.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decoder.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/encoder_interface.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/joiner.py create mode 100644 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/optim.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling_converter.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/test_model.py create mode 100755 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py create mode 120000 egs/librispeech/ASR/pruned_transducer_stateless7_ctc/zipformer.py diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh new file mode 100755 index 000000000..6642d5f67 --- /dev/null +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh @@ -0,0 +1,147 @@ +#!/usr/bin/env bash + +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-ctc-2022-12-01 + +log "Downloading pre-trained model from $repo_url" +git lfs install +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +git lfs pull --include "data/*" +git lfs pull --include "exp/cpu_jit.pt" +git lfs pull --include "exp/pretrained.pt" +ln -s pretrained.pt epoch-99.pt +ls -lh *.pt +popd + +log "Export to torchscript model" +./pruned_transducer_stateless7_ctc/export.py \ + --exp-dir $repo/exp \ + --use-averaged-model false \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --jit 1 + +ls -lh $repo/exp/*.pt + +log "Decode with models exported by torch.jit.script()" + +./pruned_transducer_stateless7_ctc/jit_pretrained.py \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --nn-model-filename $repo/exp/cpu_jit.pt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --model-filename $repo/exp/cpu_jit.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless7_ctc/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless7_ctc/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc/pretrained_ctc.py \ + --checkpoint $repo/exp/pretrained.pt \ + --words-file $repo/data/lang_bpe_500/words.txt \ + --HLG $repo/data/lang_bpe_500/HLG.pt \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --G $repo/data/lm/G_4_gram.pt \ + --method $m \ + --sample-rate 16000 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p pruned_transducer_stateless7_ctc/exp + ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless7_ctc/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_bpe_500 data/ + + ls -lh data + ls -lh pruned_transducer_stateless7_ctc/exp + + log "Decoding test-clean and test-other" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./pruned_transducer_stateless7_ctc/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --use-averaged-model 0 \ + --max-duration $max_duration \ + --exp-dir pruned_transducer_stateless7_ctc/exp + done + + for m in ctc-decoding 1best; do + ./pruned_transducer_stateless7_ctc/ctc_decode.py \ + --epoch 999 \ + --avg 1 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration $max_duration \ + --use-averaged-model 0 \ + --decoding-method $m \ + --hlg-scale 0.6 \ + --lm-dir data/lm + done + + rm pruned_transducer_stateless7_ctc/exp/*.pt +fi diff --git a/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml new file mode 100644 index 000000000..ccd8d50d0 --- /dev/null +++ b/.github/workflows/run-librispeech-2022-12-01-stateless7-ctc.yml @@ -0,0 +1,163 @@ +# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-librispeech-2022-12-01-stateless7-ctc +# zipformer + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +jobs: + run_librispeech_2022_11_11_zipformer: + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Cache LibriSpeech test-clean and test-other datasets + id: libri-test-clean-and-test-other-data + uses: actions/cache@v2 + with: + path: | + ~/tmp/download + key: cache-libri-test-clean-and-test-other + + - name: Download LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-data.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/download-librispeech-test-clean-and-test-other-dataset.sh + + - name: Prepare manifests for LibriSpeech test-clean and test-other + shell: bash + run: | + .github/scripts/prepare-librispeech-test-clean-and-test-other-manifests.sh + + - name: Cache LibriSpeech test-clean and test-other fbank features + id: libri-test-clean-and-test-other-fbank + uses: actions/cache@v2 + with: + path: | + ~/tmp/fbank-libri + key: cache-libri-fbank-test-clean-and-test-other-v2 + + - name: Compute fbank for LibriSpeech test-clean and test-other + if: steps.libri-test-clean-and-test-other-fbank.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/compute-fbank-librispeech-test-clean-and-test-other.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree sox + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-pruned-transducer-stateless7-ctc-2022-12-01.sh + + - name: Display decoding results for librispeech pruned_transducer_stateless7_ctc + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/librispeech/ASR/ + tree ./pruned_transducer_stateless7_ctc/exp + + cd pruned_transducer_stateless7_ctc + echo "results for pruned_transducer_stateless7_ctc" + echo "===greedy search===" + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===modified beam search===" + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===ctc decoding===" + find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/ctc-decoding -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + echo "===1best===" + find exp/1best -name "log-*" -exec grep -n --color "best for test-clean" {} + | sort -n -k2 + find exp/1best -name "log-*" -exec grep -n --color "best for test-other" {} + | sort -n -k2 + + - name: Upload decoding results for librispeech pruned_transducer_stateless7_ctc + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless7-ctc-2022-12-01 + path: egs/librispeech/ASR/pruned_transducer_stateless7_ctc/exp/ diff --git a/egs/librispeech/ASR/README.md b/egs/librispeech/ASR/README.md index e737d68bd..caa23a49f 100644 --- a/egs/librispeech/ASR/README.md +++ b/egs/librispeech/ASR/README.md @@ -23,6 +23,7 @@ The following table lists the differences among them. | `pruned_transducer_stateless5` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + more layers + random combiner| | `pruned_transducer_stateless6` | Conformer(modified) | Embedding + Conv1d | same as pruned_transducer_stateless4 + distillation with hubert| | `pruned_transducer_stateless7` | Zipformer | Embedding + Conv1d | First experiment with Zipformer from Dan| +| `pruned_transducer_stateless7_ctc` | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but with extra CTC head| | `pruned_transducer_stateless8` | Zipformer | Embedding + Conv1d | Same as pruned_transducer_stateless7, but using extra data from GigaSpeech| | `pruned_stateless_emformer_rnnt2` | Emformer(from torchaudio) | Embedding + Conv1d | Using Emformer from torchaudio for streaming ASR| | `conv_emformer_transducer_stateless` | ConvEmformer | Embedding + Conv1d | Using ConvEmformer for streaming ASR + mechanisms in reworked model | diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 0885fb9b6..9e5669f6d 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1,5 +1,84 @@ ## Results +### pruned_transducer_stateless7_ctc (zipformer with transducer loss and ctc loss) + +See for more details. + +[pruned_transducer_stateless7_ctc](./pruned_transducer_stateless7_ctc) + +The tensorboard log can be found at + + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +Number of model parameters: 70561891, i.e., 70.56 M + +| | test-clean | test-other | comment | +|--------------------------|------------|-------------|--------------------| +| greedy search | 2.23 | 5.19 | --epoch 30 --avg 8 | +| modified beam search | 2.21 | 5.12 | --epoch 30 --avg 8 | +| fast beam search | 2.23 | 5.18 | --epoch 30 --avg 8 | +| ctc decoding | 2.48 | 5.82 | --epoch 30 --avg 9 | +| 1best | 2.43 | 5.22 | --epoch 30 --avg 9 | +| nbest | 2.43 | 5.22 | --epoch 30 --avg 9 | +| nbest rescoring | 2.34 | 5.05 | --epoch 30 --avg 9 | +| whole lattice rescoring | 2.34 | 5.04 | --epoch 30 --avg 9 | + +The training commands are: +```bash +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_ctc/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --full-libri 1 \ + --use-fp16 1 \ + --max-duration 750 \ + --exp-dir pruned_transducer_stateless7_ctc/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --ctc-loss-scale 0.2 \ + --master-port 12535 +``` + +The decoding commands for the transducer branch are: +```bash +for m in greedy_search fast_beam_search modified_beam_search ; do + for epoch in 30; do + for avg in 8; do + ./pruned_transducer_stateless7_ctc/decode.py \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --feedforward-dims "1024,1024,2048,2048,1024" \ + --max-duration 600 \ + --decoding-method $m + done + done +done +``` + +The decoding commands for the ctc branch are: +```bash +for m in ctc-decoding nbest nbest-rescoring whole-lattice-rescoring; do + for epoch in 30; do + for avg in 9; do + ./pruned_transducer_stateless7_ctc/ctc_decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 100 \ + --decoding-method $m \ + --hlg-scale 0.6 \ + --lm-dir data/lm + done + done +done +``` + + ### LibriSpeech BPE training results (Conformer CTC, supporting delay penalty) #### [conformer_ctc3](./conformer_ctc3) diff --git a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py index c96defd23..5be898e37 100755 --- a/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc3/jit_pretrained.py @@ -23,40 +23,44 @@ Usage (for non-streaming mode): (1) ctc-decoding ./conformer_ctc3/pretrained.py \ - --checkpoint conformer_ctc3/exp/pretrained.pt \ + --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \ --bpe-model data/lang_bpe_500/bpe.model \ --method ctc-decoding \ --sample-rate 16000 \ - test_wavs/1089-134686-0001.wav + /path/to/foo.wav \ + /path/to/bar.wav (2) 1best ./conformer_ctc3/pretrained.py \ - --checkpoint conformer_ctc3/exp/pretrained.pt \ + --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ --method 1best \ --sample-rate 16000 \ - test_wavs/1089-134686-0001.wav + /path/to/foo.wav \ + /path/to/bar.wav (3) nbest-rescoring ./conformer_ctc3/pretrained.py \ - --checkpoint conformer_ctc3/exp/pretrained.pt \ + --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ --G data/lm/G_4_gram.pt \ --method nbest-rescoring \ --sample-rate 16000 \ - test_wavs/1089-134686-0001.wav + /path/to/foo.wav \ + /path/to/bar.wav (4) whole-lattice-rescoring ./conformer_ctc3/pretrained.py \ - --checkpoint conformer_ctc3/exp/pretrained.pt \ + --nn-model-filename ./conformer_ctc3/exp/cpu_jit.pt \ --HLG data/lang_bpe_500/HLG.pt \ --words-file data/lang_bpe_500/words.txt \ --G data/lm/G_4_gram.pt \ --method whole-lattice-rescoring \ --sample-rate 16000 \ - test_wavs/1089-134686-0001.wav + /path/to/foo.wav \ + /path/to/bar.wav """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/__init__.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py new file mode 100755 index 000000000..9c23e7d66 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/ctc_decode.py @@ -0,0 +1,818 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Liyong Guo, +# Quandong Wang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) ctc-decoding +./pruned_transducer_stateless7_ctc/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method ctc-decoding + +(2) 1best +./pruned_transducer_stateless7_ctc/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --decoding-method 1best + +(3) nbest +./pruned_transducer_stateless7_ctc/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --decoding-method 1best + +(4) nbest-rescoring +./pruned_transducer_stateless7_ctc/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --lm-dir data/lm \ + --decoding-method nbest-rescoring + +(5) whole-lattice-rescoring +./pruned_transducer_stateless7_ctc/ctc_decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --hlg-scale 0.8 \ + --lm-dir data/lm \ + --decoding-method whole-lattice-rescoring +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import ( + get_lattice, + nbest_decoding, + nbest_oracle, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="ctc-decoding", + help="""Decoding method. + Supported values are: + - (1) ctc-decoding. Use CTC decoding. It uses a sentence piece + model, i.e., lang_dir/bpe.model, to convert word pieces to words. + It needs neither a lexicon nor an n-gram LM. + - (2) 1best. Extract the best path from the decoding lattice as the + decoding result. + - (3) nbest. Extract n paths from the decoding lattice; the path + with the highest score is the decoding result. + - (4) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an n-gram LM (e.g., a 4-gram LM), the path with + the highest score is the decoding result. + - (5) whole-lattice-rescoring. Rescore the decoding lattice with an + n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice + is the decoding result. + you have trained an RNN LM using ./rnn_lm/train.py + - (6) nbest-oracle. Its WER is the lower bound of any n-best + rescoring method can achieve. Useful for debugging n-best + rescoring method. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help="""Number of paths for n-best based decoding method. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""The scale to be applied to `lattice.scores`. + It's needed if you use any kinds of n-best based rescoring. + Used only when "method" is one of the following values: + nbest, nbest-rescoring, and nbest-oracle + A smaller value results in more unique paths. + """, + ) + + parser.add_argument( + "--hlg-scale", + type=float, + default=0.8, + help="""The scale to be applied to `hlg.scores`. + """, + ) + + parser.add_argument( + "--lm-dir", + type=str, + default="data/lm", + help="""The n-gram LM dir. + It should contain either G_4_gram.pt or G_4_gram.fst.txt + """, + ) + + add_model_arguments(parser) + + return parser + + +def get_decoding_params() -> AttributeDict: + """Parameters for decoding.""" + params = AttributeDict( + { + "frame_shift_ms": 10, + "search_beam": 20, + "output_beam": 8, + "min_active_states": 30, + "max_active_states": 10000, + "use_double_scores": True, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + batch: dict, + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + - key: It indicates the setting used for decoding. For example, + if no rescoring is used, the key is the string `no_rescore`. + If LM rescoring is used, the key is the string `lm_scale_xxx`, + where `xxx` is the value of `lm_scale`. An example key is + `lm_scale_0.7` + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + + Args: + params: + It's the return value of :func:`get_params`. + + - params.decoding_method is "1best", it uses 1best decoding without LM rescoring. + - params.decoding_method is "nbest", it uses nbest decoding without LM rescoring. + - params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring. + - params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM + rescoring. + + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return the decoding result. See above description for the format of + the returned dict. Note: If it decodes to nothing, then return None. + """ + if HLG is not None: + device = HLG.device + else: + device = H.device + feature = batch["inputs"] + assert feature.ndim == 3 + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(feature, feature_lens) + nnet_output = model.ctc_output(encoder_out) + # nnet_output is (N, T, C) + + supervision_segments = torch.stack( + ( + supervisions["sequence_idx"], + supervisions["start_frame"] // params.subsampling_factor, + supervisions["num_frames"] // params.subsampling_factor, + ), + 1, + ).to(torch.int32) + + if H is None: + assert HLG is not None + decoding_graph = HLG + else: + assert HLG is None + assert bpe_model is not None + decoding_graph = H + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=decoding_graph, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.decoding_method == "ctc-decoding": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + # Note: `best_path.aux_labels` contains token IDs, not word IDs + # since we are using H, not HLG here. + # + # token_ids is a lit-of-list of IDs + token_ids = get_texts(best_path) + + # hyps is a list of str, e.g., ['xxx yyy zzz', ...] + hyps = bpe_model.decode(token_ids) + + # hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ] + hyps = [s.split() for s in hyps] + key = "ctc-decoding" + return {key: hyps} + + if params.decoding_method == "nbest-oracle": + # Note: You can also pass rescored lattices to it. + # We choose the HLG decoded lattice for speed reasons + # as HLG decoding is faster and the oracle WER + # is only slightly worse than that of rescored lattices. + best_path = nbest_oracle( + lattice=lattice, + num_paths=params.num_paths, + ref_texts=supervisions["text"], + word_table=word_table, + nbest_scale=params.nbest_scale, + oov="", + ) + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa + return {key: hyps} + + if params.decoding_method in ["1best", "nbest"]: + if params.decoding_method == "1best": + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + key = "no_rescore" + else: + best_path = nbest_decoding( + lattice=lattice, + num_paths=params.num_paths, + use_double_scores=params.use_double_scores, + nbest_scale=params.nbest_scale, + ) + key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa + + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + return {key: hyps} + + assert params.decoding_method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ] + + lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] + lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3] + lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0] + + if params.decoding_method == "nbest-rescoring": + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=lm_scale_list, + nbest_scale=params.nbest_scale, + ) + elif params.decoding_method == "whole-lattice-rescoring": + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=lm_scale_list, + ) + else: + assert False, f"Unsupported decoding method: {params.decoding_method}" + + ans = dict() + if best_path_dict is not None: + for lm_scale_str, best_path in best_path_dict.items(): + hyps = get_texts(best_path) + hyps = [[word_table[i] for i in ids] for ids in hyps] + ans[lm_scale_str] = hyps + else: + ans = None + return ans + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + HLG: Optional[k2.Fsa], + H: Optional[k2.Fsa], + bpe_model: Optional[spm.SentencePieceProcessor], + word_table: k2.SymbolTable, + G: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + HLG: + The decoding graph. Used only when params.decoding_method is NOT ctc-decoding. + H: + The ctc topo. Used only when params.decoding_method is ctc-decoding. + bpe_model: + The BPE model. Used only when params.decoding_method is ctc-decoding. + word_table: + It is the word symbol table. + G: + An LM. It is not None when params.decoding_method is "nbest-rescoring" + or "whole-lattice-rescoring". In general, the G in HLG + is a 3-gram LM, while this G is a 4-gram LM. + Returns: + Return a dict, whose key may be "no-rescore" if no LM rescoring + is used, or it may be "lm_scale_0.7" if LM rescoring is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + batch=batch, + word_table=word_table, + G=G, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % 100 == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats(f, f"{test_set_name}-{key}", results) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + args.lm_dir = Path(args.lm_dir) + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + assert params.decoding_method in ( + "ctc-decoding", + "1best", + "nbest", + "nbest-rescoring", + "whole-lattice-rescoring", + "nbest-oracle", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + logging.info(params) + + lexicon = Lexicon(params.lang_dir) + max_token_id = max(lexicon.tokens) + num_classes = max_token_id + 1 # +1 for the blank + + params.vocab_size = num_classes + # and are defined in local/train_bpe_model.py + params.blank_id = 0 + + if params.decoding_method == "ctc-decoding": + HLG = None + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(str(params.lang_dir / "bpe.model")) + else: + H = None + bpe_model = None + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location=device) + ) + assert HLG.requires_grad is False + + HLG.scores *= params.hlg_scale + if not hasattr(HLG, "lm_scores"): + HLG.lm_scores = HLG.scores.clone() + + if params.decoding_method in ( + "nbest-rescoring", + "whole-lattice-rescoring", + ): + if not (params.lm_dir / "G_4_gram.pt").is_file(): + logging.info("Loading G_4_gram.fst.txt") + logging.warning("It may take 8 minutes.") + with open(params.lm_dir / "G_4_gram.fst.txt") as f: + first_word_disambig_id = lexicon.word_table["#0"] + + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION: The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set G.properties to None + G.__dict__["_properties"] = None + G = k2.Fsa.from_fsas([G]).to(device) + G = k2.arc_sort(G) + # Save a dummy value so that it can be loaded in C++. + # See https://github.com/pytorch/pytorch/issues/67902 + # for why we need to do this. + G.dummy = 1 + + torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt") + else: + logging.info("Loading pre-compiled G_4_gram.pt") + d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) + G = k2.Fsa.from_dict(d) + + if params.decoding_method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + else: + G = None + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + HLG=HLG, + H=H, + bpe_model=bpe_model, + word_table=lexicon.word_table, + G=G, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py new file mode 100755 index 000000000..32a9b6bb2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decode.py @@ -0,0 +1,841 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7_ctc/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7_ctc/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7_ctc/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7_ctc/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7_ctc/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7_ctc/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7_ctc/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decoder.py new file mode 120000 index 000000000..33944d0d2 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/decoder.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/encoder_interface.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py new file mode 100755 index 000000000..59a393739 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/export.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.script() + +./pruned_transducer_stateless7_ctc/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 9 \ + --jit 1 + +It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later +load it by `torch.jit.load("cpu_jit.pt")`. + +Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python +are on CPU. You can use `to("cuda")` to move them to a CUDA device. + +Check +https://github.com/k2-fsa/sherpa +for how to use the exported models outside of icefall. + +(2) Export `model.state_dict()` + +./pruned_transducer_stateless7_ctc/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `pruned_transducer_stateless7_ctc/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless7_ctc/decode.py \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + It will generate a file named cpu_jit.pt + + Check ./jit_pretrained.py for how to use it. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + model.to(device) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit is True: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.script()") + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torchscript. Export model.state_dict()") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py new file mode 100755 index 000000000..280b95984 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +./pruned_transducer_stateless7_ctc/jit_pretrained.py \ + --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + model: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + context_size = model.decoder.context_size + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.nn_model_filename) + + model.eval() + + model.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + + hyps = greedy_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py new file mode 100755 index 000000000..d3343d34a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py @@ -0,0 +1,423 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +Usage of this script: + +(1) ctc-decoding +./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(3) nbest-rescoring +./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(4) whole-lattice-rescoring +./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --nn-model-filename ./pruned_transducer_stateless7_ctc/exp/cpu_jit.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from ctc_decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import get_params + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--model-filename", + type=str, + required=True, + help="Path to the torchscript model.", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = torch.jit.load(args.model_filename) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + nnet_output = model.ctc_output(encoder_out) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/joiner.py new file mode 120000 index 000000000..ecfb6dd8a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/joiner.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py new file mode 100644 index 000000000..a6e919e2f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/model.py @@ -0,0 +1,198 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface + +from icefall.utils import add_sos + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = nn.Linear( + encoder_dim, + vocab_size, + ) + self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) + + self.ctc_output = nn.Sequential( + nn.Dropout(p=0.1), + nn.Linear(encoder_dim, vocab_size), + nn.LogSoftmax(dim=-1), + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return a tuple containing simple loss, pruned loss, and ctc-output. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens = self.encoder(x, x_lens) + assert torch.all(x_lens > 0) + + # compute ctc log-probs + ctc_output = self.ctc_output(encoder_out) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss, ctc_output) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/optim.py new file mode 120000 index 000000000..81ac4a89a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/optim.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py new file mode 100755 index 000000000..2f1b1a49f --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained.py @@ -0,0 +1,353 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7_ctc/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7_ctc/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless7_ctc/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless7_ctc/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless7_ctc/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless7_ctc/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless7_ctc/exp/pretrained.pt is generated by +./pruned_transducer_stateless7_ctc/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py new file mode 100755 index 000000000..74aef1bc7 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/pretrained_ctc.py @@ -0,0 +1,441 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7_ctc/export.py \ + --exp-dir ./pruned_transducer_stateless7_ctc/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) ctc-decoding +./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --method ctc-decoding \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) 1best +./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --method 1best \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) nbest-rescoring +./bruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method nbest-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav + + +(4) whole-lattice-rescoring +./pruned_transducer_stateless7_ctc/jit_pretrained_ctc.py \ + --checkpoint ./pruned_transducer_stateless7_ctc/exp/pretrained.pt \ + --HLG data/lang_bpe_500/HLG.pt \ + --words-file data/lang_bpe_500/words.txt \ + --G data/lm/G_4_gram.pt \ + --method whole-lattice-rescoring \ + --sample-rate 16000 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from ctc_decode import get_decoding_params +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_n_best_list, + rescore_with_whole_lattice, +) +from icefall.utils import get_texts + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " "2 means tri-gram", + ) + + parser.add_argument( + "--words-file", + type=str, + help="""Path to words.txt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--HLG", + type=str, + help="""Path to HLG.pt. + Used only when method is not ctc-decoding. + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model. + Used only when method is ctc-decoding. + """, + ) + + parser.add_argument( + "--method", + type=str, + default="1best", + help="""Decoding method. + Possible values are: + (0) ctc-decoding - Use CTC decoding. It uses a sentence + piece model, i.e., lang_dir/bpe.model, to convert + word pieces to words. It needs neither a lexicon + nor an n-gram LM. + (1) 1best - Use the best path as decoding output. Only + the transformer encoder output is used for decoding. + We call it HLG decoding. + (2) nbest-rescoring. Extract n paths from the decoding lattice, + rescore them with an LM, the path with + the highest score is the decoding result. + We call it HLG decoding + n-gram LM rescoring. + (3) whole-lattice-rescoring - Use an LM to rescore the + decoding lattice and then use 1best to decode the + rescored lattice. + We call it HLG decoding + n-gram LM rescoring. + """, + ) + + parser.add_argument( + "--G", + type=str, + help="""An LM for rescoring. + Used only when method is + whole-lattice-rescoring or nbest-rescoring. + It's usually a 4-gram LM. + """, + ) + + parser.add_argument( + "--num-paths", + type=int, + default=100, + help=""" + Used only when method is attention-decoder. + It specifies the size of n-best list.""", + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=1.3, + help=""" + Used only when method is whole-lattice-rescoring and nbest-rescoring. + It specifies the scale for n-gram LM scores. + (Note: You need to tune it on a dataset.) + """, + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help=""" + Used only when method is nbest-rescoring. + It specifies the scale for lattice.scores when + extracting n-best lists. A smaller value results in + more unique number of paths with the risk of missing + the best path. + """, + ) + + parser.add_argument( + "--num-classes", + type=int, + default=500, + help=""" + Vocab size in the BPE model. + """, + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float = 16000 +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + # add decoding params + params.update(get_decoding_params()) + params.update(vars(args)) + params.vocab_size = params.num_classes + params.blank_id = 0 + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder( + x=features, + x_lens=feature_lengths, + ) + nnet_output = model.ctc_output(encoder_out) + + batch_size = nnet_output.shape[0] + supervision_segments = torch.tensor( + [[i, 0, nnet_output.shape[1]] for i in range(batch_size)], + dtype=torch.int32, + ) + + if params.method == "ctc-decoding": + logging.info("Use CTC decoding") + bpe_model = spm.SentencePieceProcessor() + bpe_model.load(params.bpe_model) + max_token_id = params.num_classes - 1 + + H = k2.ctc_topo( + max_token=max_token_id, + modified=False, + device=device, + ) + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=H, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + token_ids = get_texts(best_path) + hyps = bpe_model.decode(token_ids) + hyps = [s.split() for s in hyps] + elif params.method in [ + "1best", + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading HLG from {params.HLG}") + HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu")) + HLG = HLG.to(device) + if not hasattr(HLG, "lm_scores"): + # For whole-lattice-rescoring and attention-decoder + HLG.lm_scores = HLG.scores.clone() + + if params.method in [ + "nbest-rescoring", + "whole-lattice-rescoring", + ]: + logging.info(f"Loading G from {params.G}") + G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu")) + G = G.to(device) + if params.method == "whole-lattice-rescoring": + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + + # G.lm_scores is used to replace HLG.lm_scores during + # LM rescoring. + G.lm_scores = G.scores.clone() + + lattice = get_lattice( + nnet_output=nnet_output, + decoding_graph=HLG, + supervision_segments=supervision_segments, + search_beam=params.search_beam, + output_beam=params.output_beam, + min_active_states=params.min_active_states, + max_active_states=params.max_active_states, + subsampling_factor=params.subsampling_factor, + ) + + if params.method == "1best": + logging.info("Use HLG decoding") + best_path = one_best_decoding( + lattice=lattice, use_double_scores=params.use_double_scores + ) + if params.method == "nbest-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_n_best_list( + lattice=lattice, + G=G, + num_paths=params.num_paths, + lm_scale_list=[params.ngram_lm_scale], + nbest_scale=params.nbest_scale, + ) + best_path = next(iter(best_path_dict.values())) + elif params.method == "whole-lattice-rescoring": + logging.info("Use HLG decoding + LM rescoring") + best_path_dict = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=[params.ngram_lm_scale], + ) + best_path = next(iter(best_path_dict.values())) + + hyps = get_texts(best_path) + word_sym_table = k2.SymbolTable.from_file(params.words_file) + hyps = [[word_sym_table[i] for i in ids] for ids in hyps] + else: + raise ValueError(f"Unsupported decoding method: {params.method}") + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling.py new file mode 120000 index 000000000..2428b74b9 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling_converter.py new file mode 120000 index 000000000..b8b8ba432 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/test_model.py new file mode 100755 index 000000000..e482d2040 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/test_model.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless7_ctc/test_model.py +""" + +from train import get_params, get_transducer_model + + +def test_model_1(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = "2,4,3,2,4" + # params.feedforward_dims = "1024,1024,1536,1536,1024" + params.feedforward_dims = "1024,1024,2048,2048,1024" + params.nhead = "8,8,8,8,8" + params.encoder_dims = "384,384,384,384,384" + params.attention_dims = "192,192,192,192,192" + params.encoder_unmasked_dims = "256,256,256,256,256" + params.zipformer_downsampling_factors = "1,2,4,8,2" + params.cnn_module_kernels = "31,31,31,31,31" + params.decoder_dim = 512 + params.joiner_dim = 512 + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + +def main(): + test_model_1() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py new file mode 100755 index 000000000..abfd56e5a --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -0,0 +1,1252 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7_ctc/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7_ctc/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7_ctc/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7_ctc/exp \ + --full-libri 1 \ + --max-duration 550 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import ( + AttributeDict, + MetricsTracker, + encode_supervisions, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7_ctc/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--ctc-loss-scale", + type=float, + default=0.2, + help="Scale for CTC loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + # parameters for ctc loss + "beam_size": 10, + "use_double_scores": True, + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + token_ids = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(token_ids).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss, ctc_output = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + # Compute ctc loss + + # NOTE: We need `encode_supervisions` to sort sequences with + # different duration in decreasing order, required by + # `k2.intersect_dense` called in `k2.ctc_loss` + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + supervision_segments, token_ids = encode_supervisions( + supervisions, + subsampling_factor=params.subsampling_factor, + token_ids=token_ids, + ) + + # Works with a BPE model + decoding_graph = k2.ctc_graph(token_ids, modified=False, device=device) + dense_fsa_vec = k2.DenseFsaVec( + ctc_output, + supervision_segments, + allow_truncate=params.subsampling_factor - 1, + ) + + ctc_loss = k2.ctc_loss( + decoding_graph=decoding_graph, + dense_fsa_vec=dense_fsa_vec, + output_beam=params.beam_size, + reduction="sum", + use_double_scores=params.use_double_scores, + ) + assert ctc_loss.requires_grad == is_training + loss += params.ctc_loss_scale * ctc_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + info["ctc_loss"] = ctc_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/zipformer.py new file mode 120000 index 000000000..79b076556 --- /dev/null +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/zipformer.py @@ -0,0 +1 @@ +../pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/icefall/utils.py b/icefall/utils.py index d852491c8..99e51a2a9 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -175,11 +175,13 @@ class AttributeDict(dict): def encode_supervisions( - supervisions: dict, subsampling_factor: int -) -> Tuple[torch.Tensor, List[str]]: + supervisions: dict, + subsampling_factor: int, + token_ids: Optional[List[List[int]]] = None, +) -> Tuple[torch.Tensor, Union[List[str], List[List[int]]]]: """ Encodes Lhotse's ``batch["supervisions"]`` dict into - a pair of torch Tensor, and a list of transcription strings. + a pair of torch Tensor, and a list of transcription strings or token indexes The supervision tensor has shape ``(batch_size, 3)``. Its second dimension contains information about sequence index [0], @@ -208,10 +210,14 @@ def encode_supervisions( indices = torch.argsort(supervision_segments[:, 2], descending=True) supervision_segments = supervision_segments[indices] - texts = supervisions["text"] - texts = [texts[idx] for idx in indices] - return supervision_segments, texts + if token_ids is None: + texts = supervisions["text"] + res = [texts[idx] for idx in indices] + else: + res = [token_ids[idx] for idx in indices] + + return supervision_segments, res def get_texts( From e6a67270128f607f49c81327190aca63bb3bb4eb Mon Sep 17 00:00:00 2001 From: Senyan Li <1149593720@qq.com> Date: Sat, 3 Dec 2022 23:50:49 +0800 Subject: [PATCH 18/27] Add Tibetan Amdo dialect xbmu_amdo31 in egs (#706) * add egs/xbmu_amdo31 * fix xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py * fix xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py * fix xbmu_amdo31/ASR/prepare.sh * add RESULTS.md and README.md * dix pruned_transducer_stateless5 decode.py * add transducer stateless7 * fix transducer_stateless7 * fix RESULTS.md error * Add pruned_transducer_stateless7 validation set results --- egs/xbmu_amdo31/ASR/README.md | 16 + egs/xbmu_amdo31/ASR/RESULTS.md | 92 ++ egs/xbmu_amdo31/ASR/local/compile_hlg.py | 1 + egs/xbmu_amdo31/ASR/local/compile_lg.py | 1 + .../ASR/local/compute_fbank_musan.py | 1 + .../ASR/local/compute_fbank_xbmu_amdo31.py | 130 ++ .../convert_transcript_words_to_tokens.py | 1 + egs/xbmu_amdo31/ASR/local/filter_cuts.py | 1 + .../ASR/local/generate_unique_lexicon.py | 1 + egs/xbmu_amdo31/ASR/local/prepare_lang.py | 1 + egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py | 1 + .../ASR/local/prepare_lm_training_data.py | 1 + .../ASR/local/sort_lm_training_data.py | 1 + egs/xbmu_amdo31/ASR/local/train_bpe_model.py | 1 + .../ASR/local/validate_bpe_lexicon.py | 1 + egs/xbmu_amdo31/ASR/prepare.sh | 357 +++++ .../pruned_transducer_stateless5/__init__.py | 0 .../asr_datamodule.py | 408 ++++++ .../beam_search.py | 1 + .../pruned_transducer_stateless5/conformer.py | 1 + .../pruned_transducer_stateless5/decode.py | 970 +++++++++++++ .../decode_stream.py | 1 + .../pruned_transducer_stateless5/decoder.py | 1 + .../encoder_interface.py | 1 + .../pruned_transducer_stateless5/export.py | 287 ++++ .../pruned_transducer_stateless5/joiner.py | 1 + .../ASR/pruned_transducer_stateless5/lstmp.py | 1 + .../ASR/pruned_transducer_stateless5/model.py | 1 + .../ASR/pruned_transducer_stateless5/optim.py | 1 + .../pretrained.py | 344 +++++ .../pruned_transducer_stateless5/scaling.py | 1 + .../scaling_converter.py | 1 + .../streaming_beam_search.py | 1 + .../streaming_decode.py | 1 + .../test_model.py | 65 + .../ASR/pruned_transducer_stateless5/train.py | 1187 ++++++++++++++++ .../pruned_transducer_stateless7/__init__.py | 0 .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../pruned_transducer_stateless7/decode.py | 843 ++++++++++++ .../pruned_transducer_stateless7/decoder.py | 1 + .../encoder_interface.py | 1 + .../pruned_transducer_stateless7/export.py | 1 + .../jit_pretrained.py | 1 + .../pruned_transducer_stateless7/joiner.py | 1 + .../ASR/pruned_transducer_stateless7/model.py | 1 + .../ASR/pruned_transducer_stateless7/optim.py | 1 + .../pretrained.py | 355 +++++ .../pruned_transducer_stateless7/scaling.py | 1 + .../scaling_converter.py | 1 + .../test_model.py | 1 + .../ASR/pruned_transducer_stateless7/train.py | 1224 +++++++++++++++++ .../pruned_transducer_stateless7/zipformer.py | 1 + egs/xbmu_amdo31/ASR/shared | 1 + 54 files changed, 6317 insertions(+) create mode 100644 egs/xbmu_amdo31/ASR/README.md create mode 100644 egs/xbmu_amdo31/ASR/RESULTS.md create mode 120000 egs/xbmu_amdo31/ASR/local/compile_hlg.py create mode 120000 egs/xbmu_amdo31/ASR/local/compile_lg.py create mode 120000 egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py create mode 100755 egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py create mode 120000 egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py create mode 120000 egs/xbmu_amdo31/ASR/local/filter_cuts.py create mode 120000 egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py create mode 120000 egs/xbmu_amdo31/ASR/local/prepare_lang.py create mode 120000 egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py create mode 120000 egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py create mode 120000 egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py create mode 120000 egs/xbmu_amdo31/ASR/local/train_bpe_model.py create mode 120000 egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py create mode 100755 egs/xbmu_amdo31/ASR/prepare.sh create mode 100644 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/__init__.py create mode 100644 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py create mode 100644 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/__init__.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py create mode 100755 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py create mode 120000 egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py create mode 120000 egs/xbmu_amdo31/ASR/shared diff --git a/egs/xbmu_amdo31/ASR/README.md b/egs/xbmu_amdo31/ASR/README.md new file mode 100644 index 000000000..0a441d070 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/README.md @@ -0,0 +1,16 @@ +# Introduction +About the XBMU-AMDO31 corpus +XBMU-AMDO31 is an open-source Amdo Tibetan speech corpus published by Northwest Minzu University. +publicly available on https://huggingface.co/datasets/syzym/xbmu_amdo31 + +XBMU-AMDO31 dataset is a speech recognition corpus of Amdo Tibetan dialect. +The open source corpus contains 31 hours of speech data and resources related +to build speech recognition systems,including transcribed texts and a Tibetan +pronunciation lexicon. +(The lexicon is a Tibetan lexicon of the Lhasa dialect, which has been reused +for the Amdo dialect because of the uniformity of the Tibetan language) +The dataset can be used to train a model for Amdo Tibetan Automatic Speech Recognition (ASR). + +This recipe includes some different ASR models trained with XBMU-AMDO31. + +[./RESULTS.md](./RESULTS.md) contains the latest results. \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/RESULTS.md b/egs/xbmu_amdo31/ASR/RESULTS.md new file mode 100644 index 000000000..1bd9b2e2b --- /dev/null +++ b/egs/xbmu_amdo31/ASR/RESULTS.md @@ -0,0 +1,92 @@ +## Results + +### XBMU-AMDO31 BPE training result (Stateless Transducer) + +#### Pruned transducer stateless 5 + +[./pruned_transducer_stateless5](./pruned_transducer_stateless5) + +It uses pruned RNN-T. + +A pre-trained model and decoding logs can be found at + +You can use to deploy it. + +Number of model parameters: 87801200, i.e., 87.8 M + +| | test | dev | comment | +|------------------------|------|------|---------------------------------------| +| greedy search | 11.06| 11.73| --epoch 28 --avg 23 --max-duration 600| +| beam search | 10.64| 11.42| --epoch 28 --avg 23 --max-duration 600| +| modified beam search | 10.57| 11.24| --epoch 28 --avg 23 --max-duration 600| + + +Training command is: + +```bash +cd egs/xbmu_amdo31/ASR +./prepare.sh + +export CUDA_VISIBLE_DEVICES="0" + +./pruned_transducer_stateless5/train.py +``` + +**Caution**: It uses `--context-size=1`. + + +The decoding command is: +```bash +for method in greedy_search beam_search modified_beam_search; +do +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 23 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method $method +done +``` + +### pruned_transducer_stateless7 (zipformer) + +See for more details. + +[pruned_transducer_stateless7](./pruned_transducer_stateless7) + +You can find a pretrained model, training logs, decoding logs, and decoding +results at: + + +You can use to deploy it. + +Number of model parameters: 70369391, i.e., 70.37 M + +| | test | dev | comment | +|----------------------|------|------|----------------------------------------| +| greedy search | 10.06| 10.59| --epoch 23 --avg 11 --max-duration 600 | +| beam search | 9.77 | 10.11| --epoch 23 --avg 11 --max-duration 600 | +| modified beam search | 9.7 | 10.12| --epoch 23 --avg 11 --max-duration 600 | + +The training commands are: +```bash +export CUDA_VISIBLE_DEVICES="0" + +./pruned_transducer_stateless7/train.py +``` + +The decoding commands are: +```bash +for m in greedy_search beam_search modified_beam_search; do + for epoch in 23; do + for avg in 11; do + ./pruned_transducer_stateless7/decode.py \ + --epoch $epoch \ + --avg $avg \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method $m + done + done +done +``` diff --git a/egs/xbmu_amdo31/ASR/local/compile_hlg.py b/egs/xbmu_amdo31/ASR/local/compile_hlg.py new file mode 120000 index 000000000..471aa7fb4 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/local/compile_hlg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_hlg.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/compile_lg.py b/egs/xbmu_amdo31/ASR/local/compile_lg.py new file mode 120000 index 000000000..462d6d3fb --- /dev/null +++ b/egs/xbmu_amdo31/ASR/local/compile_lg.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compile_lg.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py b/egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py new file mode 120000 index 000000000..5833f2484 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/local/compute_fbank_musan.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/compute_fbank_musan.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py b/egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py new file mode 100755 index 000000000..a593e7be3 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/local/compute_fbank_xbmu_amdo31.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the XBMU-AMDO31 dataset. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" + +import argparse +import logging +import os +from pathlib import Path +from typing import Optional + +import sentencepiece as spm +import torch +from filter_cuts import filter_cuts +from lhotse import CutSet, Fbank, FbankConfig, LilcomChunkyWriter +from lhotse.recipes.utils import read_manifests_if_cached + +from icefall.utils import get_executor + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to the bpe.model. If not None, we will remove short and + long utterances before extracting features""", + ) + return parser.parse_args() + + +def compute_fbank_xbmu_amdo31(bpe_model: Optional[str] = None): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + num_jobs = min(15, os.cpu_count()) + num_mel_bins = 80 + + if bpe_model: + logging.info(f"Loading {bpe_model}") + sp = spm.SentencePieceProcessor() + sp.load(bpe_model) + + dataset_parts = ( + "train", + "dev", + "test", + ) + prefix = "xbmu_amdo31" + suffix = "jsonl.gz" + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=src_dir, + prefix=prefix, + suffix=suffix, + ) + assert manifests is not None + + assert len(manifests) == len(dataset_parts), ( + len(manifests), + len(dataset_parts), + list(manifests.keys()), + dataset_parts, + ) + + extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) + + with get_executor() as ex: # Initialize the executor only once. + for partition, m in manifests.items(): + cuts_filename = f"{prefix}_cuts_{partition}.{suffix}" + if (output_dir / cuts_filename).is_file(): + logging.info(f"{partition} already exists - skipping.") + continue + logging.info(f"Processing {partition}") + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + if bpe_model: + cut_set = filter_cuts(cut_set, sp) + + if "train" in partition: + cut_set = ( + cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + ) + cut_set = cut_set.compute_and_store_features( + extractor=extractor, + storage_path=f"{output_dir}/{prefix}_feats_{partition}", + # when an executor is specified, make more partitions + num_jobs=num_jobs if ex is None else 80, + executor=ex, + storage_type=LilcomChunkyWriter, + ) + cut_set.to_file(output_dir / cuts_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() + logging.info(vars(args)) + compute_fbank_xbmu_amdo31(bpe_model=args.bpe_model) diff --git a/egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py b/egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py new file mode 120000 index 000000000..2ce13fd69 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/local/convert_transcript_words_to_tokens.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/convert_transcript_words_to_tokens.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/filter_cuts.py b/egs/xbmu_amdo31/ASR/local/filter_cuts.py new file mode 120000 index 000000000..27aca1729 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/local/filter_cuts.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/filter_cuts.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py b/egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py new file mode 120000 index 000000000..c0aea1403 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/local/generate_unique_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/generate_unique_lexicon.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/prepare_lang.py b/egs/xbmu_amdo31/ASR/local/prepare_lang.py new file mode 120000 index 000000000..747f2ab39 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/local/prepare_lang.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py b/egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py new file mode 120000 index 000000000..36b40e7fc --- /dev/null +++ b/egs/xbmu_amdo31/ASR/local/prepare_lang_bpe.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lang_bpe.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py b/egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py new file mode 120000 index 000000000..abc00d421 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/local/prepare_lm_training_data.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/prepare_lm_training_data.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py b/egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py new file mode 120000 index 000000000..1d6ccbe33 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/local/sort_lm_training_data.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/sort_lm_training_data.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/train_bpe_model.py b/egs/xbmu_amdo31/ASR/local/train_bpe_model.py new file mode 120000 index 000000000..6fad36421 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/local/train_bpe_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py b/egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py new file mode 120000 index 000000000..721bb48e7 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/local/validate_bpe_lexicon.py @@ -0,0 +1 @@ +../../../librispeech/ASR/local/validate_bpe_lexicon.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/prepare.sh b/egs/xbmu_amdo31/ASR/prepare.sh new file mode 100755 index 000000000..32ae440f7 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/prepare.sh @@ -0,0 +1,357 @@ +#!/usr/bin/env bash + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 + +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/xbmu_amdo31 +# You can find data, resource, etc, inside it. +# You can download them from https://huggingface.co/datasets/syzym/xbmu_amdo31 +# +# - $dl_dir/lm +# This directory contains the following files downloaded from +# git lfs install +# https://huggingface.co/syzym/xbmu_amdo31_lm +# +# - tibetan.3-gram.arpa +# - tibetan.4-gram.arpa +# +# - $dl_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech + +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/lang_bpe_xxx, +# data/lang_bpe_yyy if the array contains xxx, yyy +vocab_sizes=( + 1000 + 500 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "stage -1: Download LM" + # We assume that you have installed the git-lfs, if not, you could install it + # using: `sudo apt-get install git-lfs && git-lfs install` + git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1) + + if [ ! -f $dl_dir/lm/3-gram.unpruned.arpa ]; then + git clone https://huggingface.co/syzym/xbmu_amdo31_lm $dl_dir/lm + pushd $dl_dir/lm + git lfs pull --include "tibetan.3-gram.arpa" + git lfs pull --include "tibetan.4-gram.arpa" + popd + fi +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Download data" + + # If you have pre-downloaded it to /path/to/xbmu_amdo31, + # you can create a symlink + # + # ln -sfv /path/to/xbmu_amdo31 $dl_dir/xbmu_amdo31 + # + + if [ ! -f $dl_dir/xbmu_amdo31 ]; then + git lfs 1>/dev/null 2>&1 || (echo "please install git-lfs, consider using: sudo apt-get install git-lfs && git-lfs install" && exit 1) + lhotse download xbmu-amdo31 $dl_dir + fi + + # If you have pre-downloaded it to /path/to/musan, + # you can create a symlink + # + # ln -sfv /path/to/musan $dl_dir/ + # + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir + fi +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare xbmu_amdo31 manifest" + # We assume that you have downloaded the xbmu_amdo31 corpus + # to $dl_dir/xbmu_amdo31 + if [ ! -f data/manifests/.xbmu_amdo31_manifests.done ]; then + mkdir -p data/manifests + lhotse prepare xbmu-amdo31 $dl_dir/xbmu_amdo31 data/manifests + touch data/manifests/.xbmu_amdo31_manifests.done + fi +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Prepare musan manifest" + # We assume that you have downloaded the musan corpus + # to data/musan + if [ ! -f data/manifests/.musan_manifests.done ]; then + log "It may take 6 minutes" + mkdir -p data/manifests + lhotse prepare musan $dl_dir/musan data/manifests + touch data/manifests/.musan_manifests.done + fi +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Compute fbank for xbmu_amdo31" + if [ ! -f data/fbank/.xbmu_amdo31.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_xbmu_amdo31.py + touch data/fbank/.xbmu_amdo31.done + fi +fi + + + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Compute fbank for musan" + if [ ! -f data/fbank/.msuan.done ]; then + mkdir -p data/fbank + ./local/compute_fbank_musan.py + touch data/fbank/.msuan.done + fi +fi + + +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare phone based lang" + lang_dir=data/lang_phone + mkdir -p $lang_dir + + (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | + cat - $dl_dir/xbmu_amdo31/resource/lexicon.txt | + sort | uniq > $lang_dir/lexicon.txt + + ./local/generate_unique_lexicon.py --lang-dir $lang_dir + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang.py --lang-dir $lang_dir + fi +fi + + +if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then + log "Stage 6: Prepare BPE based lang" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + mkdir -p $lang_dir + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt $lang_dir + + if [ ! -f $lang_dir/transcript_words.txt ]; then + log "Generate data to train phone based bigram P" + xbmu_amdo31_text=$dl_dir/xbmu_amdo31/data/transcript/transcript_clean.txt + xbmu_amdo31_train_uid=$dl_dir/xbmu_amdo31/data/transcript/xbmu_amdo31_train_uid + find $dl_dir/xbmu_amdo31/data/wav/train -name "*.wav" | sed 's/\.wav//g' | awk -F '-' '{print $NF}' > $xbmu_amdo31_train_uid + awk 'NR==FNR{uid[$1]=$1} NR!=FNR{if($1 in uid) print $0}' $xbmu_amdo31_train_uid $xbmu_amdo31_text | + cut -d " " -f 2- > $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/bpe.model ]; then + ./local/train_bpe_model.py \ + --lang-dir $lang_dir \ + --vocab-size $vocab_size \ + --transcript $lang_dir/transcript_words.txt + fi + + if [ ! -f $lang_dir/L_disambig.pt ]; then + ./local/prepare_lang_bpe.py --lang-dir $lang_dir + + log "Validating $lang_dir/lexicon.txt" + ./local/validate_bpe_lexicon.py \ + --lexicon $lang_dir/lexicon.txt \ + --bpe-model $lang_dir/bpe.model + fi + done +fi + +if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then + log "Stage 7: Prepare bigram P" + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + + if [ ! -f $lang_dir/transcript_tokens.txt ]; then + ./local/convert_transcript_words_to_tokens.py \ + --lexicon $lang_dir/lexicon.txt \ + --transcript $lang_dir/transcript_words.txt \ + --oov "" \ + > $lang_dir/transcript_tokens.txt + fi + + if [ ! -f $lang_dir/P.arpa ]; then + ./shared/make_kn_lm.py \ + -ngram-order 2 \ + -text $lang_dir/transcript_tokens.txt \ + -lm $lang_dir/P.arpa + fi + + if [ ! -f $lang_dir/P.fst.txt ]; then + python3 -m kaldilm \ + --read-symbol-table="$lang_dir/tokens.txt" \ + --disambig-symbol='#0' \ + --max-order=2 \ + $lang_dir/P.arpa > $lang_dir/P.fst.txt + fi + done +fi + +if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then + log "Stage 8: Prepare G" + # We assume you have install kaldilm, if not, please install + # it using: pip install kaldilm + + mkdir -p data/lm + if [ ! -f data/lm/G_3_gram.fst.txt ]; then + # It is used in building HLG + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=3 \ + $dl_dir/lm/tibetan.3-gram.arpa > data/lm/G_3_gram.fst.txt + fi + + if [ ! -f data/lm/G_4_gram.fst.txt ]; then + # It is used for LM rescoring + python3 -m kaldilm \ + --read-symbol-table="data/lang_phone/words.txt" \ + --disambig-symbol='#0' \ + --max-order=4 \ + $dl_dir/lm/tibetan.4-gram.arpa > data/lm/G_4_gram.fst.txt + fi +fi + +if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then + log "Stage 9: Compile HLG" + ./local/compile_hlg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_hlg.py --lang-dir $lang_dir + done +fi + +# Compile LG for RNN-T fast_beam_search decoding +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Compile LG" + ./local/compile_lg.py --lang-dir data/lang_phone + + for vocab_size in ${vocab_sizes[@]}; do + lang_dir=data/lang_bpe_${vocab_size} + ./local/compile_lg.py --lang-dir $lang_dir + done +fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Generate LM training data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + lang_dir=data/lang_bpe_${vocab_size} + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $dl_dir/lm/lm_train.txt \ + --lm-archive $out_dir/lm_data.pt + done +fi + +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 12: Generate LM validation data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/valid.txt ]; then + files=$dl_dir/xbmu_amdo31/data/transcript/dev_text + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > $out_dir/valid.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/valid.txt \ + --lm-archive $out_dir/lm_data-valid.pt + done +fi + +if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then + log "Stage 13: Generate LM test data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/test.txt ]; then + files=$dl_dir/xbmu_amdo31/data/transcript/test_text + cat $f | cut -d " " -f 2- > $out_dir/test.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/test.txt \ + --lm-archive $out_dir/lm_data-test.pt + done +fi + +if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then + log "Stage 14: Sort LM training data" + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of BPE tokens + # in a sentence. + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-valid.pt \ + --out-lm-data $out_dir/sorted_lm_data-valid.pt \ + --out-statistics $out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-test.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt + done +fi diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/__init__.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py new file mode 100644 index 000000000..55d5f4636 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -0,0 +1,408 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2022 Northwest Minzu University (Author: Senyan Li) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import inspect +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy +from lhotse.dataset import CutConcatenate # noqa F401 for PrecomputedFeatures +from lhotse.dataset import ( + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SingleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import AudioSamples # noqa F401 For AudioSamples +from lhotse.dataset.input_strategies import OnTheFlyFeatures +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class Xbmu_AmdoAsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/fbank"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + group.add_argument( + "--concatenate-cuts", + type=str2bool, + default=False, + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", + ) + group.add_argument( + "--duration-factor", + type=float, + default=1.0, + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", + ) + group.add_argument( + "--gap", + type=float, + default=1.0, + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", + ) + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + + group.add_argument( + "--num-workers", + type=int, + default=2, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=True, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--spec-aug-time-warp-factor", + type=int, + default=80, + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", + ) + + group.add_argument( + "--enable-musan", + type=str2bool, + default=True, + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + transforms = [] + if self.args.enable_musan: + logging.info("Enable MUSAN") + logging.info("About to get Musan cuts") + cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + transforms.append( + CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + ) + else: + logging.info("Disable MUSAN") + + if self.args.concatenate_cuts: + logging.info( + f"Using cut concatenation with duration factor " + f"{self.args.duration_factor} and gap {self.args.gap}." + ) + # Cut concatenation should be the first transform in the list, + # so that if we e.g. mix noise in, it will fill the gaps between + # different utterances. + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + input_transforms = [] + if self.args.enable_spec_aug: + logging.info("Enable SpecAugment") + logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + # Set the value of num_frame_masks according to Lhotse's version. + # In different Lhotse's versions, the default of num_frame_masks is + # different. + num_frame_masks = 10 + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] + if num_frame_masks_parameter.default == 1: + num_frame_masks = 2 + logging.info(f"Num frame mask: {num_frame_masks}") + input_transforms.append( + SpecAugment( + time_warp_factor=self.args.spec_aug_time_warp_factor, + num_frame_masks=num_frame_masks, + features_mask_size=27, + num_feature_masks=2, + frames_mask_size=100, + ) + ) + else: + logging.info("Disable SpecAugment") + + logging.info("About to create train dataset") + train = K2SpeechRecognitionDataset( + input_strategy=eval(self.args.input_strategy)(), + cut_transforms=transforms, + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + # NOTE: the PerturbSpeed transform should be added only if we + # remove it from data prep stage. + # Add on-the-fly speed perturbation; since originally it would + # have increased epoch size by 3, we will apply prob 2/3 and use + # 3x more epochs. + # Speed perturbation probably should come first before + # concatenation, but in principle the transforms order doesn't have + # to be strict (e.g. could be randomized) + # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa + # Drop feats to be on the safe side. + train = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_transforms=input_transforms, + return_cuts=self.args.return_cuts, + ) + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SingleCutSampler.") + train_sampler = SingleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + transforms = [] + if self.args.concatenate_cuts: + transforms = [ + CutConcatenate( + duration_factor=self.args.duration_factor, gap=self.args.gap + ) + ] + transforms + + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + else: + validate = K2SpeechRecognitionDataset( + cut_transforms=transforms, + return_cuts=self.args.return_cuts, + ) + valid_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create dev dataloader") + valid_dl = DataLoader( + validate, + sampler=valid_sampler, + batch_size=None, + num_workers=2, + persistent_workers=False, + ) + + return valid_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = K2SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + f = self.args.manifest_dir / "xbmu_amdo31_cuts_train.jsonl.gz" + logging.info(f"About to get train cuts from {f}") + cuts_train = load_manifest_lazy(f) + return cuts_train + + @lru_cache() + def valid_cuts(self) -> CutSet: + f = self.args.manifest_dir / "xbmu_amdo31_cuts_dev.jsonl.gz" + logging.info(f"About to get valid cuts from {f}") + cuts_valid = load_manifest_lazy(f) + return cuts_valid + + @lru_cache() + def test_cuts(self) -> CutSet: + f = self.args.manifest_dir / "xbmu_amdo31_cuts_test.jsonl.gz" + logging.info(f"About to get test cuts from {f}") + cuts_test = load_manifest_lazy(f) + return cuts_test diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py new file mode 120000 index 000000000..c7c1a4b6e --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/conformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py new file mode 100755 index 000000000..6a67e26f8 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode.py @@ -0,0 +1,970 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao, +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method greedy_search +(2) beam search (not recommended) +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 +(3) modified beam search +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 +(4) fast beam search (one best) +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +(5) fast beam search (nbest) +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 +(7) fast beam search (with LG) +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(8) modified beam search with RNNLM shallow fusion (with LG) +./pruned_transducer_stateless5/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 4 \ + --max-contexts 4 \ + --rnn-lm-scale 0.4 \ + --rnn-lm-exp-dir /path/to/RNNLM/exp \ + --rnn-lm-epoch 99 \ + --rnn-lm-avg 1 \ + --rnn-lm-num-layers 3 \ + --rnn-lm-tie-weights 1 + + +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import Xbmu_AmdoAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, + modified_beam_search_rnnlm_shallow_fusion, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_LG + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + - modified_beam_search_rnnlm_shallow_fusion # for rnn lm shallow fusion + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, fast_beam_search_LG, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG and fast_beam_search_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is fast_beam_search_LG, + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is fast_beam_search_LG, + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--rnn-lm-scale", + type=float, + default=0.0, + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. + It specifies the path to RNN LM exp dir. + """, + ) + + parser.add_argument( + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. + It specifies the path to RNN LM exp dir. + """, + ) + + parser.add_argument( + "--rnn-lm-epoch", + type=int, + default=7, + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. + It specifies the checkpoint to use. + """, + ) + + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is modified_beam_search_rnnlm_shallow_fusion. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + rnnlm: Optional[RnnLmModel] = None, + rnnlm_scale: float = 1.0, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if ( + params.decoding_method == "fast_beam_search" + or params.decoding_method == "fast_beam_search_LG" + ): + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + if params.decoding_method == "fast_beam_search": + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": + hyp_tokens = modified_beam_search_rnnlm_shallow_fusion( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + sp=sp, + rnnlm=rnnlm, + rnnlm_scale=rnnlm_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, + rnnlm: Optional[RnnLmModel] = None, + rnnlm_scale: float = 1.0, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + logging.info(f"Decoding {batch_idx}-th batch") + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + rnnlm=rnnlm, + rnnlm_scale=rnnlm_scale, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + Xbmu_AmdoAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_LG", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + "modified_beam_search_rnnlm_shallow_fusion", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + params.suffix += f"-rnnlm-lm-scale-{params.rnn_lm_scale}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + rnn_lm_model = None + rnn_lm_scale = params.rnn_lm_scale + if params.decoding_method == "modified_beam_search_rnnlm_shallow_fusion": + rnn_lm_model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + assert params.rnn_lm_avg == 1 + + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + rnn_lm_model.eval() + + if "fast_beam_search" in params.decoding_method: + if "LG" in params.decoding_method: + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + xbmu_amdo = Xbmu_AmdoAsrDataModule(args) + + test_cuts = xbmu_amdo.test_cuts() + + test_dl = xbmu_amdo.test_dataloaders(test_cuts) + + test_sets = ["test"] + test_dl = [test_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + rnnlm=rnn_lm_model, + rnnlm_scale=rnn_lm_scale, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py new file mode 120000 index 000000000..d59ef95f7 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decode_stream.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/decode_stream.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py new file mode 120000 index 000000000..722e1c894 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py new file mode 120000 index 000000000..f58253127 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py new file mode 100755 index 000000000..54f656859 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/export.py @@ -0,0 +1,287 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./pruned_transducer_stateless5/export.py \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +It will generate a file exp_dir/pretrained.pt + +To use the generated file with `pruned_transducer_stateless5/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./pruned_transducer_stateless5/decode.py \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--streaming-model", + type=str2bool, + default=False, + help="""Whether to export a streaming model, if the models in exp-dir + are streaming model, this should be True. + """, + ) + + add_model_arguments(parser) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.streaming_model: + assert params.causal_convolution + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + convert_scaled_to_non_scaled(model, inplace=True) + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py new file mode 120000 index 000000000..9052f3cbb --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py new file mode 120000 index 000000000..b82e115fc --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/lstmp.py @@ -0,0 +1 @@ +../../../librispeech/ASR/lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py new file mode 120000 index 000000000..a99e74334 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/model.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py new file mode 120000 index 000000000..0a2f285aa --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py new file mode 100755 index 000000000..74a2210c3 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/pretrained.py @@ -0,0 +1,344 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) greedy search +./pruned_transducer_stateless5/pretrained.py \ + --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless5/pretrained.py \ + --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless5/pretrained.py \ + --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless5/pretrained.py \ + --checkpoint ./pruned_transducer_stateless5/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless5/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless5/exp/pretrained.pt is generated by +./pruned_transducer_stateless5/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py new file mode 120000 index 000000000..c10cdfe12 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py new file mode 120000 index 000000000..db93d155b --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py new file mode 120000 index 000000000..1199a61d6 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py new file mode 120000 index 000000000..f29284163 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py new file mode 100755 index 000000000..9aad32014 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/test_model.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +To run this file, do: + + cd icefall/egs/librispeech/ASR + python ./pruned_transducer_stateless4/test_model.py +""" + +from train import get_params, get_transducer_model + + +def test_model_1(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = 24 + params.dim_feedforward = 1536 # 384 * 4 + params.encoder_dim = 384 + model = get_transducer_model(params) + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + +# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf +def test_model_M(): + params = get_params() + params.vocab_size = 500 + params.blank_id = 0 + params.context_size = 2 + params.num_encoder_layers = 18 + params.dim_feedforward = 1024 + params.encoder_dim = 256 + params.nhead = 4 + params.decoder_dim = 512 + params.joiner_dim = 512 + model = get_transducer_model(params) + num_param = sum([p.numel() for p in model.parameters()]) + print(f"Number of model parameters: {num_param}") + + +def main(): + # test_model_1() + test_model_M() + + +if __name__ == "__main__": + main() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py new file mode 100755 index 000000000..5b5ac17be --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless5/train.py @@ -0,0 +1,1187 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless5/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless5/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless5/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless5/exp \ + --full-libri 1 \ + --max-duration 550 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import Xbmu_AmdoAsrDataModule +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + MetricsTracker, + display_and_save_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=int, + default=24, + help="Number of conformer encoder layers..", + ) + + parser.add_argument( + "--dim-feedforward", + type=int, + default=1536, + help="Feedforward dimension of the conformer encoder layer.", + ) + + parser.add_argument( + "--nhead", + type=int, + default=8, + help="Number of attention heads in the conformer encoder layer.", + ) + + parser.add_argument( + "--encoder-dim", + type=int, + default=384, + help="Attention dimension in the conformer encoder layer.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--dynamic-chunk-training", + type=str2bool, + default=False, + help="""Whether to use dynamic_chunk_training, if you want a streaming + model, this requires to be True. + """, + ) + + parser.add_argument( + "--causal-convolution", + type=str2bool, + default=False, + help="""Whether to use causal convolution, this requires to be True when + using dynamic_chunk_training. + """, + ) + + parser.add_argument( + "--short-chunk-size", + type=int, + default=25, + help="""Chunk length of dynamic training, the chunk size would be either + max sequence length of current batch or uniformly sampled from (1, short_chunk_size). + """, + ) + + parser.add_argument( + "--num-left-chunks", + type=int, + default=4, + help="How many left context can be seen in chunks when calculating attention.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless5/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="The initial learning rate. This value should not need to be changed.", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Conformer( + num_features=params.feature_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + dynamic_chunk_training=params.dynamic_chunk_training, + short_chunk_size=params.short_chunk_size, + num_left_chunks=params.num_left_chunks, + causal=params.causal_convolution, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + warmup: float = 1.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, + ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + is_finite = simple_loss_is_finite & pruned_loss_is_finite + if not torch.all(is_finite): + logging.info( + "Not all losses are finite!\n" + f"simple_loss: {simple_loss}\n" + f"pruned_loss: {pruned_loss}" + ) + display_and_save_batch(batch, params=params, sp=sp) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + + # If the batch contains more than 10 utterances AND + # if either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + + simple_loss = simple_loss.sum() + pruned_loss = pruned_loss.sum() + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = ( + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.dynamic_chunk_training: + assert ( + params.causal_convolution + ), "dynamic_chunk_training requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + xbmu_amdo = Xbmu_AmdoAsrDataModule(args) + + train_cuts = xbmu_amdo.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./conformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 1) // 2 - 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = xbmu_amdo.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = xbmu_amdo.valid_cuts() + valid_dl = xbmu_amdo.valid_dataloaders(valid_cuts) + + if params.start_batch <= 0 and not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, + warmup: float, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=warmup, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + + +def main(): + parser = get_parser() + Xbmu_AmdoAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/__init__.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py new file mode 120000 index 000000000..c473a600a --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless5/asr_datamodule.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py new file mode 100755 index 000000000..ace792e13 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decode.py @@ -0,0 +1,843 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless7/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import Xbmu_AmdoAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=9, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--simulate-streaming", + type=str2bool, + default=False, + help="""Whether to simulate streaming in decoding, this is a good way to + test a streaming model. + """, + ) + + parser.add_argument( + "--decode-chunk-size", + type=int, + default=16, + help="The chunk size for decoding (in frames after subsampling)", + ) + + parser.add_argument( + "--left-context", + type=int, + default=64, + help="left context can be seen during decoding (in frames after subsampling)", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.simulate_streaming: + feature_lens += params.left_context + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, params.left_context), + value=LOG_EPS, + ) + encoder_out, encoder_out_lens, _ = model.encoder.streaming_forward( + x=feature, + x_lens=feature_lens, + chunk_size=params.decode_chunk_size, + left_context=params.left_context, + simulate_streaming=True, + ) + else: + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + Xbmu_AmdoAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.simulate_streaming: + params.suffix += f"-streaming-chunk-size-{params.decode_chunk_size}" + params.suffix += f"-left-context-{params.left_context}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + if params.simulate_streaming: + assert ( + params.causal_convolution + ), "Decoding in streaming requires causal convolution" + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + xbmu_amdo = Xbmu_AmdoAsrDataModule(args) + + test_cuts = xbmu_amdo.test_cuts() + + test_dl = xbmu_amdo.test_dataloaders(test_cuts) + + test_sets = [ + "test", + ] + test_dl = [ + test_dl, + ] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py new file mode 120000 index 000000000..8283d8c5a --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py new file mode 120000 index 000000000..f58253127 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py new file mode 120000 index 000000000..2713792e6 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/export.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/export.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py new file mode 120000 index 000000000..a44034e34 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py new file mode 120000 index 000000000..0f0c3c90a --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py new file mode 120000 index 000000000..0d8bc665b --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/model.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py new file mode 120000 index 000000000..8a05abb5f --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/optim.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py new file mode 100755 index 000000000..d05bafcfb --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/pretrained.py @@ -0,0 +1,355 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads a checkpoint and uses it to decode waves. +You can generate the checkpoint with the following command: + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +(1) greedy search +./pruned_transducer_stateless7/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless7/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless7/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless7/pretrained.py \ + --checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./pruned_transducer_stateless7/exp/epoch-xx.pt`. + +Note: ./pruned_transducer_stateless7/exp/pretrained.pt is generated by +./pruned_transducer_stateless7/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py new file mode 120000 index 000000000..5f9be9fe0 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py new file mode 120000 index 000000000..f9960e5c6 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py new file mode 120000 index 000000000..7ceac5d10 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/test_model.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/test_model.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py new file mode 100755 index 000000000..1332bafd8 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/train.py @@ -0,0 +1,1224 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./pruned_transducer_stateless7/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless7/exp \ + --full-libri 1 \ + --max-duration 550 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import Xbmu_AmdoAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from zipformer import Zipformer + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,4,3,2,4", + help="Number of zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--feedforward-dims", + type=str, + default="1024,1024,2048,2048,1024", + help="Feedforward dimension of the zipformer encoder layers, comma separated.", + ) + + parser.add_argument( + "--nhead", + type=str, + default="8,8,8,8,8", + help="Number of attention heads in the zipformer encoder layers.", + ) + + parser.add_argument( + "--encoder-dims", + type=str, + default="384,384,384,384,384", + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", + ) + + parser.add_argument( + "--attention-dims", + type=str, + default="192,192,192,192,192", + help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; + not the same as embedding dimension.""", + ) + + parser.add_argument( + "--encoder-unmasked-dims", + type=str, + default="256,256,256,256,256", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", + ) + + parser.add_argument( + "--zipformer-downsampling-factors", + type=str, + default="1,2,4,8,2", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--cnn-module-kernels", + type=str, + default="31,31,31,31,31", + help="Sizes of kernels in convolution modules", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless7/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=2000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Zipformer and Transformer + def to_int_tuple(s: str): + return tuple(map(int, s.split(","))) + + encoder = Zipformer( + num_features=params.feature_dim, + output_downsampling_factor=2, + zipformer_downsampling_factors=to_int_tuple( + params.zipformer_downsampling_factors + ), + encoder_dims=to_int_tuple(params.encoder_dims), + attention_dim=to_int_tuple(params.attention_dims), + encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), + nhead=to_int_tuple(params.nhead), + feedforward_dim=to_int_tuple(params.feedforward_dims), + cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), + num_encoder_layers=to_int_tuple(params.num_encoder_layers), + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(params.encoder_dims.split(",")[-1]), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s + if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 + if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + xbmu_amdo = Xbmu_AmdoAsrDataModule(args) + + train_cuts = xbmu_amdo.train_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = xbmu_amdo.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = xbmu_amdo.valid_cuts() + valid_dl = xbmu_amdo.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + Xbmu_AmdoAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py new file mode 120000 index 000000000..f2f66041e --- /dev/null +++ b/egs/xbmu_amdo31/ASR/pruned_transducer_stateless7/zipformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless7/zipformer.py \ No newline at end of file diff --git a/egs/xbmu_amdo31/ASR/shared b/egs/xbmu_amdo31/ASR/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/xbmu_amdo31/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file From c25c8c6ad18b8a3d5de2f093947f7b2293eec35a Mon Sep 17 00:00:00 2001 From: Wei Kang Date: Sun, 4 Dec 2022 17:20:17 +0800 Subject: [PATCH 19/27] Add need_repeat_flag in phone based ctc graph compiler (#727) * Fix is_repeat_token in icefall * Fix phone based recipe * Update egs/librispeech/ASR/conformer_ctc3/train.py Co-authored-by: Fangjun Kuang * Fix black Co-authored-by: Fangjun Kuang --- egs/librispeech/ASR/conformer_ctc3/train.py | 1 + icefall/graph_compiler.py | 18 ++++++++++++++---- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc3/train.py b/egs/librispeech/ASR/conformer_ctc3/train.py index fb3b740c1..ac489af9e 100755 --- a/egs/librispeech/ASR/conformer_ctc3/train.py +++ b/egs/librispeech/ASR/conformer_ctc3/train.py @@ -890,6 +890,7 @@ def run(rank, world_size, args): graph_compiler = CtcTrainingGraphCompiler( lexicon, device=device, + need_repeat_flag=params.delay_penalty > 0, ) # Manually add the sos/eos ID with their default values # from the BPE recipe which we're adapting here. diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py index 0dcd777ad..d26ddbbd1 100644 --- a/icefall/graph_compiler.py +++ b/icefall/graph_compiler.py @@ -29,6 +29,7 @@ class CtcTrainingGraphCompiler(object): lexicon: Lexicon, device: torch.device, oov: str = "", + need_repeat_flag: bool = False, ): """ Args: @@ -39,6 +40,13 @@ class CtcTrainingGraphCompiler(object): oov: Out of vocabulary word. When a word in the transcript does not exist in the lexicon, it is replaced with `oov`. + need_repeat_flag: + If True, will add an attribute named `_is_repeat_token_` to ctc_topo + indicating whether this token is a repeat token in ctc graph. + This attribute is needed to implement delay-penalty for phone-based + ctc loss. See https://github.com/k2-fsa/k2/pull/1086 for more + details. Note: The above change MUST be included in k2 to open this + flag. """ L_inv = lexicon.L_inv.to(device) assert L_inv.requires_grad is False @@ -53,6 +61,12 @@ class CtcTrainingGraphCompiler(object): ctc_topo = k2.ctc_topo(max_token_id, modified=False) self.ctc_topo = ctc_topo.to(device) + + if need_repeat_flag: + self.ctc_topo._is_repeat_token_ = ( + self.ctc_topo.labels != self.ctc_topo.aux_labels + ) + self.device = device def compile(self, texts: List[str]) -> k2.Fsa: @@ -79,10 +93,6 @@ class CtcTrainingGraphCompiler(object): fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops) - self.ctc_topo._is_repeat_token_ = ( - self.ctc_topo.labels != self.ctc_topo.aux_labels - ).int() - decoding_graph = k2.compose( self.ctc_topo, fsa_with_self_loops, treat_epsilons_specially=False ) From bd7fa2253dab9f627edc914b3289fb2f6c0e5bb6 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 4 Dec 2022 20:27:45 +0800 Subject: [PATCH 20/27] Update the manifest statistics of the L subset of wenetspeech (#731) --- .../ASR/local/display_manifest_statistics.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/egs/wenetspeech/ASR/local/display_manifest_statistics.py b/egs/wenetspeech/ASR/local/display_manifest_statistics.py index c41445b8d..36e4ac5c3 100644 --- a/egs/wenetspeech/ASR/local/display_manifest_statistics.py +++ b/egs/wenetspeech/ASR/local/display_manifest_statistics.py @@ -33,6 +33,7 @@ def main(): paths = [ "./data/fbank/cuts_S.jsonl.gz", "./data/fbank/cuts_M.jsonl.gz", + "./data/fbank/cuts_L.jsonl.gz", "./data/fbank/cuts_DEV.jsonl.gz", "./data/fbank/cuts_TEST_NET.jsonl.gz", "./data/fbank/cuts_TEST_MEETING.jsonl.gz", @@ -48,6 +49,24 @@ if __name__ == "__main__": main() """ +Starting display the statistics for ./data/fbank/cuts_L.jsonl.gz + +Cuts count: 43874235 +Total duration (hours): 30217.3 +Speech duration (hours): 30217.3 (100.0%) +*** +Duration statistics (seconds): +mean 2.5 +std 1.7 +min 0.2 +25% 1.4 +50% 2.0 +75% 3.0 +99% 8.4 +99.5% 9.1 +99.9% 15.4 +max 405.1 + Starting display the statistics for ./data/fbank/cuts_S.jsonl.gz Duration statistics (seconds): mean 2.4 From be6e08f69a9384de27c28115a299d4fe64bb5de1 Mon Sep 17 00:00:00 2001 From: Cesc Date: Mon, 5 Dec 2022 23:35:10 +0800 Subject: [PATCH 21/27] fix wenet stateless5 jit export error (#735) --- egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py | 2 ++ egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py | 1 + .../ASR/pruned_transducer_stateless5/scaling_converter.py | 1 + 3 files changed, 4 insertions(+) mode change 100644 => 100755 egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py create mode 120000 egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py create mode 120000 egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py old mode 100644 new mode 100755 index 35577c327..cb541070e --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py @@ -74,6 +74,7 @@ import logging from pathlib import Path import torch +from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model from icefall.checkpoint import average_checkpoints, load_checkpoint @@ -184,6 +185,7 @@ def main(): # it here. # Otherwise, one of its arguments is a ragged tensor and is not # torch scriptabe. + convert_scaled_to_non_scaled(model, inplace=True) model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py new file mode 120000 index 000000000..d13a1e063 --- /dev/null +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/lstmp.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/lstmp.py \ No newline at end of file diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py new file mode 120000 index 000000000..e58473a04 --- /dev/null +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/scaling_converter.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/scaling_converter.py \ No newline at end of file From f13cf61b05432a989e6a42c95b843a56639bcbde Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 6 Dec 2022 16:34:27 +0800 Subject: [PATCH 22/27] Convert conv-emformer to ncnn (#717) * Export conv-emformer via torch.jit.trace() --- ...former-transducer-stateless2-2022-12-05.sh | 79 + ...-lstm-transducer-stateless2-2022-09-03.sh} | 0 ...ormer-transducer-stateless2-2022-12-05.yml | 77 + ...-lstm-transducer-stateless2-2022-09-03.yml | 2 +- .../emformer2.py | 1798 +++++++++++++++++ .../export-for-ncnn.py | 335 +++ .../jit_pretrained.py | 292 +++ .../lstmp.py | 1 + .../scaling_converter.py | 1 + .../streaming-ncnn-decode.py | 387 ++++ .../train2.py | 1128 +++++++++++ 11 files changed, 4099 insertions(+), 1 deletion(-) create mode 100755 .github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh rename .github/scripts/{run-librispeech-lstm-transducer-stateless2-2022-09-03.yml => run-librispeech-lstm-transducer-stateless2-2022-09-03.sh} (100%) create mode 100644 .github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml create mode 100644 egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py create mode 100755 egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py create mode 100755 egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py create mode 120000 egs/librispeech/ASR/conv_emformer_transducer_stateless2/lstmp.py create mode 120000 egs/librispeech/ASR/conv_emformer_transducer_stateless2/scaling_converter.py create mode 100755 egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py create mode 100755 egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py diff --git a/.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh b/.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh new file mode 100755 index 000000000..32c939206 --- /dev/null +++ b/.github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh @@ -0,0 +1,79 @@ +#!/usr/bin/env bash +# +set -e + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/librispeech/ASR + +repo_url=https://huggingface.co/Zengwei/icefall-asr-librispeech-conv-emformer-transducer-stateless2-2022-07-05 + +log "Downloading pre-trained model from $repo_url" +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +repo=$(basename $repo_url) +pushd $repo +git lfs pull --include "exp/pretrained-epoch-30-avg-10-averaged.pt" +git lfs pull --include "data/lang_bpe_500/bpe.model" +cd exp +ln -s pretrained-epoch-30-avg-10-averaged.pt epoch-99.pt +popd + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +log "Install ncnn and pnnx" + +# We are using a modified ncnn here. Will try to merge it to the official repo +# of ncnn +git clone https://github.com/csukuangfj/ncnn +pushd ncnn +git submodule init +git submodule update python/pybind11 +python3 setup.py bdist_wheel +ls -lh dist/ +pip install dist/*.whl +cd tools/pnnx +mkdir build +cd build +cmake -D Python3_EXECUTABLE=/opt/hostedtoolcache/Python/3.8.14/x64/bin/python3 .. +make -j4 pnnx + +./src/pnnx || echo "pass" + +popd + +log "Test exporting to pnnx format" + +./conv_emformer_transducer_stateless2/export-for-ncnn.py \ + --exp-dir $repo/exp \ + --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --epoch 99 \ + --avg 1 \ + --use-averaged-model 0 \ + \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 + +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/encoder_jit_trace-pnnx.pt +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/decoder_jit_trace-pnnx.pt +./ncnn/tools/pnnx/build/src/pnnx $repo/exp/joiner_jit_trace-pnnx.pt + +./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + --encoder-param-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.param \ + --encoder-bin-filename $repo/exp/encoder_jit_trace-pnnx.ncnn.bin \ + --decoder-param-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.param \ + --decoder-bin-filename $repo/exp/decoder_jit_trace-pnnx.ncnn.bin \ + --joiner-param-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.param \ + --joiner-bin-filename $repo/exp/joiner_jit_trace-pnnx.ncnn.bin \ + $repo/test_wavs/1089-134686-0001.wav diff --git a/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh similarity index 100% rename from .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml rename to .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh diff --git a/.github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml b/.github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml new file mode 100644 index 000000000..b9a1582c4 --- /dev/null +++ b/.github/workflows/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.yml @@ -0,0 +1,77 @@ +name: run-librispeech-conv-emformer-transducer-stateless2-2022-12-05 + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +jobs: + run_librispeech_conv_emformer_transducer_stateless2_2022_12_05: + if: github.event.label.name == 'ready' || github.event.label.name == 'ncnn' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python-version: [3.8] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | grep -v kaldifst | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }}-2022-09-25 + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + mkdir -p egs/librispeech/ASR/data + ln -sfv ~/tmp/fbank-libri egs/librispeech/ASR/data/fbank + ls -lh egs/librispeech/ASR/data/* + + sudo apt-get -qq install git-lfs tree sox + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-librispeech-conv-emformer-transducer-stateless2-2022-12-05.sh diff --git a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml index 59f116fde..f5ee09e16 100644 --- a/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml +++ b/.github/workflows/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml @@ -111,7 +111,7 @@ jobs: export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH - .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.yml + .github/scripts/run-librispeech-lstm-transducer-stateless2-2022-09-03.sh - name: Display decoding results for lstm_transducer_stateless2 if: github.event_name == 'schedule' diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py new file mode 100644 index 000000000..65a7efa77 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer2.py @@ -0,0 +1,1798 @@ +# Copyright 2022 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# It is modified based on +# 1) https://github.com/pytorch/audio/blob/main/torchaudio/models/emformer.py # noqa +# 2) https://github.com/pytorch/audio/blob/main/torchaudio/prototype/models/conv_emformer.py # noqa + +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv1d, + ScaledConv2d, + ScaledLinear, +) + +from icefall.utils import make_pad_mask + +LOG_EPSILON = math.log(1e-10) + + +def unstack_states( + states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] +) -> List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]]: + """Unstack the emformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Args: + states: + A tuple of 2 elements. + ``states[0]`` is the attention caches of a batch of utterance. + ``states[1]`` is the convolution caches of a batch of utterance. + ``len(states[0])`` and ``len(states[1])`` both eqaul to number of layers. # noqa + + Returns: + A list of states. + ``states[i]`` is a tuple of 2 elements of i-th utterance. + ``states[i][0]`` is the attention caches of i-th utterance. + ``states[i][1]`` is the convolution caches of i-th utterance. + ``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa + """ + + attn_caches, conv_caches = states + batch_size = conv_caches[0].size(0) + num_layers = len(attn_caches) + + list_attn_caches = [None] * batch_size + for i in range(batch_size): + list_attn_caches[i] = [[] for _ in range(num_layers)] + for li, layer in enumerate(attn_caches): + for s in layer: + s_list = s.unbind(dim=1) + for bi, b in enumerate(list_attn_caches): + b[li].append(s_list[bi]) + + list_conv_caches = [None] * batch_size + for i in range(batch_size): + list_conv_caches[i] = [None] * num_layers + for li, layer in enumerate(conv_caches): + c_list = layer.unbind(dim=0) + for bi, b in enumerate(list_conv_caches): + b[li] = c_list[bi] + + ans = [None] * batch_size + for i in range(batch_size): + ans[i] = [list_attn_caches[i], list_conv_caches[i]] + + return ans + + +def stack_states( + state_list: List[Tuple[List[List[torch.Tensor]], List[torch.Tensor]]] +) -> Tuple[List[List[torch.Tensor]], List[torch.Tensor]]: + """Stack list of emformer states that correspond to separate utterances + into a single emformer state so that it can be used as an input for + emformer when those utterances are formed into a batch. + + Note: + It is the inverse of :func:`unstack_states`. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the emformer model for a single utterance. + ``states[i]`` is a tuple of 2 elements of i-th utterance. + ``states[i][0]`` is the attention caches of i-th utterance. + ``states[i][1]`` is the convolution caches of i-th utterance. + ``len(states[i][0])`` and ``len(states[i][1])`` both eqaul to number of layers. # noqa + + Returns: + A new state corresponding to a batch of utterances. + See the input argument of :func:`unstack_states` for the meaning + of the returned tensor. + """ + batch_size = len(state_list) + + attn_caches = [] + for layer in state_list[0][0]: + if batch_size > 1: + # Note: We will stack attn_caches[layer][s][] later to get attn_caches[layer][s] # noqa + attn_caches.append([[s] for s in layer]) + else: + attn_caches.append([s.unsqueeze(1) for s in layer]) + for b, states in enumerate(state_list[1:], 1): + for li, layer in enumerate(states[0]): + for si, s in enumerate(layer): + attn_caches[li][si].append(s) + if b == batch_size - 1: + attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) + + conv_caches = [] + for layer in state_list[0][1]: + if batch_size > 1: + # Note: We will stack conv_caches[layer][] later to get conv_caches[layer] # noqa + conv_caches.append([layer]) + else: + conv_caches.append(layer.unsqueeze(0)) + for b, states in enumerate(state_list[1:], 1): + for li, layer in enumerate(states[1]): + conv_caches[li].append(layer) + if b == batch_size - 1: + conv_caches[li] = torch.stack(conv_caches[li], dim=0) + + return [attn_caches, conv_caches] + + +class ConvolutionModule(nn.Module): + """ConvolutionModule. + + Modified from https://github.com/pytorch/audio/blob/main/torchaudio/prototype/models/conv_emformer.py # noqa + + Args: + chunk_length (int): + Length of each chunk. + right_context_length (int): + Length of right context. + channels (int): + The number of input channels and output channels of conv layers. + kernel_size (int): + Kernerl size of conv layers. + bias (bool): + Whether to use bias in conv layers (default=True). + """ + + def __init__( + self, + chunk_length: int, + right_context_length: int, + channels: int, + kernel_size: int, + bias: bool = True, + ) -> None: + """Construct an ConvolutionModule object.""" + super().__init__() + # kernerl_size should be an odd number for 'SAME' padding + assert (kernel_size - 1) % 2 == 0, kernel_size + + self.chunk_length = chunk_length + self.right_context_length = right_context_length + self.channels = channels + + self.pointwise_conv1 = ScaledConv1d( + channels, + 2 * channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + ) + # After pointwise_conv1 we put x through a gated linear unit + # (nn.functional.glu). + # For most layers the normal rms value of channels of x seems to be in + # the range 1 to 4, but sometimes, for some reason, for layer 0 the rms + # ends up being very large, between 50 and 100 for different channels. + # This will cause very peaky and sparse derivatives for the sigmoid + # gating function, which will tend to make the loss function not learn + # effectively. (for most layers the average absolute values are in the + # range 0.5..9.0, and the average p(x>0), i.e. positive proportion, + # at the output of pointwise_conv1.output is around 0.35 to 0.45 for + # different layers, which likely breaks down as 0.5 for the "linear" + # half and 0.2 to 0.3 for the part that goes into the sigmoid. + # The idea is that if we constrain the rms values to a reasonable range + # via a constraint of max_abs=10.0, it will be in a better position to + # start learning something, i.e. to latch onto the correct range. + self.deriv_balancer1 = ActivationBalancer( + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 + ) + + # make it causal by padding cached (kernel_size - 1) frames on the left + self.cache_size = kernel_size - 1 + self.depthwise_conv = ScaledConv1d( + channels, + channels, + kernel_size, + stride=1, + padding=0, + groups=channels, + bias=bias, + ) + + self.deriv_balancer2 = ActivationBalancer( + channel_dim=1, min_positive=0.05, max_positive=1.0 + ) + + self.activation = DoubleSwish() + + self.pointwise_conv2 = ScaledConv1d( + channels, + channels, + kernel_size=1, + stride=1, + padding=0, + bias=bias, + initial_scale=0.25, + ) + + def _split_right_context( + self, + pad_utterance: torch.Tensor, + right_context: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + pad_utterance: + Its shape is (cache_size + U, B, D). + right_context: + Its shape is (R, B, D). + + Returns: + Right context segments padding with corresponding context. + Its shape is (num_segs * B, D, cache_size + right_context_length). + """ + U_, B, D = pad_utterance.size() + R = right_context.size(0) + assert self.right_context_length != 0 + assert R % self.right_context_length == 0 + num_chunks = R // self.right_context_length + right_context = right_context.reshape( + num_chunks, self.right_context_length, B, D + ) + right_context = right_context.permute(0, 2, 1, 3).reshape( + num_chunks * B, self.right_context_length, D + ) + + intervals = torch.arange( + 0, self.chunk_length * (num_chunks - 1), self.chunk_length + ) + first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) + indexes = intervals.unsqueeze(1) + first.unsqueeze(0) + indexes = torch.cat( + [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] + ) + padding = pad_utterance[indexes] # (num_chunks, cache_size, B, D) + padding = padding.permute(0, 2, 1, 3).reshape( + num_chunks * B, self.cache_size, D + ) + + pad_right_context = torch.cat([padding, right_context], dim=1) + # (num_chunks * B, cache_size + right_context_length, D) + return pad_right_context.permute(0, 2, 1) + + def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: + """ + Args: + right_context: + Right context segments. + It shape is (num_segs * B, D, right_context_length). + B: + Batch size. + + Returns: + A tensor of shape (B, D, R), where + R = num_segs * right_context_length. + """ + right_context = right_context.reshape( + -1, B, self.channels, self.right_context_length + ) + right_context = right_context.permute(1, 2, 0, 3) + right_context = right_context.reshape(B, self.channels, -1) + return right_context + + def forward( + self, + utterance: torch.Tensor, + right_context: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Causal convolution module. + + Args: + utterance (torch.Tensor): + Utterance tensor of shape (U, B, D). + right_context (torch.Tensor): + Right context tensor of shape (R, B, D). + + Returns: + A tuple of 2 tensors: + - output utterance of shape (U, B, D). + - output right_context of shape (R, B, D). + """ + U, B, D = utterance.size() + R, _, _ = right_context.size() + + # point-wise conv and GLU mechanism + x = torch.cat([right_context, utterance], dim=0) # (R + U, B, D) + x = x.permute(1, 2, 0) # (B, D, R + U) + x = self.pointwise_conv1(x) # (B, 2 * D, R + U) + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (B, D, R + U) + utterance = x[:, :, R:] # (B, D, U) + right_context = x[:, :, :R] # (B, D, R) + + # make causal convolution + cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) + pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) + + # depth-wise conv on utterance + utterance = self.depthwise_conv(pad_utterance) # (B, D, U) + + if self.right_context_length > 0: + # depth-wise conv on right_context + pad_right_context = self._split_right_context( + pad_utterance.permute(2, 0, 1), right_context.permute(2, 0, 1) + ) # (num_segs * B, D, cache_size + right_context_length) + right_context = self.depthwise_conv( + pad_right_context + ) # (num_segs * B, D, right_context_length) + right_context = self._merge_right_context(right_context, B) # (B, D, R) + + x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) + x = self.deriv_balancer2(x) + x = self.activation(x) + + # point-wise conv + x = self.pointwise_conv2(x) # (B, D, R + U) + + right_context = x[:, :, :R] # (B, D, R) + utterance = x[:, :, R:] # (B, D, U) + return ( + utterance.permute(2, 0, 1), + right_context.permute(2, 0, 1), + ) + + @torch.jit.export + def infer( + self, + utterance: torch.Tensor, + right_context: torch.Tensor, + cache: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Causal convolution module applied on both utterance and right_context. + + Args: + utterance (torch.Tensor): + Utterance tensor of shape (U, B, D). + right_context (torch.Tensor): + Right context tensor of shape (R, B, D). + cache (torch.Tensor, optional): + Cached tensor for left padding of shape (B, D, cache_size). + + Returns: + A tuple of 3 tensors: + - output utterance of shape (U, B, D). + - output right_context of shape (R, B, D). + - updated cache tensor of shape (B, D, cache_size). + """ + # U, B, D = utterance.size() + # R, _, _ = right_context.size() + U = self.chunk_length + B = 1 + D = self.channels + R = self.right_context_length + + # point-wise conv + x = torch.cat([utterance, right_context], dim=0) # (U + R, B, D) + x = x.permute(1, 2, 0) # (B, D, U + R) + x = self.pointwise_conv1(x) # (B, 2 * D, U + R) + x = self.deriv_balancer1(x) + x = nn.functional.glu(x, dim=1) # (B, D, U + R) + + # make causal convolution + assert cache.shape == (B, D, self.cache_size), cache.shape + x = torch.cat([cache, x], dim=2) # (B, D, cache_size + U + R) + # update cache + new_cache = x[:, :, -R - self.cache_size : -R] + + # 1-D depth-wise conv + x = self.depthwise_conv(x) # (B, D, U + R) + + x = self.deriv_balancer2(x) + x = self.activation(x) + + # point-wise conv + x = self.pointwise_conv2(x) # (B, D, U + R) + + utterance = x[:, :, :U] # (B, D, U) + right_context = x[:, :, U:] # (B, D, R) + return ( + utterance.permute(2, 0, 1), + right_context.permute(2, 0, 1), + new_cache, + ) + + +class EmformerAttention(nn.Module): + r"""Emformer layer attention module. + + Args: + embed_dim (int): + Embedding dimension. + nhead (int): + Number of attention heads in each Emformer layer. + dropout (float, optional): + Dropout probability. (Default: 0.0) + tanh_on_mem (bool, optional): + If ``True``, applies tanh to memory elements. (Default: ``False``) + negative_inf (float, optional): + Value to use for negative infinity in attention weights. (Default: -1e8) + """ + + def __init__( + self, + embed_dim: int, + nhead: int, + left_context_length: int, + chunk_length: int, + right_context_length: int, + memory_size: int, + dropout: float = 0.0, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + super().__init__() + + if embed_dim % nhead != 0: + raise ValueError( + f"embed_dim ({embed_dim}) is not a multiple of nhead ({nhead})." + ) + + self.embed_dim = embed_dim + self.nhead = nhead + self.tanh_on_mem = tanh_on_mem + self.negative_inf = negative_inf + self.head_dim = embed_dim // nhead + self.dropout = dropout + + self.left_context_length = left_context_length + self.right_context_length = right_context_length + self.chunk_length = chunk_length + self.memory_size = memory_size + + self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) + self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) + self.out_proj = ScaledLinear( + embed_dim, embed_dim, bias=True, initial_scale=0.25 + ) + + def _gen_attention_probs( + self, + attention_weights: torch.Tensor, + attention_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Given the entire attention weights, mask out unecessary connections + and optionally with padding positions, to obtain underlying chunk-wise + attention probabilities. + + B: batch size; + Q: length of query; + KV: length of key and value. + + Args: + attention_weights (torch.Tensor): + Attention weights computed on the entire concatenated tensor + with shape (B * nhead, Q, KV). + attention_mask (torch.Tensor): + Mask tensor where chunk-wise connections are filled with `False`, + and other unnecessary connections are filled with `True`, + with shape (Q, KV). + padding_mask (torch.Tensor, optional): + Mask tensor where the padding positions are fill with `True`, + and other positions are filled with `False`, with shapa `(B, KV)`. + + Returns: + A tensor of shape (B * nhead, Q, KV). + """ + attention_weights_float = attention_weights.float() + attention_weights_float = attention_weights_float.masked_fill( + attention_mask.unsqueeze(0), self.negative_inf + ) + if padding_mask is not None: + Q = attention_weights.size(1) + B = attention_weights.size(0) // self.nhead + attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) + attention_weights_float = attention_weights_float.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), + self.negative_inf, + ) + attention_weights_float = attention_weights_float.view( + B * self.nhead, Q, -1 + ) + + attention_probs = nn.functional.softmax( + attention_weights_float, dim=-1 + ).type_as(attention_weights) + + attention_probs = nn.functional.dropout( + attention_probs, p=self.dropout, training=self.training + ) + return attention_probs + + def _forward_impl( + self, + utterance: torch.Tensor, + right_context: torch.Tensor, + memory: torch.Tensor, + attention_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + left_context_key: Optional[torch.Tensor] = None, + left_context_val: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Underlying chunk-wise attention implementation.""" + # U, B, _ = utterance.size() + # R = right_context.size(0) + # M = memory.size(0) + + U = self.chunk_length + B = 1 + R = self.right_context_length + M = self.memory_size + L = self.left_context_length + + scaling = float(self.head_dim) ** -0.5 + + # compute query with [right_context, utterance]. + query = self.emb_to_query(torch.cat([right_context, utterance])) + # compute key and value with [memory, right_context, utterance]. + key, value = self.emb_to_key_value( + torch.cat([memory, right_context, utterance]) + ).chunk(chunks=2, dim=2) + + if left_context_key is not None and left_context_val is not None: + # now compute key and value with + # [memory, right context, left context, uttrance] + # this is used in inference mode + key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) + value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) + + # Q = query.size(0) + Q = U + R + + # KV = key.size(0) + + reshaped_query = query.view(Q, self.nhead, self.head_dim).permute(1, 0, 2) + reshaped_key = key.view(M + R + U + L, self.nhead, self.head_dim).permute( + 1, 0, 2 + ) + reshaped_value = value.view(M + R + U + L, self.nhead, self.head_dim).permute( + 1, 0, 2 + ) + + # reshaped_query, reshaped_key, reshaped_value = [ + # tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) + # for tensor in [query, key, value] + # ] # (B * nhead, Q or KV, head_dim) + attention_weights = torch.bmm( + reshaped_query * scaling, reshaped_key.permute(0, 2, 1) + ) # (B * nhead, Q, KV) + + # compute attention probabilities + if False: + attention_probs = self._gen_attention_probs( + attention_weights, attention_mask, padding_mask + ) + else: + attention_probs = nn.functional.softmax(attention_weights, dim=-1) + + # compute attention outputs + attention = torch.bmm(attention_probs, reshaped_value) + assert attention.shape == (B * self.nhead, Q, self.head_dim) + attention = attention.permute(1, 0, 2).reshape(-1, self.embed_dim) + # TODO(fangjun): ncnn does not support reshape(-1, 1, self.embed_dim) + # We have to change InnerProduct in ncnn to ignore the extra dim below + attention = attention.unsqueeze(1) + + # apply output projection + output_right_context_utterance = self.out_proj(attention) + # The return shape of output_right_context_utterance is (10, 1, 512) + + return output_right_context_utterance, key, value + + def forward( + self, + utterance: torch.Tensor, + right_context: torch.Tensor, + memory: torch.Tensor, + attention_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # TODO: Modify docs. + """Forward pass for training and validation mode. + + B: batch size; + D: embedding dimension; + R: length of the hard-copied right contexts; + U: length of full utterance; + M: length of memory vectors. + + It computes a `big` attention matrix on full utterance and + then utilizes a pre-computed mask to simulate chunk-wise attention. + + It concatenates three blocks: hard-copied right contexts, + and full utterance, as a `big` block, + to compute the query tensor: + query = [right_context, utterance], + with length Q = R + U. + It concatenates the three blocks: memory vectors, + hard-copied right contexts, and full utterance as another `big` block, + to compute the key and value tensors: + key & value = [memory, right_context, utterance], + with length KV = M + R + U. + Attention scores is computed with above `big` query and key. + + Then the underlying chunk-wise attention is obtained by applying + the attention mask. Suppose + c_i: chunk at index i; + r_i: right context that c_i can use; + l_i: left context that c_i can use; + m_i: past memory vectors from previous layer that c_i can use; + The target chunk-wise attention is: + c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key) + + Args: + utterance (torch.Tensor): + Full utterance frames, with shape (U, B, D). + right_context (torch.Tensor): + Hard-copied right context frames, with shape (R, B, D), + where R = num_chunks * right_context_length + memory (torch.Tensor): + Memory elements, with shape (M, B, D), where M = num_chunks - 1. + It is an empty tensor without using memory. + attention_mask (torch.Tensor): + Pre-computed attention mask to simulate underlying chunk-wise + attention, with shape (Q, KV). + padding_mask (torch.Tensor): + Padding mask of key tensor, with shape (B, KV). + + Returns: + Output of right context and utterance, with shape (R + U, B, D). + """ + output_right_context_utterance, _, _ = self._forward_impl( + utterance, + right_context, + memory, + attention_mask, + padding_mask=padding_mask, + ) + return output_right_context_utterance + + @torch.jit.export + def infer( + self, + utterance: torch.Tensor, + right_context: torch.Tensor, + memory: torch.Tensor, + left_context_key: torch.Tensor, + left_context_val: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass for inference. + + B: batch size; + D: embedding dimension; + R: length of right context; + U: length of utterance, i.e., current chunk; + L: length of cached left context; + M: length of cached memory vectors. + + It concatenates the right context and utterance (i.e., current chunk) + of current chunk, to compute the query tensor: + query = [right_context, utterance], + with length Q = R + U. + It concatenates the memory vectors, right context, left context, and + current chunk, to compute the key and value tensors: + key & value = [memory, right_context, left_context, utterance], + with length KV = M + R + L + U. + + The chunk-wise attention is: + chunk, right context (in query) -> + left context, chunk, right context, memory vectors (in key). + + Args: + utterance (torch.Tensor): + Current chunk frames, with shape (U, B, D), where U = chunk_length. + right_context (torch.Tensor): + Right context frames, with shape (R, B, D), + where R = right_context_length. + memory (torch.Tensor): + Memory vectors, with shape (M, B, D), or empty tensor. + left_context_key (torch,Tensor): + Cached attention key of left context from preceding computation, + with shape (L, B, D). + left_context_val (torch.Tensor): + Cached attention value of left context from preceding computation, + with shape (L, B, D). + padding_mask (torch.Tensor): + Padding mask of key tensor, with shape (B, KV). + + Returns: + A tuple containing 4 tensors: + - output of right context and utterance, with shape (R + U, B, D). + - attention key of left context and utterance, which would be cached + for next computation, with shape (L + U, B, D). + - attention value of left context and utterance, which would be + cached for next computation, with shape (L + U, B, D). + """ + # U = utterance.size(0) + # R = right_context.size(0) + # L = left_context_key.size(0) + # M = memory.size(0) + + U = self.chunk_length + R = self.right_context_length + L = self.left_context_length + M = self.memory_size + + # query = [right context, utterance] + Q = R + U + # key, value = [memory, right context, left context, utterance] + KV = M + R + L + U + attention_mask = torch.zeros(Q, KV).to( + dtype=torch.bool, device=utterance.device + ) + + output_right_context_utterance, key, value = self._forward_impl( + utterance, + right_context, + memory, + attention_mask, + padding_mask=padding_mask, + left_context_key=left_context_key, + left_context_val=left_context_val, + ) + return ( + output_right_context_utterance, + key[M + R :], + value[M + R :], + ) + + +class EmformerEncoderLayer(nn.Module): + """Emformer layer that constitutes Emformer. + + Args: + d_model (int): + Input dimension. + nhead (int): + Number of attention heads. + dim_feedforward (int): + Hidden layer dimension of feedforward network. + chunk_length (int): + Length of each input segment. + dropout (float, optional): + Dropout probability. (Default: 0.0) + layer_dropout (float, optional): + Layer dropout probability. (Default: 0.0) + cnn_module_kernel (int): + Kernel size of convolution module. + left_context_length (int, optional): + Length of left context. (Default: 0) + right_context_length (int, optional): + Length of right context. (Default: 0) + memory_size (int, optional): + Number of memory elements to use. (Default: 0) + tanh_on_mem (bool, optional): + If ``True``, applies tanh to memory elements. (Default: ``False``) + negative_inf (float, optional): + Value to use for negative infinity in attention weights. (Default: -1e8) + """ + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int, + chunk_length: int, + dropout: float = 0.1, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 31, + left_context_length: int = 0, + right_context_length: int = 0, + memory_size: int = 0, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + super().__init__() + + self.attention = EmformerAttention( + embed_dim=d_model, + nhead=nhead, + left_context_length=left_context_length, + chunk_length=chunk_length, + memory_size=memory_size, + right_context_length=right_context_length, + dropout=dropout, + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + ) + self.summary_op = nn.AvgPool1d( + kernel_size=chunk_length, stride=chunk_length, ceil_mode=True + ) + + self.feed_forward_macaron = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + + self.conv_module = ConvolutionModule( + chunk_length, + right_context_length, + d_model, + cnn_module_kernel, + ) + + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean + # (or at least, zero-median). + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + + self.dropout = nn.Dropout(dropout) + + self.layer_dropout = layer_dropout + self.left_context_length = left_context_length + self.right_context_length = right_context_length + self.chunk_length = chunk_length + self.memory_size = memory_size + self.d_model = d_model + self.use_memory = memory_size > 0 + + def _update_attn_cache( + self, + next_key: torch.Tensor, + next_val: torch.Tensor, + memory: torch.Tensor, + attn_cache: List[torch.Tensor], + ) -> List[torch.Tensor]: + """Update cached attention state: + 1) output memory of current chunk in the lower layer; + 2) attention key and value in current chunk's computation, which would + be reused in next chunk's computation. + """ + # attn_cache[0].shape (self.memory_size, 1, 512) + # memory.shape (1, 1, 512) + # attn_cache[1].shape (self.left_context_length, 1, 512) + # attn_cache[2].shape (self.left_context_length, 1, 512) + # next_key.shape (self.left_context_length + self.right_context_utterance, 1, 512) + # next_value.shape (self.left_context_length + self.right_context_utterance, 1, 512) + new_memory = torch.cat([attn_cache[0], memory]) + # TODO(fangjun): Remove torch.cat + # new_key = torch.cat([attn_cache[1], next_key]) + # new_val = torch.cat([attn_cache[2], next_val]) + attn_cache[0] = new_memory[1:] + attn_cache[1] = next_key[-self.left_context_length :] + attn_cache[2] = next_val[-self.left_context_length :] + return attn_cache + + def _apply_conv_module_forward( + self, + right_context_utterance: torch.Tensor, + R: int, + ) -> torch.Tensor: + """Apply convolution module in training and validation mode.""" + utterance = right_context_utterance[R:] + right_context = right_context_utterance[:R] + utterance, right_context = self.conv_module(utterance, right_context) + right_context_utterance = torch.cat([right_context, utterance]) + return right_context_utterance + + def _apply_conv_module_infer( + self, + right_context_utterance: torch.Tensor, + R: int, + conv_cache: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply convolution module on utterance in inference mode.""" + utterance = right_context_utterance[R:] + right_context = right_context_utterance[:R] + utterance, right_context, conv_cache = self.conv_module.infer( + utterance, right_context, conv_cache + ) + right_context_utterance = torch.cat([right_context, utterance]) + return right_context_utterance, conv_cache + + def _apply_attention_module_forward( + self, + right_context_utterance: torch.Tensor, + R: int, + attention_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Apply attention module in training and validation mode.""" + utterance = right_context_utterance[R:] + right_context = right_context_utterance[:R] + + if self.use_memory: + memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1, :, : + ] + else: + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + output_right_context_utterance = self.attention( + utterance=utterance, + right_context=right_context, + memory=memory, + attention_mask=attention_mask, + padding_mask=padding_mask, + ) + + return output_right_context_utterance + + def _apply_attention_module_infer( + self, + right_context_utterance: torch.Tensor, + R: int, + attn_cache: List[torch.Tensor], + padding_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """Apply attention module in inference mode. + 1) Unpack cached states including: + - memory from previous chunks; + - attention key and value of left context from preceding + chunk's compuation; + 2) Apply attention computation; + 3) Update cached attention states including: + - memory of current chunk; + - attention key and value in current chunk's computation, which would + be resued in next chunk's computation. + """ + utterance = right_context_utterance[R:] + right_context = right_context_utterance[:R] + + pre_memory = attn_cache[0] + left_context_key = attn_cache[1] + left_context_val = attn_cache[2] + + if self.use_memory: + memory = torch.mean(utterance, dim=0, keepdim=True) + + # memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + # :1, :, : + # ] + else: + memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + (output_right_context_utterance, next_key, next_val) = self.attention.infer( + utterance=utterance, + right_context=right_context, + memory=pre_memory, + left_context_key=left_context_key, + left_context_val=left_context_val, + padding_mask=padding_mask, + ) + attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) + return output_right_context_utterance, attn_cache + + def forward( + self, + utterance: torch.Tensor, + right_context: torch.Tensor, + attention_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Forward pass for training and validation mode. + + B: batch size; + D: embedding dimension; + R: length of hard-copied right contexts; + U: length of full utterance; + M: length of memory vectors. + + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + attention_mask (torch.Tensor): + Attention mask for underlying attention module, + with shape (Q, KV), where Q = R + U, KV = M + R + U. + padding_mask (torch.Tensor): + Padding mask of ker tensor, with shape (B, KV). + + Returns: + A tuple containing 2 tensors: + - output utterance, with shape (U, B, D). + - output right context, with shape (R, B, D). + """ + R = right_context.size(0) + src = torch.cat([right_context, utterance]) + src_orig = src + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + # emformer attention module + src_att = self._apply_attention_module_forward( + src, R, attention_mask, padding_mask=padding_mask + ) + src = src + self.dropout(src_att) + + # convolution module + src_conv = self._apply_conv_module_forward(src, R) + src = src + self.dropout(src_conv) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig + + output_utterance = src[R:] + output_right_context = src[:R] + return output_utterance, output_right_context + + @torch.jit.export + def infer( + self, + utterance: torch.Tensor, + right_context: torch.Tensor, + cache: List[torch.Tensor], + padding_mask: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: + """Forward pass for inference. + + B: batch size; + D: embedding dimension; + R: length of right_context; + U: length of utterance; + M: length of memory. + + Args: + utterance (torch.Tensor): + Utterance frames, with shape (U, B, D). + right_context (torch.Tensor): + Right context frames, with shape (R, B, D). + attn_cache (List[torch.Tensor]): + Cached attention tensors generated in preceding computation, + including memory, key and value of left context. + conv_cache (torch.Tensor, optional): + Cache tensor of left context for causal convolution. + padding_mask (torch.Tensor): + Padding mask of ker tensor. + + Returns: + (Tensor, Tensor, List[torch.Tensor], Tensor): + - output utterance, with shape (U, B, D); + - output right_context, with shape (R, B, D); + - output attention cache; + - output convolution cache. + """ + R = self.right_context_length + src = torch.cat([right_context, utterance]) + attn_cache = cache[:3] + conv_cache = cache[3] + + # macaron style feed forward module + src = src + self.dropout(self.feed_forward_macaron(src)) + + # emformer attention module + src_att, attn_cache = self._apply_attention_module_infer( + src, R, attn_cache, padding_mask=padding_mask + ) + src = src + self.dropout(src_att) + + # convolution module + src_conv, conv_cache = self._apply_conv_module_infer(src, R, conv_cache) + src = src + self.dropout(src_conv) + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + output_utterance = src[R:] + output_right_context = src[:R] + return (output_utterance, output_right_context, attn_cache + [conv_cache]) + + +def _gen_attention_mask_block( + col_widths: List[int], + col_mask: List[bool], + num_rows: int, + device: torch.device, +) -> torch.Tensor: + assert len(col_widths) == len( + col_mask + ), "Length of col_widths must match that of col_mask" + + mask_block = [ + torch.ones(num_rows, col_width, device=device) + if is_ones_col + else torch.zeros(num_rows, col_width, device=device) + for col_width, is_ones_col in zip(col_widths, col_mask) + ] + return torch.cat(mask_block, dim=1) + + +class EmformerEncoder(nn.Module): + """Implements the Emformer architecture introduced in + *Emformer: Efficient Memory Transformer Based Acoustic Model for Low Latency + Streaming Speech Recognition* + [:footcite:`shi2021emformer`]. + + In this model, the memory bank computation is simplifed, using the averaged + value of each chunk as its memory vector. + + Args: + d_model (int): + Input dimension. + nhead (int): + Number of attention heads in each emformer layer. + dim_feedforward (int): + Hidden layer dimension of each emformer layer's feedforward network. + num_encoder_layers (int): + Number of emformer layers to instantiate. + chunk_length (int): + Length of each input segment. + dropout (float, optional): + Dropout probability. (default: 0.0) + layer_dropout (float, optional): + Layer dropout probability. (default: 0.0) + cnn_module_kernel (int): + Kernel size of convolution module. + left_context_length (int, optional): + Length of left context. (default: 0) + right_context_length (int, optional): + Length of right context. (default: 0) + memory_size (int, optional): + Number of memory elements to use. (default: 0) + tanh_on_mem (bool, optional): + If ``true``, applies tanh to memory elements. (default: ``false``) + negative_inf (float, optional): + Value to use for negative infinity in attention weights. (default: -1e8) + """ + + def __init__( + self, + chunk_length: int, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 31, + left_context_length: int = 0, + right_context_length: int = 0, + memory_size: int = 0, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + ): + super().__init__() + + assert ( + chunk_length - 1 + ) & chunk_length == 0, "chunk_length should be a power of 2." + self.shift = int(math.log(chunk_length, 2)) + + self.use_memory = memory_size > 0 + + self.emformer_layers = nn.ModuleList( + [ + EmformerEncoderLayer( + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + chunk_length=chunk_length, + dropout=dropout, + layer_dropout=layer_dropout, + cnn_module_kernel=cnn_module_kernel, + left_context_length=left_context_length, + right_context_length=right_context_length, + memory_size=memory_size, + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + ) + for layer_idx in range(num_encoder_layers) + ] + ) + + self.num_encoder_layers = num_encoder_layers + self.d_model = d_model + self.left_context_length = left_context_length + self.right_context_length = right_context_length + self.chunk_length = chunk_length + self.memory_size = memory_size + self.cnn_module_kernel = cnn_module_kernel + + def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: + """Hard copy each chunk's right context and concat them.""" + T = x.shape[0] + num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) + # first (num_chunks - 1) right context block + intervals = torch.arange( + 0, self.chunk_length * (num_chunks - 1), self.chunk_length + ) + first = torch.arange( + self.chunk_length, self.chunk_length + self.right_context_length + ) + indexes = intervals.unsqueeze(1) + first.unsqueeze(0) + # cat last right context block + indexes = torch.cat( + [ + indexes, + torch.arange(T - self.right_context_length, T).unsqueeze(0), + ] + ) + right_context_blocks = x[indexes.reshape(-1)] + return right_context_blocks + + def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: + """Calculate column widths (key, value) in attention mask for the + chunk_idx chunk.""" + num_chunks = math.ceil(U / self.chunk_length) + rc = self.right_context_length + lc = self.left_context_length + rc_start = chunk_idx * rc + rc_end = rc_start + rc + chunk_start = max(chunk_idx * self.chunk_length - lc, 0) + chunk_end = min((chunk_idx + 1) * self.chunk_length, U) + R = rc * num_chunks + + if self.use_memory: + m_start = max(chunk_idx - self.memory_size, 0) + M = num_chunks - 1 + col_widths = [ + m_start, # before memory + chunk_idx - m_start, # memory + M - chunk_idx, # after memory + rc_start, # before right context + rc, # right context + R - rc_end, # after right context + chunk_start, # before chunk + chunk_end - chunk_start, # chunk + U - chunk_end, # after chunk + ] + else: + col_widths = [ + rc_start, # before right context + rc, # right context + R - rc_end, # after right context + chunk_start, # before chunk + chunk_end - chunk_start, # chunk + U - chunk_end, # after chunk + ] + + return col_widths + + def _gen_attention_mask(self, utterance: torch.Tensor) -> torch.Tensor: + """Generate attention mask to simulate underlying chunk-wise attention + computation, where chunk-wise connections are filled with `False`, + and other unnecessary connections beyond chunk are filled with `True`. + + R: length of hard-copied right contexts; + U: length of full utterance; + M: length of memory vectors; + Q: length of attention query; + KV: length of attention key and value. + + The shape of attention mask is (Q, KV). + If self.use_memory is `True`: + query = [right_context, utterance]; + key, value = [memory, right_context, utterance]; + Q = R + U, KV = M + R + U. + Otherwise: + query = [right_context, utterance] + key, value = [right_context, utterance] + Q = R + U, KV = R + U. + + Suppose: + c_i: chunk at index i; + r_i: right context that c_i can use; + l_i: left context that c_i can use; + m_i: past memory vectors from previous layer that c_i can use; + The target chunk-wise attention is: + c_i, r_i (in query) -> l_i, c_i, r_i, m_i (in key). + """ + U = utterance.size(0) + num_chunks = math.ceil(U / self.chunk_length) + + right_context_mask = [] + utterance_mask = [] + + if self.use_memory: + num_cols = 9 + # right context and utterance both attend to memory, right context, + # utterance + right_context_utterance_cols_mask = [ + idx in [1, 4, 7] for idx in range(num_cols) + ] + else: + num_cols = 6 + # right context and utterance both attend to right context and + # utterance + right_context_utterance_cols_mask = [ + idx in [1, 4] for idx in range(num_cols) + ] + masks_to_concat = [right_context_mask, utterance_mask] + + for chunk_idx in range(num_chunks): + col_widths = self._gen_attention_mask_col_widths(chunk_idx, U) + + right_context_mask_block = _gen_attention_mask_block( + col_widths, + right_context_utterance_cols_mask, + self.right_context_length, + utterance.device, + ) + right_context_mask.append(right_context_mask_block) + + utterance_mask_block = _gen_attention_mask_block( + col_widths, + right_context_utterance_cols_mask, + min( + self.chunk_length, + U - chunk_idx * self.chunk_length, + ), + utterance.device, + ) + utterance_mask.append(utterance_mask_block) + + attention_mask = ( + 1 - torch.cat([torch.cat(mask) for mask in masks_to_concat]) + ).to(torch.bool) + return attention_mask + + def _forward( + self, x: torch.Tensor, lengths: torch.Tensor, warmup: float = 1.0 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for training and validation mode. + + B: batch size; + D: input dimension; + U: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (U + right_context_length, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x, which contains the + right_context at the end. + + Returns: + A tuple of 2 tensors: + - output utterance frames, with shape (U, B, D). + - output_lengths, with shape (B,), without containing the + right_context at the end. + """ + U = x.size(0) - self.right_context_length + + right_context = self._gen_right_context(x) + utterance = x[:U] + output_lengths = torch.clamp(lengths - self.right_context_length, min=0) + attention_mask = self._gen_attention_mask(utterance) + + M = ( + right_context.size(0) // self.right_context_length - 1 + if self.use_memory + else 0 + ) + padding_mask = make_pad_mask(M + right_context.size(0) + output_lengths) + + output = utterance + for layer in self.emformer_layers: + output, right_context = layer( + output, + right_context, + attention_mask, + padding_mask=padding_mask, + warmup=warmup, + ) + + return output, output_lengths + + @torch.jit.export + def infer( + self, + x: torch.Tensor, + states: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor],]: + """Forward pass for streaming inference. + + B: batch size; + D: input dimension; + U: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (U + right_context_length, B, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x, which contains the + right_context at the end. + states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa + Cached states containing: + - attn_caches: attention states from preceding chunk's computation, + where each element corresponds to each emformer layer + - conv_caches: left context for causal convolution, where each + element corresponds to each layer. + + Returns: + (Tensor, Tensor, List[List[torch.Tensor]], List[torch.Tensor]): + - output utterance frames, with shape (U, B, D). + - output lengths, with shape (B,), without containing the + right_context at the end. + - updated states from current chunk's computation. + """ + # lengths = chunk_length + right_context_length + utterance = x[: self.chunk_length] + right_context = x[self.chunk_length :] + # right_context_utterance = torch.cat([right_context, utterance]) + + output = utterance + output_states: List[torch.Tensor] = [] + for layer_idx, layer in enumerate(self.emformer_layers): + start = layer_idx * 4 + end = start + 4 + cache = states[start:end] + + (output, right_context, output_cache,) = layer.infer( + output, + right_context, + padding_mask=None, + cache=cache, + ) + output_states.extend(output_cache) + + return output, output_states + + @torch.jit.export + def init_states( + self, device: torch.device = torch.device("cpu") + ) -> List[torch.Tensor]: + """Create initial states.""" + # + states = [] + # layer0: attn cache, conv cache, 3 tensors + 1 tensor + # layer1: attn cache, conv cache, 3 tensors + 1 tensor + # layer2: attn cache, conv cache, 3 tensors + 1 tensor + # ... + # last layer: attn cache, conv cache, 3 tensors + 1 tensor + for i in range(self.num_encoder_layers): + states.append(torch.zeros(self.memory_size, 1, self.d_model, device=device)) + states.append( + torch.zeros(self.left_context_length, 1, self.d_model, device=device) + ) + states.append( + torch.zeros(self.left_context_length, 1, self.d_model, device=device) + ) + + states.append( + torch.zeros(1, self.d_model, self.cnn_module_kernel - 1, device=device) + ) + return states + + attn_caches = [ + [ + torch.zeros(self.memory_size, self.d_model, device=device), + torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros(self.left_context_length, self.d_model, device=device), + ] + for _ in range(self.num_encoder_layers) + ] + conv_caches = [ + torch.zeros(self.d_model, self.cnn_module_kernel - 1, device=device) + for _ in range(self.num_encoder_layers) + ] + states: Tuple[List[List[torch.Tensor]], List[torch.Tensor]] = ( + attn_caches, + conv_caches, + ) + return states + + +class Emformer(EncoderInterface): + def __init__( + self, + num_features: int, + chunk_length: int, + subsampling_factor: int = 4, + d_model: int = 256, + nhead: int = 4, + dim_feedforward: int = 2048, + num_encoder_layers: int = 12, + dropout: float = 0.1, + layer_dropout: float = 0.075, + cnn_module_kernel: int = 3, + left_context_length: int = 0, + right_context_length: int = 0, + memory_size: int = 0, + tanh_on_mem: bool = False, + negative_inf: float = -1e8, + is_pnnx: bool = True, + ): + super().__init__() + + self.subsampling_factor = subsampling_factor + self.right_context_length = right_context_length + self.chunk_length = chunk_length + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + if chunk_length % subsampling_factor != 0: + raise NotImplementedError( + "chunk_length must be a mutiple of subsampling_factor." + ) + if left_context_length != 0 and left_context_length % subsampling_factor != 0: + raise NotImplementedError( + "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa + ) + if right_context_length != 0 and right_context_length % subsampling_factor != 0: + raise NotImplementedError( + "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa + ) + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + self.encoder_embed = Conv2dSubsampling(num_features, d_model, is_pnnx=is_pnnx) + self.is_pnnx = is_pnnx + + self.encoder = EmformerEncoder( + chunk_length=chunk_length // subsampling_factor, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + num_encoder_layers=num_encoder_layers, + dropout=dropout, + layer_dropout=layer_dropout, + cnn_module_kernel=cnn_module_kernel, + left_context_length=left_context_length // subsampling_factor, + right_context_length=right_context_length // subsampling_factor, + memory_size=memory_size, + tanh_on_mem=tanh_on_mem, + negative_inf=negative_inf, + ) + + def _forward( + self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for training and non-streaming inference. + + B: batch size; + D: feature dimension; + T: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (B, T, D). + x_lens (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x, containing the + right_context at the end. + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + + Returns: + (Tensor, Tensor): + - output embedding, with shape (B, T', D), where + T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4. + - output lengths, with shape (B,), without containing the + right_context at the end. + """ + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + x_lens = (((x_lens - 1) >> 1) - 1) >> 1 + assert x.size(0) == x_lens.max().item() + + output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) + + output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + + return output, output_lengths + + def forward( + self, + x: torch.Tensor, + states: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor],]: + """Forward pass for streaming inference. + + B: batch size; + D: feature dimension; + T: length of utterance. + + Args: + x (torch.Tensor): + Utterance frames right-padded with right context frames, + with shape (B, T, D). + lengths (torch.Tensor): + With shape (B,) and i-th element representing number of valid + utterance frames for i-th batch element in x, containing the + right_context at the end. + states (List[torch.Tensor, List[List[torch.Tensor]], List[torch.Tensor]]: # noqa + Cached states containing: + - past_lens: number of past frames for each sample in batch + - attn_caches: attention states from preceding chunk's computation, + where each element corresponds to each emformer layer + - conv_caches: left context for causal convolution, where each + element corresponds to each layer. + Returns: + (Tensor, Tensor): + - output embedding, with shape (B, T', D), where + T' = ((T - 1) // 2 - 1) // 2 - self.right_context_length // 4. + - output lengths, with shape (B,), without containing the + right_context at the end. + - updated states from current chunk's computation. + """ + x = self.encoder_embed(x) + # drop the first and last frames + x = x[:, 1:-1, :] + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # Caution: We assume the subsampling factor is 4! + + output, output_states = self.encoder.infer(x, states) + + output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + + return output, output_states + + @torch.jit.export + def init_states( + self, device: torch.device = torch.device("cpu") + ) -> List[torch.Tensor]: + """Create initial states.""" + return self.encoder.init_states(device) + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + is_pnnx: bool = False, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >=7, in_channels >=7 + out_channels + Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + is_pnnx: + True if we are converting the model to PNNX format. + False otherwise. + """ + assert in_channels >= 7 + super().__init__() + + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=1, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear( + layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels + ) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(out_channels, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) + + # ncnn supports only batch size == 1 + self.is_pnnx = is_pnnx + self.conv_out_dim = self.out.weight.shape[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + + if torch.jit.is_tracing() and self.is_pnnx: + x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) + x = self.out(x) + else: + # Now x is of shape (N, odim, ((T-1)//2-1)//2, ((idim-1)//2-1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py new file mode 100755 index 000000000..716de5734 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export-for-ncnn.py @@ -0,0 +1,335 @@ +#!/usr/bin/env python3 + +""" +Usage: +./conv_emformer_transducer_stateless2/export-for-ncnn.py \ + --exp-dir ./conv_emformer_transducer_stateless2/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 30 \ + --avg 10 \ + --use-averaged-model=True \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 \ + +cd ./conv_emformer_transducer_stateless2/exp +pnnx encoder_jit_trace-pnnx.pt +pnnx decoder_jit_trace-pnnx.pt +pnnx joiner_jit_trace-pnnx.pt + +You can find converted models at +https://huggingface.co/csukuangfj/sherpa-ncnn-conv-emformer-transducer-2022-12-04 + +See ./streaming-ncnn-decode.py +and +https://github.com/k2-fsa/sherpa-ncnn +for usage. +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +from scaling_converter import convert_scaled_to_non_scaled +from train2 import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: torch.nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + chunk_length = encoder_model.chunk_length # before subsampling + right_context_length = encoder_model.right_context_length # before subsampling + pad_length = right_context_length + 2 * 4 + 3 + s = f"chunk_length: {chunk_length}, " + s += f"right_context_length: {right_context_length}\n" + logging.info(s) + + T = chunk_length + pad_length + + x = torch.zeros(1, T, 80, dtype=torch.float32) + states = encoder_model.init_states() + states = encoder_model.init_states() + + traced_model = torch.jit.trace(encoder_model, (x, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: torch.nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: torch.nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ + : params.avg + 1 + ] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + + logging.info("Exporting encoder") + encoder_filename = params.exp_dir / "encoder_jit_trace-pnnx.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + logging.info("Exporting decoder") + decoder_filename = params.exp_dir / "decoder_jit_trace-pnnx.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + logging.info("Exporting joiner") + joiner_filename = params.exp_dir / "joiner_jit_trace-pnnx.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py new file mode 100755 index 000000000..1fe358c79 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/jit_pretrained.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python3 +# flake8: noqa +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models exported by `torch.jit.trace()` +and uses them to decode waves. +You can use the following command to get the exported models: + +./conv_emformer_transducer_stateless2/export-for-ncnn.py \ + --exp-dir ./conv_emformer_transducer_stateless2/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 + +Usage of this script: + +./conv_emformer_transducer_stateless2/jit_pretrained.py \ + --encoder-model-filename ./conv_emformer_transducer_stateless2/exp/encoder_jit_trace-pnnx.pt \ + --decoder-model-filename ./conv_emformer_transducer_stateless2/exp/decoder_jit_trace-pnnx.pt \ + --joiner-model-filename ./conv_emformer_transducer_stateless2/exp/joiner_jit_trace-pnnx.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature +from torch.nn.utils.rnn import pad_sequence +from typing import Optional, List + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_file", + type=str, + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + assert encoder_out.ndim == 2 + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32).unsqueeze(0) + # decoder_input.shape (1,, 1 context_size) + decoder_out = decoder(decoder_input, torch.tensor([0])).squeeze(1) + else: + assert decoder_out.ndim == 2 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for i in range(T): + cur_encoder_out = encoder_out[i : i + 1] + joiner_out = joiner(cur_encoder_out, decoder_out).squeeze(0) + y = joiner_out.argmax(dim=0).item() + + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + + decoder_input = torch.tensor(decoder_input, dtype=torch.int32).unsqueeze(0) + decoder_out = decoder(decoder_input, torch.tensor([0])).squeeze(1) + + return hyp, decoder_out + + +def create_streaming_feature_extractor(sample_rate) -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = sample_rate + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor(args.sample_rate) + + logging.info(f"Reading sound files: {args.sound_file}") + wave_samples = read_sound_files( + filenames=[args.sound_file], + expected_sample_rate=args.sample_rate, + )[0] + logging.info(wave_samples.shape) + + logging.info("Decoding started") + chunk_length = encoder.chunk_length + right_context_length = encoder.right_context_length + + # Assume the subsampling factor is 4 + pad_length = right_context_length + 2 * 4 + 3 + T = chunk_length + pad_length + + logging.info(f"chunk_length: {chunk_length}") + logging.info(f"right_context_length: {right_context_length}") + + states = encoder.init_states(device) + logging.info(f"num layers: {len(states)//4}") + + tail_padding = torch.zeros(int(0.3 * args.sample_rate), dtype=torch.float32) + + wave_samples = torch.cat([wave_samples, tail_padding]) + + chunk = int(0.25 * args.sample_rate) # 0.2 second + num_processed_frames = 0 + + hyp = None + decoder_out = None + + start = 0 + while start < wave_samples.numel(): + logging.info(f"{start}/{wave_samples.numel()}") + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + online_fbank.accept_waveform( + sampling_rate=args.sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= T: + frames = [] + for i in range(T): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += chunk_length + frames = torch.cat(frames, dim=0).unsqueeze(0) + # TODO(fangjun): remove x_lens + x_lens = torch.tensor([T]) + encoder_out, _, states = encoder(frames, x_lens, states) + + hyp, decoder_out = greedy_search( + decoder, joiner, encoder_out.squeeze(0), decoder_out, hyp + ) + + context_size = 2 + + logging.info(args.sound_file) + logging.info(sp.decode(hyp[context_size:])) + + logging.info("Decoding Done") + + +torch.set_num_threads(4) +torch.set_num_interop_threads(1) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/lstmp.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/lstmp.py new file mode 120000 index 000000000..4f377cd01 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/lstmp.py @@ -0,0 +1 @@ +../lstm_transducer_stateless2/lstmp.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/scaling_converter.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/scaling_converter.py new file mode 120000 index 000000000..3b667058d --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/scaling_converter.py @@ -0,0 +1 @@ +../pruned_transducer_stateless3/scaling_converter.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py new file mode 100755 index 000000000..b21fe5c7e --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming-ncnn-decode.py @@ -0,0 +1,387 @@ +#!/usr/bin/env python3 +# +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +./conv_emformer_transducer_stateless2/streaming-ncnn-decode.py \ + --tokens ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/tokens.txt \ + --encoder-param-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \ + --encoder-bin-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/encoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \ + --decoder-param-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \ + --decoder-bin-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/decoder_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \ + --joiner-param-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.param \ + --joiner-bin-filename ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/joiner_jit_trace-epoch-30-avg-10-pnnx.ncnn.bin \ + ./sherpa-ncnn-conv-emformer-transducer-2022-12-04/test_wavs/1089-134686-0001.wav + +You can find pretrained models at +https://huggingface.co/csukuangfj/sherpa-ncnn-conv-emformer-transducer-2022-12-04 +""" + +import argparse +import logging +from typing import List, Optional, Tuple + +import k2 +import ncnn +import torch +import torchaudio +from kaldifeat import FbankOptions, OnlineFbank, OnlineFeature + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--tokens", + type=str, + help="Path to tokens.txt", + ) + + parser.add_argument( + "--encoder-param-filename", + type=str, + help="Path to encoder.ncnn.param", + ) + + parser.add_argument( + "--encoder-bin-filename", + type=str, + help="Path to encoder.ncnn.bin", + ) + + parser.add_argument( + "--decoder-param-filename", + type=str, + help="Path to decoder.ncnn.param", + ) + + parser.add_argument( + "--decoder-bin-filename", + type=str, + help="Path to decoder.ncnn.bin", + ) + + parser.add_argument( + "--joiner-param-filename", + type=str, + help="Path to joiner.ncnn.param", + ) + + parser.add_argument( + "--joiner-bin-filename", + type=str, + help="Path to joiner.ncnn.bin", + ) + + parser.add_argument( + "sound_filename", + type=str, + help="Path to foo.wav", + ) + + return parser.parse_args() + + +class Model: + def __init__(self, args): + self.init_encoder(args) + self.init_decoder(args) + self.init_joiner(args) + + self.num_layers = 12 + self.memory_size = 32 + self.d_model = 512 + self.cnn_module_kernel = 31 + + self.left_context_length = 32 // 4 # after subsampling + self.chunk_length = 32 # before subsampling + right_context_length = 8 # before subsampling + pad_length = right_context_length + 2 * 4 + 3 + self.T = self.chunk_length + pad_length + print("T", self.T, self.chunk_length) + + def get_init_states(self) -> List[torch.Tensor]: + states = [] + + for i in range(self.num_layers): + s0 = torch.zeros(self.memory_size, self.d_model) + s1 = torch.zeros(self.left_context_length, self.d_model) + s2 = torch.zeros(self.left_context_length, self.d_model) + s3 = torch.zeros(self.d_model, self.cnn_module_kernel - 1) + states.extend([s0, s1, s2, s3]) + + return states + + def init_encoder(self, args): + encoder_net = ncnn.Net() + encoder_net.opt.use_packing_layout = False + encoder_net.opt.use_fp16_storage = False + encoder_param = args.encoder_param_filename + encoder_model = args.encoder_bin_filename + + encoder_net.load_param(encoder_param) + encoder_net.load_model(encoder_model) + + self.encoder_net = encoder_net + + def init_decoder(self, args): + decoder_param = args.decoder_param_filename + decoder_model = args.decoder_bin_filename + + decoder_net = ncnn.Net() + + decoder_net.load_param(decoder_param) + decoder_net.load_model(decoder_model) + + self.decoder_net = decoder_net + + def init_joiner(self, args): + joiner_param = args.joiner_param_filename + joiner_model = args.joiner_bin_filename + joiner_net = ncnn.Net() + joiner_net.load_param(joiner_param) + joiner_net.load_model(joiner_model) + + self.joiner_net = joiner_net + + def run_encoder( + self, + x: torch.Tensor, + states: List[torch.Tensor], + ) -> Tuple[torch.Tensor, List[torch.Tensor]]: + """ + Args: + x: + A tensor of shape (T, C) + states: + A list of tensors. len(states) == self.num_layers * 4 + Returns: + Return a tuple containing: + - encoder_out, a tensor of shape (T, encoder_dim). + - next_states, a list of tensors containing the next states + """ + with self.encoder_net.create_extractor() as ex: + ex.set_num_threads(4) + ex.input("in0", ncnn.Mat(x.numpy()).clone()) + + # layer0 in2-in5 + # layer1 in6-in9 + for i in range(self.num_layers): + offset = 1 + i * 4 + name = f"in{offset}" + # (32, 1, 512) -> (32, 512) + ex.input(name, ncnn.Mat(states[i * 4 + 0].numpy()).clone()) + + name = f"in{offset+1}" + # (8, 1, 512) -> (8, 512) + ex.input(name, ncnn.Mat(states[i * 4 + 1].numpy()).clone()) + + name = f"in{offset+2}" + # (8, 1, 512) -> (8, 512) + ex.input(name, ncnn.Mat(states[i * 4 + 2].numpy()).clone()) + + name = f"in{offset+3}" + # (1, 512, 2) -> (512, 2) + ex.input(name, ncnn.Mat(states[i * 4 + 3].numpy()).clone()) + + import pdb + + # pdb.set_trace() + ret, ncnn_out0 = ex.extract("out0") + # assert ret == 0, ret + encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + + out_states: List[torch.Tensor] = [] + for i in range(4 * self.num_layers): + name = f"out{i+1}" + ret, ncnn_out_state = ex.extract(name) + assert ret == 0, ret + ncnn_out_state = torch.from_numpy(ncnn_out_state.numpy()) + out_states.append(ncnn_out_state) + + return encoder_out, out_states + + def run_decoder(self, decoder_input): + assert decoder_input.dtype == torch.int32 + + with self.decoder_net.create_extractor() as ex: + ex.set_num_threads(4) + ex.input("in0", ncnn.Mat(decoder_input.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + decoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return decoder_out + + def run_joiner(self, encoder_out, decoder_out): + with self.joiner_net.create_extractor() as ex: + ex.set_num_threads(4) + ex.input("in0", ncnn.Mat(encoder_out.numpy()).clone()) + ex.input("in1", ncnn.Mat(decoder_out.numpy()).clone()) + ret, ncnn_out0 = ex.extract("out0") + assert ret == 0, ret + joiner_out = torch.from_numpy(ncnn_out0.numpy()).clone() + return joiner_out + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert ( + sample_rate == expected_sample_rate + ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + # We use only the first channel + ans.append(wave[0]) + return ans + + +def create_streaming_feature_extractor() -> OnlineFeature: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return OnlineFbank(opts) + + +def greedy_search( + model: Model, + encoder_out: torch.Tensor, + decoder_out: Optional[torch.Tensor] = None, + hyp: Optional[List[int]] = None, +): + context_size = 2 + blank_id = 0 + + if decoder_out is None: + assert hyp is None, hyp + hyp = [blank_id] * context_size + decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) + decoder_out = model.run_decoder(decoder_input).squeeze(0) + else: + assert decoder_out.ndim == 1 + assert hyp is not None, hyp + + T = encoder_out.size(0) + for t in range(T): + cur_encoder_out = encoder_out[t] + + joiner_out = model.run_joiner(cur_encoder_out, decoder_out) + y = joiner_out.argmax(dim=0).item() + if y != blank_id: + hyp.append(y) + decoder_input = hyp[-context_size:] + decoder_input = torch.tensor(decoder_input, dtype=torch.int32) + decoder_out = model.run_decoder(decoder_input).squeeze(0) + + return hyp, decoder_out + + +def main(): + args = get_args() + logging.info(vars(args)) + + model = Model(args) + + sound_file = args.sound_filename + + sample_rate = 16000 + + logging.info("Constructing Fbank computer") + online_fbank = create_streaming_feature_extractor() + + logging.info(f"Reading sound files: {sound_file}") + wave_samples = read_sound_files( + filenames=[sound_file], + expected_sample_rate=sample_rate, + )[0] + logging.info(wave_samples.shape) + + tail_padding = torch.zeros(int(0.3 * sample_rate), dtype=torch.float32) + + wave_samples = torch.cat([wave_samples, tail_padding]) + + states = model.get_init_states() + + hyp = None + decoder_out = None + + num_processed_frames = 0 + segment = model.T + offset = model.chunk_length + + chunk = int(1 * sample_rate) # 0.2 second + + start = 0 + while start < wave_samples.numel(): + end = min(start + chunk, wave_samples.numel()) + samples = wave_samples[start:end] + start += chunk + + online_fbank.accept_waveform( + sampling_rate=sample_rate, + waveform=samples, + ) + while online_fbank.num_frames_ready - num_processed_frames >= segment: + frames = [] + for i in range(segment): + frames.append(online_fbank.get_frame(num_processed_frames + i)) + num_processed_frames += offset + frames = torch.cat(frames, dim=0) + encoder_out, states = model.run_encoder(frames, states) + hyp, decoder_out = greedy_search(model, encoder_out, decoder_out, hyp) + + symbol_table = k2.SymbolTable.from_file(args.tokens) + + context_size = 2 + text = "" + for i in hyp[context_size:]: + text += symbol_table[i] + text = text.replace("▁", " ").strip() + + logging.info(sound_file) + logging.info(text) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py new file mode 100755 index 000000000..c91f94876 --- /dev/null +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train2.py @@ -0,0 +1,1128 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./conv_emformer_transducer_stateless2/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir conv_emformer_transducer_stateless2/exp \ + --full-libri 1 \ + --max-duration 280 \ + --master-port 12321 \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 + +# For mix precision training: +./conv_emformer_transducer_stateless2/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir conv_emformer_transducer_stateless2/exp \ + --full-libri 1 \ + --max-duration 300 \ + --master-port 12321 \ + --num-encoder-layers 12 \ + --chunk-length 32 \ + --cnn-module-kernel 31 \ + --left-context-length 32 \ + --right-context-length 8 \ + --memory-size 32 +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from emformer2 import Emformer +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--encoder-dim", + type=int, + default=512, + help="Attention dim for the Emformer", + ) + + parser.add_argument( + "--nhead", + type=int, + default=8, + help="Number of attention heads for the Emformer", + ) + + parser.add_argument( + "--dim-feedforward", + type=int, + default=2048, + help="Feed-forward dimension for the Emformer", + ) + + parser.add_argument( + "--num-encoder-layers", + type=int, + default=12, + help="Number of encoder layers for the Emformer", + ) + + parser.add_argument( + "--cnn-module-kernel", + type=int, + default=31, + help="Kernel size for the convolution module.", + ) + + parser.add_argument( + "--left-context-length", + type=int, + default=32, + help="""Number of frames before subsampling for left context + in the Emformer.""", + ) + + parser.add_argument( + "--chunk-length", + type=int, + default=32, + help="""Number of frames before subsampling for each chunk + in the Emformer.""", + ) + + parser.add_argument( + "--right-context-length", + type=int, + default=8, + help="""Number of frames before subsampling for right context + in the Emformer.""", + ) + + parser.add_argument( + "--memory-size", + type=int, + default=0, + help="Number of entries in the memory for the Emformer", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless2/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="""The initial learning rate. This value should not need to be + changed.""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network) part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=8000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for Emformer + "feature_dim": 80, + "subsampling_factor": 4, + # parameters for decoder + "decoder_dim": 512, + # parameters for joiner + "joiner_dim": 512, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Emformer( + num_features=params.feature_dim, + chunk_length=params.chunk_length, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + cnn_module_kernel=params.cnn_module_kernel, + left_context_length=params.left_context_length, + right_context_length=params.right_context_length, + memory_size=params.memory_size, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + warmup: float = 1.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + ) + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = ( + 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 1600 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2**22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + librispeech = LibriSpeechAsrDataModule(args) + + if params.full_libri: + train_cuts = librispeech.train_all_shuf_cuts() + else: + train_cuts = librispeech.train_clean_100_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 20.0 + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + # warmup = 0.0 is so that the derivs for the pruned loss stay zero + # (i.e. are not remembered by the decaying-average in adam), because + # we want to avoid these params being subject to shrinkage in adam. + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=0.0, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() From 10472e7ffc8bd3f8a096eb7cc62c86a4b861a9a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ali=20Haznedaro=C4=9Flu?= <53865510+ahazned@users.noreply.github.com> Date: Wed, 7 Dec 2022 03:22:50 +0300 Subject: [PATCH 23/27] Update prepare.sh (#737) --- egs/spgispeech/ASR/prepare.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/spgispeech/ASR/prepare.sh b/egs/spgispeech/ASR/prepare.sh index 4842f52d0..8331f94d5 100755 --- a/egs/spgispeech/ASR/prepare.sh +++ b/egs/spgispeech/ASR/prepare.sh @@ -108,7 +108,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then pieces=$(find data/manifests -name "cuts_train_[0-9]*.jsonl.gz") lhotse combine $pieces data/manifests/cuts_train.jsonl.gz fi - gunzip -c data/manifests/train_cuts.jsonl.gz | shuf | gzip -c > data/manifests/train_cuts_shuf.jsonl.gz + gunzip -c data/manifests/cuts_train.jsonl.gz | shuf | gzip -c > data/manifests/cuts_train_shuf.jsonl.gz fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then @@ -136,7 +136,7 @@ if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then # Add special words to words.txt echo " 0" > $lang_dir/words.txt echo "!SIL 1" >> $lang_dir/words.txt - echo "[UNK] 2" >> $lang_dir/words.txt + echo " 2" >> $lang_dir/words.txt # Add regular words to words.txt gunzip -c data/manifests/cuts_train_raw.jsonl.gz \ From 0e325c8782c8b9178cf0f2b030e49ae64f2b091d Mon Sep 17 00:00:00 2001 From: huangruizhe Date: Wed, 7 Dec 2022 02:43:26 -0500 Subject: [PATCH 24/27] Fixed rnn_lm model.py (#738) --- icefall/rnn_lm/model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index 9eef88840..3598a4857 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -159,10 +159,10 @@ class RnnLmModel(torch.nn.Module): if state: h, c = state else: - h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to( + h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to( device ) - c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to( + c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size).to( device ) @@ -179,8 +179,8 @@ class RnnLmModel(torch.nn.Module): if state: h, c = state else: - h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) - c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) + h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size) + c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.hidden_size) device = next(self.parameters()).device From d65fe17d2766e34adbb4080f9691ea829ac0ae05 Mon Sep 17 00:00:00 2001 From: armusc <46787089+armusc@users.noreply.github.com> Date: Thu, 8 Dec 2022 13:21:51 +0100 Subject: [PATCH 25/27] Update train.py with parameters_names as required by optimizer initialization (#742) * Update train.py --- egs/ami/ASR/pruned_transducer_stateless7/train.py | 11 ++++++++++- .../ASR/pruned_transducer_stateless7_ctc/train.py | 11 ++++++++++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/egs/ami/ASR/pruned_transducer_stateless7/train.py b/egs/ami/ASR/pruned_transducer_stateless7/train.py index b5efb3405..81823ced2 100755 --- a/egs/ami/ASR/pruned_transducer_stateless7/train.py +++ b/egs/ami/ASR/pruned_transducer_stateless7/train.py @@ -972,7 +972,16 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py index abfd56e5a..162ad8412 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7_ctc/train.py @@ -1036,7 +1036,16 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + parameters_names = [] + parameters_names.append( + [name_param_pair[0] for name_param_pair in model.named_parameters()] + ) + optimizer = ScaledAdam( + model.parameters(), + lr=params.base_lr, + clipping_scale=2.0, + parameters_names=parameters_names, + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) From 4501821fd98821a6cf3a238c6dc5c01422643fdb Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 9 Dec 2022 16:46:44 +0800 Subject: [PATCH 26/27] Support using OpenFst to compile HLG. (#606) * Support using OpenFst to compile HLG. * Fix style issues --- .../ASR/local/compile_hlg_using_openfst.py | 184 ++++++++++++++++++ egs/librispeech/ASR/prepare.sh | 41 +++- icefall/shared/convert-k2-to-openfst.py | 102 ++++++++++ requirements.txt | 1 + 4 files changed, 325 insertions(+), 3 deletions(-) create mode 100755 egs/librispeech/ASR/local/compile_hlg_using_openfst.py create mode 100755 icefall/shared/convert-k2-to-openfst.py diff --git a/egs/librispeech/ASR/local/compile_hlg_using_openfst.py b/egs/librispeech/ASR/local/compile_hlg_using_openfst.py new file mode 100755 index 000000000..9e5e3df69 --- /dev/null +++ b/egs/librispeech/ASR/local/compile_hlg_using_openfst.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script takes as input lang_dir and generates HLG from + + - H, the ctc topology, built from tokens contained in lang_dir/lexicon.txt + - L, the lexicon, built from lang_dir/L_disambig.fst + + Caution: We use a lexicon that contains disambiguation symbols + + - G, the LM, built from data/lm/G_3_gram.fst.txt + +The generated HLG is saved in $lang_dir/HLG_fst.pt + +So when to use this script instead of ./local/compile_hlg.py ? +If you have a very large G, ./local/compile_hlg.py may throw OOM for +determinization. In that case, you can use this script to compile HLG. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import kaldifst +import torch + +from icefall.lexicon import Lexicon + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--lang-dir", + type=str, + help="""Input and output directory. + """, + ) + + return parser.parse_args() + + +def compile_HLG(lang_dir: str) -> kaldifst.StdVectorFst: + """ + Args: + lang_dir: + The language directory, e.g., data/lang_phone or data/lang_bpe_5000. + + Return: + An FST representing HLG. + """ + + L = kaldifst.StdVectorFst.read(f"{lang_dir}/L_disambig.fst") + logging.info("Arc sort L") + kaldifst.arcsort(L, sort_type="olabel") + logging.info(f"L: #states {L.num_states}") + + G_filename_txt = "data/lm/G_3_gram.fst.txt" + G_filename_binary = "data/lm/G_3_gram.fst" + if Path(G_filename_binary).is_file(): + logging.info(f"Loading {G_filename_binary}") + G = kaldifst.StdVectorFst.read(G_filename_binary) + else: + logging.info(f"Loading {G_filename_txt}") + with open(G_filename_txt) as f: + G = kaldifst.compile(s=f.read(), acceptor=False) + logging.info(f"Saving G to {G_filename_binary}") + G.write(G_filename_binary) + + logging.info("Arc sort G") + kaldifst.arcsort(G, sort_type="ilabel") + + logging.info(f"G: #states {G.num_states}") + + logging.info("Compose L and G and connect LG") + LG = kaldifst.compose(L, G, connect=True) + logging.info(f"LG: #states {LG.num_states}") + + logging.info("Determinizestar LG") + kaldifst.determinize_star(LG) + logging.info(f"LG after determinize_star: #states {LG.num_states}") + + logging.info("Minimize encoded LG") + kaldifst.minimize_encoded(LG) + logging.info(f"LG after minimize_encoded: #states {LG.num_states}") + + logging.info("Converting LG to k2 format") + LG = k2.Fsa.from_openfst(LG.to_str(is_acceptor=False), acceptor=False) + logging.info(f"LG in k2: #states: {LG.shape[0]}, #arcs: {LG.num_arcs}") + + lexicon = Lexicon(lang_dir) + + first_token_disambig_id = lexicon.token_table["#0"] + first_word_disambig_id = lexicon.word_table["#0"] + logging.info(f"token id for #0: {first_token_disambig_id}") + logging.info(f"word id for #0: {first_word_disambig_id}") + + max_token_id = max(lexicon.tokens) + modified = False + logging.info( + f"Building ctc_topo. modified: {modified}, max_token_id: {max_token_id}" + ) + + H = k2.ctc_topo(max_token_id, modified=modified) + logging.info(f"H: #states: {H.shape[0]}, #arcs: {H.num_arcs}") + + logging.info("Removing disambiguation symbols on LG") + LG.labels[LG.labels >= first_token_disambig_id] = 0 + LG.aux_labels[LG.aux_labels >= first_word_disambig_id] = 0 + + # See https://github.com/k2-fsa/k2/issues/874 + # for why we need to set LG.properties to None + LG.__dict__["_properties"] = None + + logging.info("Removing epsilons from LG") + LG = k2.remove_epsilon(LG) + logging.info( + f"LG after k2.remove_epsilon: #states: {LG.shape[0]}, #arcs: {LG.num_arcs}" + ) + + logging.info("Connecting LG after removing epsilons") + LG = k2.connect(LG) + LG.aux_labels = LG.aux_labels.remove_values_eq(0) + logging.info(f"LG after k2.connect: #states: {LG.shape[0]}, #arcs: {LG.num_arcs}") + + logging.info("Arc sorting LG") + LG = k2.arc_sort(LG) + + logging.info("Composing H and LG") + + HLG = k2.compose(H, LG, inner_labels="tokens") + logging.info( + f"HLG after k2.compose: #states: {HLG.shape[0]}, #arcs: {HLG.num_arcs}" + ) + + logging.info("Connecting HLG") + HLG = k2.connect(HLG) + logging.info( + f"HLG after k2.connect: #states: {HLG.shape[0]}, #arcs: {HLG.num_arcs}" + ) + + logging.info("Arc sorting LG") + HLG = k2.arc_sort(HLG) + + return HLG + + +def main(): + args = get_args() + lang_dir = Path(args.lang_dir) + + filename = lang_dir / "HLG_fst.pt" + + if filename.is_file(): + logging.info(f"{filename} already exists - skipping") + return + + HLG = compile_HLG(lang_dir) + logging.info(f"Saving HLG to {filename}") + torch.save(HLG.as_dict(), filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 542bbcdd8..11c8e1066 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -44,9 +44,9 @@ dl_dir=$PWD/download # It will generate data/lang_bpe_xxx, # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( - 5000 - 2000 - 1000 + # 5000 + # 2000 + # 1000 500 ) @@ -168,6 +168,22 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then if [ ! -f $lang_dir/L_disambig.pt ]; then ./local/prepare_lang.py --lang-dir $lang_dir fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/disambig_L.fst + fi fi @@ -208,6 +224,22 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then --lexicon $lang_dir/lexicon.txt \ --bpe-model $lang_dir/bpe.model fi + + if [ ! -f $lang_dir/L.fst ]; then + log "Converting L.pt to L.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L.pt \ + $lang_dir/L.fst + fi + + if [ ! -f $lang_dir/L_disambig.fst ]; then + log "Converting L_disambig.pt to L_disambig.fst" + ./shared/convert-k2-to-openfst.py \ + --olabels aux_labels \ + $lang_dir/L_disambig.pt \ + $lang_dir/L_disambig.fst + fi done fi @@ -270,10 +302,13 @@ fi if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then log "Stage 9: Compile HLG" ./local/compile_hlg.py --lang-dir data/lang_phone + ./local/compile_hlg_using_openfst.py --lang-dir data/lang_phone for vocab_size in ${vocab_sizes[@]}; do lang_dir=data/lang_bpe_${vocab_size} ./local/compile_hlg.py --lang-dir $lang_dir + + ./local/compile_hlg_using_openfst.py --lang-dir $lang_dir done fi diff --git a/icefall/shared/convert-k2-to-openfst.py b/icefall/shared/convert-k2-to-openfst.py new file mode 100755 index 000000000..29a2cd7f7 --- /dev/null +++ b/icefall/shared/convert-k2-to-openfst.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes as input an FST in k2 format and convert it +to an FST in OpenFST format. + +The generated FST is saved into a binary file and its type is +StdVectorFst. + +Usage examples: +(1) Convert an acceptor + + ./convert-k2-to-openfst.py in.pt binary.fst + +(2) Convert a transducer + + ./convert-k2-to-openfst.py --olabels aux_labels in.pt binary.fst +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import kaldifst.utils +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--olabels", + type=str, + default=None, + help="""If not empty, the input FST is assumed to be a transducer + and we use its attribute specified by "olabels" as the output labels. + """, + ) + parser.add_argument( + "input_filename", + type=str, + help="Path to the input FST in k2 format", + ) + + parser.add_argument( + "output_filename", + type=str, + help="Path to the output FST in OpenFst format", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + logging.info(f"{vars(args)}") + + input_filename = args.input_filename + output_filename = args.output_filename + olabels = args.olabels + + if Path(output_filename).is_file(): + logging.info(f"{output_filename} already exists - skipping") + return + + assert Path(input_filename).is_file(), f"{input_filename} does not exist" + logging.info(f"Loading {input_filename}") + k2_fst = k2.Fsa.from_dict(torch.load(input_filename)) + if olabels: + assert hasattr(k2_fst, olabels), f"No such attribute: {olabels}" + + p = Path(output_filename).parent + if not p.is_dir(): + logging.info(f"Creating {p}") + p.mkdir(parents=True) + + logging.info("Converting (May take some time if the input FST is large)") + fst = kaldifst.utils.k2_to_openfst(k2_fst, olabels=olabels) + logging.info(f"Saving to {output_filename}") + fst.write(output_filename) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/requirements.txt b/requirements.txt index 5e32af853..a07f6b7c7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +kaldifst kaldilm kaldialign sentencepiece>=0.1.96 From a0cf85343dad31a678ddaac7652f0bb2bbb4cac2 Mon Sep 17 00:00:00 2001 From: Yifan Yang <64255737+yfyeung@users.noreply.github.com> Date: Fri, 9 Dec 2022 19:23:11 +0800 Subject: [PATCH 27/27] fix for memory usage in pruned_transducer_stateless7/scaling.py (#752) Co-authored-by: yifanyang --- egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 6f63e0629..042c9c3e4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -562,7 +562,7 @@ class ActivationBalancer(torch.nn.Module): sign_factor = None scale_factor = _compute_scale_factor( - x, + x.detach(), self.channel_dim, min_abs=self.min_abs, max_abs=self.max_abs,