From a42d96dfe047e28a9cd5463b33246f4456e92cdb Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Mon, 20 Jun 2022 13:40:01 +0800 Subject: [PATCH 1/9] Fix warmup (#435) * fix warmup when scan_pessimistic_batches_for_oom * delete comments --- .../ASR/conv_emformer_transducer_stateless/train.py | 7 +++---- .../ASR/pruned_transducer_stateless2/train.py | 7 +++---- .../ASR/pruned_transducer_stateless3/train.py | 7 +++---- .../ASR/pruned_transducer_stateless4/train.py | 7 +++---- .../ASR/pruned_transducer_stateless5/train.py | 7 +++---- .../ASR/pruned_transducer_stateless6/train.py | 11 +++++------ 6 files changed, 20 insertions(+), 26 deletions(-) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index 106f3e511..acaf1397f 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -1018,6 +1018,7 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, ) scaler = GradScaler(enabled=params.use_fp16) @@ -1078,6 +1079,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + warmup: float, ): from lhotse.dataset import find_pessimistic_batches @@ -1088,9 +1090,6 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, @@ -1098,7 +1097,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup=0.0, + warmup=warmup, ) loss.backward() optimizer.step() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 36ee7ca74..55f32e119 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -883,6 +883,7 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, + warmup=0.0 if params.start_epoch == 0 else 1.0, ) scaler = GradScaler(enabled=params.use_fp16) @@ -973,6 +974,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + warmup: float, ): from lhotse.dataset import find_pessimistic_batches @@ -983,9 +985,6 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, @@ -993,7 +992,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup=0.0, + warmup=warmup, ) loss.backward() optimizer.step() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 92eae78d1..be9fa8f8b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -1001,6 +1001,7 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, + warmup=0.0 if params.start_epoch == 0 else 1.0, ) scaler = GradScaler(enabled=params.use_fp16) @@ -1061,6 +1062,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + warmup: float, ): from lhotse.dataset import find_pessimistic_batches @@ -1071,9 +1073,6 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, @@ -1081,7 +1080,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup=0.0, + warmup=warmup, ) loss.backward() optimizer.step() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 48c0e683d..0fece2464 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -932,6 +932,7 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, ) scaler = GradScaler(enabled=params.use_fp16) @@ -992,6 +993,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + warmup: float, ): from lhotse.dataset import find_pessimistic_batches @@ -1002,9 +1004,6 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, @@ -1012,7 +1011,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup=0.0, + warmup=warmup, ) loss.backward() optimizer.step() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index e77eb19ff..eaf893997 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -980,6 +980,7 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, ) scaler = GradScaler(enabled=params.use_fp16) @@ -1072,6 +1073,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + warmup: float, ): from lhotse.dataset import find_pessimistic_batches @@ -1082,9 +1084,6 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, @@ -1092,7 +1091,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup=0.0, + warmup=warmup, ) loss.backward() optimizer.step() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index 315c01c8e..9e9fc1440 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -74,9 +74,9 @@ from conformer import Conformer from decoder import Decoder from joiner import Joiner from lhotse.cut import Cut, MonoCut +from lhotse.dataset.collation import collate_custom_field from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed -from lhotse.dataset.collation import collate_custom_field from model import Transducer from optim import Eden, Eve from torch import Tensor @@ -376,7 +376,7 @@ def get_params() -> AttributeDict: "distillation_layer": 5, # 0-based index # Since output rate of hubert is 50, while that of encoder is 8, # two successive codebook_index are concatenated together. - # Detailed in function Transducer::concat_sucessive_codebook_indexes. + # Detailed in function Transducer::concat_sucessive_codebook_indexes "num_codebooks": 16, # used to construct distillation loss } ) @@ -988,6 +988,7 @@ def run(rank, world_size, args): optimizer=optimizer, sp=sp, params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, ) scaler = GradScaler(enabled=params.use_fp16) @@ -1048,6 +1049,7 @@ def scan_pessimistic_batches_for_oom( optimizer: torch.optim.Optimizer, sp: spm.SentencePieceProcessor, params: AttributeDict, + warmup: float, ): from lhotse.dataset import find_pessimistic_batches @@ -1058,9 +1060,6 @@ def scan_pessimistic_batches_for_oom( for criterion, cuts in batches.items(): batch = train_dl.dataset[cuts] try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, _ = compute_loss( params=params, @@ -1068,7 +1067,7 @@ def scan_pessimistic_batches_for_oom( sp=sp, batch=batch, is_training=True, - warmup=0.0, + warmup=warmup, ) loss.backward() optimizer.step() From 998091ef52ceb46e4efe57dfc28ebce3c4edc10d Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Mon, 20 Jun 2022 14:57:08 +0800 Subject: [PATCH 2/9] do some changes for export.py (#437) --- .../ASR/pruned_transducer_stateless2/export.py | 7 +++++-- egs/aishell/ASR/transducer_stateless/export.py | 7 +++++-- .../ASR/transducer_stateless_modified-2/export.py | 7 +++++-- egs/aishell/ASR/transducer_stateless_modified/export.py | 7 +++++-- egs/aishell4/ASR/pruned_transducer_stateless5/export.py | 7 +++++-- .../ASR/pruned_transducer_stateless2/export.py | 7 +++++-- .../ASR/pruned_transducer_stateless2/export.py | 7 +++++-- .../ASR/pruned_transducer_stateless2/export.py | 7 +++++-- egs/tedlium3/ASR/pruned_transducer_stateless/export.py | 7 +++++-- egs/tedlium3/ASR/transducer_stateless/export.py | 9 ++++++--- .../ASR/pruned_transducer_stateless2/export.py | 7 +++++-- 11 files changed, 56 insertions(+), 23 deletions(-) diff --git a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py index 43033e517..00b54c39f 100644 --- a/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py +++ b/egs/aidatatang_200zh/ASR/pruned_transducer_stateless2/export.py @@ -114,8 +114,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -155,6 +153,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/aishell/ASR/transducer_stateless/export.py b/egs/aishell/ASR/transducer_stateless/export.py index 591b333e0..4c6519b96 100755 --- a/egs/aishell/ASR/transducer_stateless/export.py +++ b/egs/aishell/ASR/transducer_stateless/export.py @@ -184,8 +184,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -225,6 +223,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/export.py b/egs/aishell/ASR/transducer_stateless_modified-2/export.py index d009de603..3bd2ceb11 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/export.py @@ -182,8 +182,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: def main(): args = get_parser().parse_args() - assert args.jit is False, "torchscript support will be added later" - params = get_params() params.update(vars(args)) @@ -223,6 +221,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/aishell/ASR/transducer_stateless_modified/export.py b/egs/aishell/ASR/transducer_stateless_modified/export.py index 9a20fab6f..11335a834 100755 --- a/egs/aishell/ASR/transducer_stateless_modified/export.py +++ b/egs/aishell/ASR/transducer_stateless_modified/export.py @@ -182,8 +182,6 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: def main(): args = get_parser().parse_args() - assert args.jit is False, "torchscript support will be added later" - params = get_params() params.update(vars(args)) @@ -223,6 +221,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py index f487a8ba5..f42a85373 100755 --- a/egs/aishell4/ASR/pruned_transducer_stateless5/export.py +++ b/egs/aishell4/ASR/pruned_transducer_stateless5/export.py @@ -149,8 +149,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -252,6 +250,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py index 0a69e0a57..8beec1b8a 100644 --- a/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py +++ b/egs/alimeeting/ASR/pruned_transducer_stateless2/export.py @@ -114,8 +114,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -155,6 +153,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py index 6b3a7a9ff..cff9c7377 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py @@ -131,8 +131,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -191,6 +189,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py index 6119ecf2c..77faa3c0e 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py @@ -130,8 +130,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -178,6 +176,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py index 1e6edbb99..a1c3bcea3 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py @@ -117,8 +117,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -161,6 +159,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/tedlium3/ASR/transducer_stateless/export.py b/egs/tedlium3/ASR/transducer_stateless/export.py index f2bfa2ec9..c32b1d002 100644 --- a/egs/tedlium3/ASR/transducer_stateless/export.py +++ b/egs/tedlium3/ASR/transducer_stateless/export.py @@ -1,7 +1,7 @@ #!/usr/bin/env python3 # # Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang -# Mingshuang Luo) +# Mingshuang Luo) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -185,8 +185,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -229,6 +227,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py index 8c4f92c81..345792a3c 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py @@ -114,8 +114,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -155,6 +153,11 @@ def main(): model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" From 91b2765cfd39fea1783cc07bb391e27bc2d5224e Mon Sep 17 00:00:00 2001 From: 2xwwx2 Date: Mon, 20 Jun 2022 16:41:04 +0800 Subject: [PATCH 3/9] Fixs spelling mistake (#438) --- egs/aishell/ASR/prepare.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 26324b0af..da0a1470e 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -18,7 +18,7 @@ stop_stage=10 # This directory contains the language model downloaded from # https://huggingface.co/pkufool/aishell_lm # -# - 3-gram.unpruned.apra +# - 3-gram.unpruned.arpa # # - $dl_dir/musan # This directory contains the following directories downloaded from From d3daeaf5cd6712fca63b98db960ee71ff211a59e Mon Sep 17 00:00:00 2001 From: Zengwei Yao Date: Tue, 21 Jun 2022 19:16:59 +0800 Subject: [PATCH 4/9] Upload extracted codebook indexes (#429) * save only vq-related info to manifest * support to join manifest files * support using extracted codebook indexes * fix doc * minor fix * add enable-distillation argument option, fix monir typos * fix style * fix typo --- .../ASR/distillation_with_hubert.sh | 119 +++++++++++++----- .../pruned_transducer_stateless6/decode.py | 9 +- .../extract_codebook_index.py | 19 ++- .../ASR/pruned_transducer_stateless6/train.py | 22 ++-- .../pruned_transducer_stateless6/vq_utils.py | 80 +++++++++--- 5 files changed, 189 insertions(+), 60 deletions(-) mode change 100644 => 100755 egs/librispeech/ASR/distillation_with_hubert.sh diff --git a/egs/librispeech/ASR/distillation_with_hubert.sh b/egs/librispeech/ASR/distillation_with_hubert.sh old mode 100644 new mode 100755 index e18ba8f55..3d4c4856a --- a/egs/librispeech/ASR/distillation_with_hubert.sh +++ b/egs/librispeech/ASR/distillation_with_hubert.sh @@ -1,3 +1,5 @@ +#!/usr/bin/env bash +# # A short introduction about distillation framework. # # A typical traditional distillation method is @@ -14,15 +16,15 @@ # teacher embeddings. # 3. a middle layer 6(1-based) out of total 6 layers is used to extract # student embeddings. - -# This is an example to do distillation with librispeech clean-100 subset. -# run with command: -# bash distillation_with_hubert.sh [0|1|2|3|4] # -# For example command -# bash distillation_with_hubert.sh 0 -# will download hubert model. -stage=$1 +# To directly download the extracted codebook indexes for model distillation, you can +# set stage=2, stop_stage=4, use_extracted_codebook=True +# +# To start from scratch, you can +# set stage=0, stop_stage=4, use_extracted_codebook=False + +stage=0 +stop_stage=4 # Set the GPUs available. # This script requires at least one GPU. @@ -33,10 +35,35 @@ stage=$1 # export CUDA_VISIBLE_DEVICES="0" # # Suppose GPU 2,3,4,5 are available. -export CUDA_VISIBLE_DEVICES="2,3,4,5" +export CUDA_VISIBLE_DEVICES="0,1,2,3" +exp_dir=./pruned_transducer_stateless6/exp +mkdir -p $exp_dir -if [ $stage -eq 0 ]; then +# full_libri can be "True" or "False" +# "True" -> use full librispeech dataset for distillation +# "False" -> use train-clean-100 subset for distillation +full_libri=False + +# use_extracted_codebook can be "True" or "False" +# "True" -> stage 0 and stage 1 would be skipped, +# and directly download the extracted codebook indexes for distillation +# "False" -> start from scratch +use_extracted_codebook=False + +# teacher_model_id can be one of +# "hubert_xtralarge_ll60k_finetune_ls960" -> fine-tuned model, it is the one we currently use. +# "hubert_xtralarge_ll60k" -> pretrained model without fintuing +teacher_model_id=hubert_xtralarge_ll60k_finetune_ls960 + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ] && [ ! "$use_extracted_codebook" == "True" ]; then + log "Stage 0: Download HuBERT model" # Preparation stage. # Install fairseq according to: @@ -45,7 +72,7 @@ if [ $stage -eq 0 ]; then # commit 806855bf660ea748ed7ffb42fe8dcc881ca3aca0 is used. has_fairseq=$(python3 -c "import importlib; print(importlib.util.find_spec('fairseq') is not None)") if [ $has_fairseq == 'False' ]; then - echo "Please install fairseq before running following stages" + log "Please install fairseq before running following stages" exit 1 fi @@ -56,42 +83,41 @@ if [ $stage -eq 0 ]; then has_quantization=$(python3 -c "import importlib; print(importlib.util.find_spec('quantization') is not None)") if [ $has_quantization == 'False' ]; then - echo "Please install quantization before running following stages" + log "Please install quantization before running following stages" exit 1 fi - echo "Download hubert model." + log "Download HuBERT model." # Parameters about model. - exp_dir=./pruned_transducer_stateless6/exp/ - model_id=hubert_xtralarge_ll60k_finetune_ls960 hubert_model_dir=${exp_dir}/hubert_models - hubert_model=${hubert_model_dir}/${model_id}.pt + hubert_model=${hubert_model_dir}/${teacher_model_id}.pt mkdir -p ${hubert_model_dir} # For more models refer to: https://github.com/pytorch/fairseq/tree/main/examples/hubert if [ -f ${hubert_model} ]; then - echo "hubert model alread exists." + log "HuBERT model alread exists." else - wget -c https://dl.fbaipublicfiles.com/hubert/${model_id} -P ${hubert_model} + wget -c https://dl.fbaipublicfiles.com/hubert/${teacher_model_id}.pt -P ${hubert_model_dir} wget -c wget https://dl.fbaipublicfiles.com/fairseq/wav2vec/dict.ltr.txt -P ${hubert_model_dir} fi fi if [ ! -d ./data/fbank ]; then - echo "This script assumes ./data/fbank is already generated by prepare.sh" + log "This script assumes ./data/fbank is already generated by prepare.sh" exit 1 fi -if [ $stage -eq 1 ]; then +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ] && [ ! "$use_extracted_codebook" == "True" ]; then + log "Stage 1: Verify that the downloaded HuBERT model is correct." # This stage is not directly used by codebook indexes extraction. # It is a method to "prove" that the downloaed hubert model # is inferenced in an correct way if WERs look like normal. # Expect WERs: # [test-clean-ctc_greedy_search] %WER 2.04% [1075 / 52576, 92 ins, 104 del, 879 sub ] # [test-other-ctc_greedy_search] %WER 3.71% [1942 / 52343, 152 ins, 126 del, 1664 sub ] - ./pruned_transducer_stateless6/hubert_decode.py + ./pruned_transducer_stateless6/hubert_decode.py --exp-dir $exp_dir fi -if [ $stage -eq 2 ]; then +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then # Analysis of disk usage: # With num_codebooks==8, each teacher embedding is quantized into # a sequence of eight 8-bit integers, i.e. only eight bytes are needed. @@ -113,25 +139,61 @@ if [ $stage -eq 2 ]; then # During quantizer's training data(teacher embedding) and it's training, # only the first ONE GPU is used. # During codebook indexes extraction, ALL GPUs set by CUDA_VISIBLE_DEVICES are used. + + if [ "$use_extracted_codebook" == "True" ]; then + if [ ! "$teacher_model_id" == "hubert_xtralarge_ll60k_finetune_ls960" ]; then + log "Currently we only uploaded codebook indexes from teacher model hubert_xtralarge_ll60k_finetune_ls960" + exit 1 + fi + mkdir -p $exp_dir/vq + codebook_dir=$exp_dir/vq/$teacher_model_id + mkdir -p codebook_dir + codebook_download_dir=$exp_dir/download_codebook + if [ -d $codebook_download_dir ]; then + log "$codebook_download_dir exists, you should remove it first." + exit 1 + fi + log "Downloading extracted codebook indexes to $codebook_download_dir" + # Make sure you have git-lfs installed (https://git-lfs.github.com) + git lfs install + git clone https://huggingface.co/Zengwei/pruned_transducer_stateless6_hubert_xtralarge_ll60k_finetune_ls960 $codebook_download_dir + + mkdir -p data/vq_fbank + mv $codebook_download_dir/*.jsonl.gz data/vq_fbank/ + mkdir -p $codebook_dir/splits4 + mv $codebook_download_dir/*.h5 $codebook_dir/splits4/ + log "Remove $codebook_download_dir" + rm -rf $codebook_download_dir + fi + ./pruned_transducer_stateless6/extract_codebook_index.py \ - --full-libri False + --full-libri $full_libri \ + --exp-dir $exp_dir \ + --embedding-layer 36 \ + --num-utts 1000 \ + --num-codebooks 8 \ + --max-duration 100 \ + --teacher-model-id $teacher_model_id \ + --use-extracted-codebook $use_extracted_codebook fi -if [ $stage -eq 3 ]; then +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then # Example training script. # Note: it's better to set spec-aug-time-warpi-factor=-1 WORLD_SIZE=$(echo ${CUDA_VISIBLE_DEVICES} | awk '{n=split($1, _, ","); print n}') ./pruned_transducer_stateless6/train.py \ --manifest-dir ./data/vq_fbank \ --master-port 12359 \ - --full-libri False \ + --full-libri $full_libri \ --spec-aug-time-warp-factor -1 \ --max-duration 300 \ --world-size ${WORLD_SIZE} \ - --num-epochs 20 + --num-epochs 20 \ + --exp-dir $exp_dir \ + --enable-distillation True fi -if [ $stage -eq 4 ]; then +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then # Results should be similar to: # errs-test-clean-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 5.67 # errs-test-other-beam_size_4-epoch-20-avg-10-beam-4.txt:%WER = 15.60 @@ -140,5 +202,6 @@ if [ $stage -eq 4 ]; then --epoch 20 \ --avg 10 \ --max-duration 200 \ - --exp-dir ./pruned_transducer_stateless6/exp + --exp-dir $exp_dir \ + --enable-distillation True fi diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 4739a6526..701cad73c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -128,7 +128,7 @@ def get_parser(): parser.add_argument( "--use-averaged-model", type=str2bool, - default=False, + default=True, help="Whether to load averaged model. Currently it only supports " "using --epoch. If True, it would decode with the averaged model " "over the epoch range from `epoch-avg` (excluded) to `epoch`." @@ -143,6 +143,13 @@ def get_parser(): help="The experiment dir", ) + parser.add_argument( + "--enable-distillation", + type=str2bool, + default=True, + help="Whether to eanble distillation.", + ) + parser.add_argument( "--bpe-model", type=str, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py index c5c172ff2..21409287c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py @@ -24,7 +24,7 @@ import torch from vq_utils import CodebookIndexExtractor from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned -from icefall.utils import AttributeDict +from icefall.utils import AttributeDict, str2bool def get_parser(): @@ -38,6 +38,13 @@ def get_parser(): help="The experiment dir", ) + parser.add_argument( + "--use-extracted-codebook", + type=str2bool, + default=False, + help="Whether to use the extracted codebook indexes.", + ) + return parser @@ -71,9 +78,13 @@ def main(): params.world_size = world_size extractor = CodebookIndexExtractor(params=params) - extractor.extract_and_save_embedding() - extractor.train_quantizer() - extractor.extract_codebook_indexes() + if not params.use_extracted_codebook: + extractor.extract_and_save_embedding() + extractor.train_quantizer() + extractor.extract_codebook_indexes() + + extractor.reuse_manifests() + extractor.join_manifests() if __name__ == "__main__": diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index 9e9fc1440..b904e1e59 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -41,7 +41,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3" --full-libri 1 \ --max-duration 550 -# For distiallation with codebook_indexes: +# For distillation with codebook_indexes: ./pruned_transducer_stateless6/train.py \ --manifest-dir ./data/vq_fbank \ @@ -300,6 +300,13 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--enable-distillation", + type=str2bool, + default=True, + help="Whether to eanble distillation.", + ) + return parser @@ -372,7 +379,6 @@ def get_params() -> AttributeDict: "model_warm_step": 3000, # arg given to model, not for lrate "env_info": get_env_info(), # parameters for distillation with codebook indexes. - "enable_distiallation": True, "distillation_layer": 5, # 0-based index # Since output rate of hubert is 50, while that of encoder is 8, # two successive codebook_index are concatenated together. @@ -394,7 +400,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: dim_feedforward=params.dim_feedforward, num_encoder_layers=params.num_encoder_layers, middle_output_layer=params.distillation_layer - if params.enable_distiallation + if params.enable_distillation else None, ) return encoder @@ -433,9 +439,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, - num_codebooks=params.num_codebooks - if params.enable_distiallation - else 0, + num_codebooks=params.num_codebooks if params.enable_distillation else 0, ) return model @@ -615,7 +619,7 @@ def compute_loss( y = k2.RaggedTensor(y).to(device) info = MetricsTracker() - if is_training and params.enable_distiallation: + if is_training and params.enable_distillation: codebook_indexes, _ = extract_codebook_indexes(batch) codebook_indexes = codebook_indexes.to(device) else: @@ -645,7 +649,7 @@ def compute_loss( params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss ) - if is_training and params.enable_distiallation: + if is_training and params.enable_distillation: assert codebook_loss is not None loss += params.codebook_loss_scale * codebook_loss @@ -661,7 +665,7 @@ def compute_loss( info["loss"] = loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item() - if is_training and params.enable_distiallation: + if is_training and params.enable_distillation: info["codebook_loss"] = codebook_loss.detach().cpu().item() return loss, info diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index c4935f921..e3dcd039b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -37,6 +37,7 @@ from icefall.utils import ( setup_logger, ) from lhotse import CutSet, load_manifest +from lhotse.cut import MonoCut from lhotse.features.io import NumpyHdf5Writer @@ -62,16 +63,15 @@ class CodebookIndexExtractor: setup_logger(f"{self.vq_dir}/log-vq_extraction") def init_dirs(self): - # vq_dir is the root dir for quantizer: - # training data/ quantizer / extracted codebook indexes + # vq_dir is the root dir for quantization, containing: + # training data, trained quantizer, and extracted codebook indexes self.vq_dir = ( self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" ) self.vq_dir.mkdir(parents=True, exist_ok=True) - # manifest_dir for : - # splited original manifests, - # extracted codebook indexes and their related manifests + # manifest_dir contains: + # splited original manifests, extracted codebook indexes with related manifests # noqa self.manifest_dir = self.vq_dir / f"splits{self.params.world_size}" self.manifest_dir.mkdir(parents=True, exist_ok=True) @@ -135,6 +135,7 @@ class CodebookIndexExtractor: logging.warn(warn_message) return + logging.info("Start to extract embeddings for training the quantizer.") total_cuts = 0 with NumpyHdf5Writer(self.embedding_file_path) as writer: for batch_idx, batch in enumerate(self.quantizer_train_dl): @@ -187,14 +188,15 @@ class CodebookIndexExtractor: return assert self.embedding_file_path.exists() + logging.info("Start to train quantizer.") trainer = quantization.QuantizerTrainer( dim=self.params.embedding_dim, bytes_per_frame=self.params.num_codebooks, device=self.params.device, ) train, valid = quantization.read_hdf5_data(self.embedding_file_path) - B = 512 # Minibatch size, this is very arbitrary, it's close to what we used - # when we tuned this method. + B = 512 # Minibatch size, this is very arbitrary, + # it's close to what we used when we tuned this method. def minibatch_generator(data: torch.Tensor, repeat: bool): assert 3 * B < data.shape[0] @@ -222,18 +224,50 @@ class CodebookIndexExtractor: """ for subset in self.params.subsets: logging.info(f"About to split {subset}.") - ori_manifest = f"./data/fbank/cuts_train-{subset}.json.gz" + ori_manifest = ( + f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + ) split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}" os.system(f"{split_cmd}") + def join_manifests(self): + """ + Join the vq manifest to the original manifest according to cut id. + """ + logging.info("Start to join manifest files.") + for subset in self.params.subsets: + vq_manifest_path = ( + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + ) + ori_manifest_path = ( + self.ori_manifest_dir + / f"librispeech_cuts_train-{subset}.jsonl.gz" + ) + dst_vq_manifest_path = ( + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}.jsonl.gz" + ) + cuts_vq = load_manifest(vq_manifest_path) + cuts_ori = load_manifest(ori_manifest_path) + cuts_vq = cuts_vq.sort_like(cuts_ori) + for cut_idx, (cut_vq, cut_ori) in enumerate(zip(cuts_vq, cuts_ori)): + assert cut_vq.id == cut_ori.id + cut_ori.codebook_indexes = cut_vq.codebook_indexes + + CutSet.from_cuts(cuts_ori).to_jsonl(dst_vq_manifest_path) + logging.info(f"Processed {subset}.") + logging.info(f"Saved to {dst_vq_manifest_path}.") + def merge_vq_manifests(self): """ Merge generated vq included manfiests and storage to self.dst_manifest_dir. """ for subset in self.params.subsets: - vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-cuts_train-{subset}*.json.gz" + vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" dst_vq_manifest = ( - self.dst_manifest_dir / f"cuts_train-{subset}.json.gz" + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) if 1 == self.params.world_size: merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}" @@ -273,7 +307,6 @@ class CodebookIndexExtractor: os.symlink(ori_manifest_path, dst_manifest_path) def create_vq_fbank(self): - self.reuse_manifests() self.merge_vq_manifests() @cached_property @@ -294,11 +327,13 @@ class CodebookIndexExtractor: def load_ori_dl(self, subset): if self.params.world_size == 1: - ori_manifest_path = f"./data/fbank/cuts_train-{subset}.json.gz" + ori_manifest_path = ( + f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + ) else: ori_manifest_path = ( self.manifest_dir - / f"cuts_train-{subset}.{self.params.manifest_index}.json.gz" + / f"librispeech_cuts_train-{subset}.{self.params.manifest_index}.jsonl.gz" # noqa ) cuts = load_manifest(ori_manifest_path) @@ -311,6 +346,7 @@ class CodebookIndexExtractor: torch.cuda.empty_cache() def extract_codebook_indexes(self): + logging.info("Start to extract codebook indexes.") if self.params.world_size == 1: self.extract_codebook_indexes_imp() else: @@ -333,7 +369,7 @@ class CodebookIndexExtractor: def extract_codebook_indexes_imp(self): for subset in self.params.subsets: num_cuts = 0 - cuts = [] + new_cuts = [] if self.params.world_size == 1: manifest_file_id = f"{subset}" else: @@ -356,15 +392,23 @@ class CodebookIndexExtractor: assert len(cut_list) == codebook_indexes.shape[0] assert all(c.start == 0 for c in supervisions["cut"]) + new_cut_list = [] for idx, cut in enumerate(cut_list): - cut.codebook_indexes = writer.store_array( + new_cut = MonoCut( + id=cut.id, + start=cut.start, + duration=cut.duration, + channel=cut.channel, + ) + new_cut.codebook_indexes = writer.store_array( key=cut.id, value=codebook_indexes[idx][: num_frames[idx]], frame_shift=0.02, temporal_dim=0, start=0, ) - cuts += cut_list + new_cut_list.append(new_cut) + new_cuts += new_cut_list num_cuts += len(cut_list) message = f"Processed {num_cuts} cuts from {subset}" if self.params.world_size > 1: @@ -373,9 +417,9 @@ class CodebookIndexExtractor: json_file_path = ( self.manifest_dir - / f"with_codebook_indexes-cuts_train-{manifest_file_id}.json.gz" + / f"with_codebook_indexes-librispeech-cuts_train-{manifest_file_id}.jsonl.gz" # noqa ) - CutSet.from_cuts(cuts).to_json(json_file_path) + CutSet.from_cuts(new_cuts).to_jsonl(json_file_path) @torch.no_grad() From 7100c33820c8c478e07d3435e25e4f1543b6eec7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Tue, 21 Jun 2022 21:17:22 +0800 Subject: [PATCH 5/9] Add pruned RNN-T for aishell. (#436) * Add pruned RNN-T for aishell. * support torch script. * Update CI. * Minor fixes. * Add links to sherpa. --- ...pruned-transducer-stateless3-2022-06-20.sh | 86 ++ .github/workflows/run-aishell-2022-06-20.yml | 119 ++ egs/aishell/ASR/README.md | 3 + egs/aishell/ASR/RESULTS.md | 85 +- .../aidatatang_200zh.py | 1 + .../pruned_transducer_stateless3/aishell.py | 1 + .../asr_datamodule.py | 1 + .../beam_search.py | 1 + .../pruned_transducer_stateless3/conformer.py | 1 + .../pruned_transducer_stateless3/decode.py | 637 +++++++++ .../pruned_transducer_stateless3/decoder.py | 1 + .../encoder_interface.py | 1 + .../exp-context-size-1 | 1 + .../pruned_transducer_stateless3/export.py | 277 ++++ .../pruned_transducer_stateless3/joiner.py | 1 + .../ASR/pruned_transducer_stateless3/model.py | 236 ++++ .../ASR/pruned_transducer_stateless3/optim.py | 1 + .../pretrained.py | 337 +++++ .../pruned_transducer_stateless3/scaling.py | 1 + .../ASR/pruned_transducer_stateless3/train.py | 1229 +++++++++++++++++ .../transducer_stateless_modified-2/train.py | 2 +- .../pruned_transducer_stateless2/decoder.py | 3 + .../pruned_transducer_stateless5/conformer.py | 39 +- .../pruned_transducer_stateless5/export.py | 9 +- 24 files changed, 3055 insertions(+), 18 deletions(-) create mode 100755 .github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh create mode 100644 .github/workflows/run-aishell-2022-06-20.yml create mode 120000 egs/aishell/ASR/pruned_transducer_stateless3/aidatatang_200zh.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless3/aishell.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless3/asr_datamodule.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless3/beam_search.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless3/conformer.py create mode 100755 egs/aishell/ASR/pruned_transducer_stateless3/decode.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless3/decoder.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless3/encoder_interface.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 create mode 100755 egs/aishell/ASR/pruned_transducer_stateless3/export.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless3/joiner.py create mode 100644 egs/aishell/ASR/pruned_transducer_stateless3/model.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless3/optim.py create mode 100755 egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py create mode 120000 egs/aishell/ASR/pruned_transducer_stateless3/scaling.py create mode 100755 egs/aishell/ASR/pruned_transducer_stateless3/train.py diff --git a/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh new file mode 100755 index 000000000..cf35f711b --- /dev/null +++ b/.github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh @@ -0,0 +1,86 @@ +#!/usr/bin/env bash + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +cd egs/aishell/ASR + +git lfs install + +fbank_url=https://huggingface.co/csukuangfj/aishell-test-dev-manifests +log "Downloading pre-commputed fbank from $fbank_url" + +git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests +ln -s $PWD/aishell-test-dev-manifests/data . + +log "Downloading pre-trained model from $repo_url" +repo_url=https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20 +git clone $repo_url +repo=$(basename $repo_url) + +log "Display test files" +tree $repo/ +soxi $repo/test_wavs/*.wav +ls -lh $repo/test_wavs/*.wav + +pushd $repo/exp +ln -s pretrained-epoch-29-avg-5-torch-1.10.pt pretrained.pt +popd + +for sym in 1 2 3; do + log "Greedy search with --max-sym-per-frame $sym" + + ./pruned_transducer_stateless3/pretrained.py \ + --method greedy_search \ + --max-sym-per-frame $sym \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $rep/test_wavs/BAC009S0764W0123.wav +done + +for method in modified_beam_search beam_search fast_beam_search; do + log "$method" + + ./pruned_transducer_stateless3/pretrained.py \ + --method $method \ + --beam-size 4 \ + --checkpoint $repo/exp/pretrained.pt \ + --lang-dir $repo/data/lang_char \ + $repo/test_wavs/BAC009S0764W0121.wav \ + $repo/test_wavs/BAC009S0764W0122.wav \ + $rep/test_wavs/BAC009S0764W0123.wav +done + +echo "GITHUB_EVENT_NAME: ${GITHUB_EVENT_NAME}" +echo "GITHUB_EVENT_LABEL_NAME: ${GITHUB_EVENT_LABEL_NAME}" +if [[ x"${GITHUB_EVENT_NAME}" == x"schedule" || x"${GITHUB_EVENT_LABEL_NAME}" == x"run-decode" ]]; then + mkdir -p pruned_transducer_stateless3/exp + ln -s $PWD/$repo/exp/pretrained.pt pruned_transducer_stateless3/exp/epoch-999.pt + ln -s $PWD/$repo/data/lang_char data/ + + ls -lh data + ls -lh pruned_transducer_stateless3/exp + + log "Decoding test and dev" + + # use a small value for decoding with CPU + max_duration=100 + + for method in greedy_search fast_beam_search modified_beam_search; do + log "Decoding with $method" + + ./pruned_transducer_stateless3/decode.py \ + --decoding-method $method \ + --epoch 999 \ + --avg 1 \ + --max-duration $max_duration \ + --exp-dir pruned_transducer_stateless3/exp + done + + rm pruned_transducer_stateless3/exp/*.pt +fi diff --git a/.github/workflows/run-aishell-2022-06-20.yml b/.github/workflows/run-aishell-2022-06-20.yml new file mode 100644 index 000000000..e684e598e --- /dev/null +++ b/.github/workflows/run-aishell-2022-06-20.yml @@ -0,0 +1,119 @@ +# Copyright 2022 Fangjun Kuang (csukuangfj@gmail.com) + +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +name: run-aishell-2022-06-20 +# pruned RNN-T + reworked model with random combiner +# https://huggingface.co/csukuangfj/icefall-aishell-pruned-transducer-stateless3-2022-06-20 + +on: + push: + branches: + - master + pull_request: + types: [labeled] + + schedule: + # minute (0-59) + # hour (0-23) + # day of the month (1-31) + # month (1-12) + # day of the week (0-6) + # nightly build at 15:50 UTC time every day + - cron: "50 15 * * *" + +jobs: + run_aishell_2022_06_20: + if: github.event.label.name == 'ready' || github.event.label.name == 'run-decode' || github.event_name == 'push' || github.event_name == 'schedule' + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-18.04] + python-version: [3.7, 3.8, 3.9] + + fail-fast: false + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 0 + + - name: Setup Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + cache-dependency-path: '**/requirements-ci.txt' + + - name: Install Python dependencies + run: | + grep -v '^#' ./requirements-ci.txt | xargs -n 1 -L 1 pip install + pip uninstall -y protobuf + pip install --no-binary protobuf protobuf + + - name: Cache kaldifeat + id: my-cache + uses: actions/cache@v2 + with: + path: | + ~/tmp/kaldifeat + key: cache-tmp-${{ matrix.python-version }} + + - name: Install kaldifeat + if: steps.my-cache.outputs.cache-hit != 'true' + shell: bash + run: | + .github/scripts/install-kaldifeat.sh + + - name: Inference with pre-trained model + shell: bash + env: + GITHUB_EVENT_NAME: ${{ github.event_name }} + GITHUB_EVENT_LABEL_NAME: ${{ github.event.label.name }} + run: | + sudo apt-get -qq install git-lfs tree sox + export PYTHONPATH=$PWD:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/kaldifeat/python:$PYTHONPATH + export PYTHONPATH=~/tmp/kaldifeat/build/lib:$PYTHONPATH + + .github/scripts/run-aishell-pruned-transducer-stateless3-2022-06-20.sh + + - name: Display decoding results for aishell pruned_transducer_stateless3 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + shell: bash + run: | + cd egs/aishell/ASR/ + tree ./pruned_transducer_stateless3/exp + + cd pruned_transducer_stateless3 + echo "results for pruned_transducer_stateless3" + echo "===greedy search===" + find exp/greedy_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2 + find exp/greedy_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2 + + echo "===fast_beam_search===" + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2 + find exp/fast_beam_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2 + + echo "===modified beam search===" + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for test" {} + | sort -n -k2 + find exp/modified_beam_search -name "log-*" -exec grep -n --color "best for dev" {} + | sort -n -k2 + + - name: Upload decoding results for aishell pruned_transducer_stateless3 + uses: actions/upload-artifact@v2 + if: github.event_name == 'schedule' || github.event.label.name == 'run-decode' + with: + name: aishell-torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu-pruned_transducer_stateless3-2022-06-20 + path: egs/aishell/ASR/pruned_transducer_stateless3/exp/ diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md index d0a0c1829..75fc6326e 100644 --- a/egs/aishell/ASR/README.md +++ b/egs/aishell/ASR/README.md @@ -4,6 +4,8 @@ Please refer to for how to run models in this recipe. + + # Transducers There are various folders containing the name `transducer` in this folder. @@ -14,6 +16,7 @@ The following table lists the differences among them. | `transducer_stateless` | Conformer | Embedding + Conv1d | with `k2.rnnt_loss` | | `transducer_stateless_modified` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` | | `transducer_stateless_modified-2` | Conformer | Embedding + Conv1d | with modified transducer from `optimized_transducer` + extra data | +| `pruned_transducer_stateless3` | Conformer (reworked) | Embedding + Conv1d | pruned RNN-T + reworked model with random combiner + using aidatatang_20zh as extra data| The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index ecc93c21b..b420a1982 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -1,10 +1,93 @@ ## Results -### Aishell training result(Transducer-stateless) + +### Aishell training result(Stateless Transducer) + +#### Pruned transducer stateless 3 + +See + + +[./pruned_transducer_stateless3](./pruned_transducer_stateless3) + +It uses pruned RNN-T. + +| | test | dev | comment | +|------------------------|------|------|---------------------------------------| +| greedy search | 5.39 | 5.09 | --epoch 29 --avg 5 --max-duration 600 | +| modified beam search | 5.05 | 4.79 | --epoch 29 --avg 5 --max-duration 600 | +| fast beam search | 5.13 | 4.91 | --epoch 29 --avg 5 --max-duration 600 | + +Training command is: + +```bash +./prepare.sh +./prepare_aidatatang_200zh.sh + +export CUDA_VISIBLE_DEVICES="4,5,6,7" + +./pruned_transducer_stateless3/train.py \ + --exp-dir ./pruned_transducer_stateless3/exp-context-size-1 \ + --world-size 4 \ + --max-duration 200 \ + --datatang-prob 0.5 \ + --start-epoch 1 \ + --num-epochs 30 \ + --use-fp16 1 \ + --num-encoder-layers 12 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --context-size 1 \ + --decoder-dim 512 \ + --joiner-dim 512 \ + --master-port 12356 +``` + +**Caution**: It uses `--context-size=1`. + +The tensorboard log is available at + + +The decoding command is: + +```bash +for epoch in 29; do + for avg in 5; do + for m in greedy_search modified_beam_search fast_beam_search; do + ./pruned_transducer_stateless3/decode.py \ + --exp-dir ./pruned_transducer_stateless3/exp-context-size-1 \ + --epoch $epoch \ + --avg $avg \ + --use-averaged-model 1 \ + --max-duration 600 \ + --decoding-method $m \ + --num-encoder-layers 12 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --context-size 1 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 + done + done +done +``` + +Pretrained models, training logs, decoding logs, and decoding results +are available at + + +We have a tutorial in [sherpa](https://github.com/k2-fsa/sherpa) about how +to use the pre-trained model for non-streaming ASR. See + #### 2022-03-01 [./transducer_stateless_modified-2](./transducer_stateless_modified-2) +It uses [optimized_transducer](https://github.com/csukuangfj/optimized_transducer) +for computing RNN-T loss. + Stateless transducer + modified transducer + using [aidatatang_200zh](http://www.openslr.org/62/) as extra training data. diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/aidatatang_200zh.py b/egs/aishell/ASR/pruned_transducer_stateless3/aidatatang_200zh.py new file mode 120000 index 000000000..9a799406b --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/aidatatang_200zh.py @@ -0,0 +1 @@ +../transducer_stateless_modified-2/aidatatang_200zh.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/aishell.py b/egs/aishell/ASR/pruned_transducer_stateless3/aishell.py new file mode 120000 index 000000000..1b5f38a54 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/aishell.py @@ -0,0 +1 @@ +../transducer_stateless_modified-2/aishell.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/aishell/ASR/pruned_transducer_stateless3/asr_datamodule.py new file mode 120000 index 000000000..ae3bdd1e0 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/asr_datamodule.py @@ -0,0 +1 @@ +../transducer_stateless_modified-2/asr_datamodule.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless3/beam_search.py new file mode 120000 index 000000000..e24eca39f --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/beam_search.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/conformer.py b/egs/aishell/ASR/pruned_transducer_stateless3/conformer.py new file mode 120000 index 000000000..c7c1a4b6e --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/conformer.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless5/conformer.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decode.py b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py new file mode 100755 index 000000000..f686174f3 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decode.py @@ -0,0 +1,637 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +""" + + +import argparse +import logging +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import torch +import torch.nn as nn +from aishell import AIShell +from asr_datamodule import AsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="The lang dir", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + token_table: k2.SymbolTable, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + token_table: + It maps token ID to a string. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + else: + hyp_tokens = [] + batch_size = encoder_out.size(0) + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyp_tokens.append(hyp) + + hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif params.decoding_method == "fast_beam_search": + return { + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps + } + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + token_table: k2.SymbolTable, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + token_table: + It maps a token ID to a string. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + + hyps_dict = decode_one_batch( + params=params, + model=model, + token_table=token_table, + decoding_graph=decoding_graph, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for hyp_words, ref_text in zip(hyps, texts): + ref_words = ref_text.split() + this_batch.append((ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[int], List[int]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + # we compute CER for aishell dataset. + results_char = [] + for res in results: + results_char.append((list("".join(res[0])), list("".join(res[1])))) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results_char, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tCER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, CER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict( + average_checkpoints(filenames, device=device), strict=False + ) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ), + strict=False, + ) + + model.to(device) + model.eval() + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + asr_datamodule = AsrDataModule(args) + aishell = AIShell(manifest_dir=args.manifest_dir) + test_cuts = aishell.test_cuts() + dev_cuts = aishell.valid_cuts() + test_dl = asr_datamodule.test_dataloaders(test_cuts) + dev_dl = asr_datamodule.test_dataloaders(dev_cuts) + + test_sets = ["test", "dev"] + test_dls = [test_dl, dev_dl] + + for test_set, test_dl in zip(test_sets, test_dls): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + token_table=lexicon.token_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/decoder.py b/egs/aishell/ASR/pruned_transducer_stateless3/decoder.py new file mode 120000 index 000000000..722e1c894 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/decoder.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/decoder.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/encoder_interface.py b/egs/aishell/ASR/pruned_transducer_stateless3/encoder_interface.py new file mode 120000 index 000000000..f58253127 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/encoder_interface.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 b/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 new file mode 120000 index 000000000..bcd4abc2f --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 @@ -0,0 +1 @@ +/ceph-fj/fangjun/open-source/icefall-aishell/egs/aishell/ASR/pruned_transducer_stateless3/exp-context-size-1 \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/export.py b/egs/aishell/ASR/pruned_transducer_stateless3/export.py new file mode 100755 index 000000000..307895a76 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/export.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +# +# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" +Usage: +./pruned_transducer_stateless3/export.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --jit 0 \ + --epoch 29 \ + --avg 5 + +It will generate a file exp_dir/pretrained-epoch-29-avg-5.pt + +To use the generated file with `pruned_transducer_stateless3/decode.py`, +you can do:: + + cd /path/to/exp_dir + ln -s pretrained-epoch-29-avg-5.pt epoch-9999.pt + + cd /path/to/egs/aishell/ASR + ./pruned_transducer_stateless3/decode.py \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 100 \ + --lang-dir data/lang_char +""" + +import argparse +import logging +from pathlib import Path + +import torch +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=29, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default=Path("pruned_transducer_stateless3/exp"), + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.script. + """, + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default=Path("data/lang_char"), + help="The lang dir", + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def main(): + args = get_parser().parse_args() + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = ( + params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" + ) + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = ( + params.exp_dir + / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" + ) + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/joiner.py b/egs/aishell/ASR/pruned_transducer_stateless3/joiner.py new file mode 120000 index 000000000..9052f3cbb --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/joiner.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/joiner.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/model.py b/egs/aishell/ASR/pruned_transducer_stateless3/model.py new file mode 100644 index 000000000..e150e8230 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/model.py @@ -0,0 +1,236 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Optional + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import ScaledLinear + +from icefall.utils import add_sos + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + decoder_datatang: Optional[nn.Module] = None, + joiner_datatang: Optional[nn.Module] = None, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and + (N, U, decoder_dim). Its output shape is (N, T, U, vocab_size). + Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + encoder_dim: + Output dimension of the encoder network. + decoder_dim: + Output dimension of the decoder network. + joiner_dim: + Input dimension of the joiner network. + vocab_size: + Output dimension of the joiner network. + decoder_datatang: + Optional. The decoder network for the aidatatang_200zh dataset. + joiner_datatang: + Optional. The joiner network for the aidatatang_200zh dataset. + """ + super().__init__() + + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + self.decoder_datatang = decoder_datatang + self.joiner_datatang = joiner_datatang + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) + self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + + if decoder_datatang is not None: + self.simple_am_proj_datatang = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) + self.simple_lm_proj_datatang = ScaledLinear(decoder_dim, vocab_size) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + aishell: bool = True, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + warmup: float = 1.0, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + aishell: + True to use the decoder and joiner for the aishell dataset. + False to use the decoder and joiner for the aidatatang_200zh + dataset. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + warmup: + A value warmup >= 0 that determines which modules are active, values + warmup > 1 "are fully warmed up" and all modules will be active. + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup) + assert torch.all(encoder_out_lens > 0) + + if aishell: + decoder = self.decoder + simple_lm_proj = self.simple_lm_proj + simple_am_proj = self.simple_am_proj + joiner = self.joiner + else: + decoder = self.decoder_datatang + simple_lm_proj = self.simple_lm_proj_datatang + simple_am_proj = self.simple_am_proj_datatang + joiner = self.joiner_datatang + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = encoder_out_lens + + lm = simple_lm_proj(decoder_out) + am = simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=joiner.encoder_proj(encoder_out), + lm=joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss) diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/optim.py b/egs/aishell/ASR/pruned_transducer_stateless3/optim.py new file mode 120000 index 000000000..0a2f285aa --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/optim.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py new file mode 100755 index 000000000..5cda411bc --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/pretrained.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + +(1) greedy search +./pruned_transducer_stateless3/pretrained.py \ + --checkpoint /path/to/pretrained.pt \ + --lang-dir /path/to/lang_char \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./pruned_transducer_stateless3/pretrained.py \ + --checkpoint /path/to/pretrained.pt \ + --lang-dir /path/to/lang_char \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./pruned_transducer_stateless3/pretrained.py \ + --checkpoint /path/to/pretrained.pt \ + --lang-dir /path/to/lang_char \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./pruned_transducer_stateless3/pretrained.py \ + --checkpoint /path/to/pretrained.pt \ + --lang-dir /path/to/lang_char \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from pathlib import Path +from typing import List + +import k2 +import kaldifeat +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.lexicon import Lexicon + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default=Path("data/lang_char"), + help="The lang dir", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="Maximum number of symbols per frame. " + "Use only when --method is greedy_search", + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + lexicon = Lexicon(params.lang_dir) + + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lens = [f.size(0) for f in features] + feature_lens = torch.tensor(feature_lens, device=device) + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lens + ) + + num_waves = encoder_out.size(0) + hyp_list = [] + logging.info(f"Using {params.method}") + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_list = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_list = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + elif params.method == "modified_beam_search": + hyp_list = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.method}" + ) + hyp_list.append(hyp) + + hyps = [] + for hyp in hyp_list: + hyps.append([lexicon.token_table[i] for i in hyp]) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/scaling.py b/egs/aishell/ASR/pruned_transducer_stateless3/scaling.py new file mode 120000 index 000000000..c10cdfe12 --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/scaling.py @@ -0,0 +1 @@ +../../../librispeech/ASR/pruned_transducer_stateless2/scaling.py \ No newline at end of file diff --git a/egs/aishell/ASR/pruned_transducer_stateless3/train.py b/egs/aishell/ASR/pruned_transducer_stateless3/train.py new file mode 100755 index 000000000..02efe94fe --- /dev/null +++ b/egs/aishell/ASR/pruned_transducer_stateless3/train.py @@ -0,0 +1,1229 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# Copyright 2021 (Pingfeng Luo) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +./prepare.sh +./prepare_aidatatang_200zh.sh + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + + +./pruned_transducer_stateless3/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 0 \ + --exp-dir pruned_transducer_stateless3/exp \ + --max-duration 300 \ + --datatang-prob 0.2 + +# For mix precision training: + +./pruned_transducer_stateless3/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir pruned_transducer_stateless3/exp \ + --max-duration 550 +""" + + +import argparse +import copy +import logging +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from aidatatang_200zh import AIDatatang200zh +from aishell import AIShell +from asr_datamodule import AsrDataModule +from conformer import Conformer +from decoder import Decoder +from joiner import Joiner +from lhotse import CutSet, load_manifest +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.lexicon import Lexicon +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=int, + default=12, + help="Number of conformer encoder layers..", + ) + + parser.add_argument( + "--dim-feedforward", + type=int, + default=2048, + help="Feedforward dimension of the conformer encoder layer.", + ) + + parser.add_argument( + "--nhead", + type=int, + default=8, + help="Number of attention heads in the conformer encoder layer.", + ) + + parser.add_argument( + "--encoder-dim", + type=int, + default=512, + help="Attention dimension in the conformer encoder layer.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--lang-dir", + type=str, + default="data/lang_char", + help="""The lang dir + It contains language related input files such as + "lexicon.txt" + """, + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="The initial learning rate. This value should not need " + "to be changed.", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=6, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=1, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--datatang-prob", + type=float, + default=0.2, + help="The probability to select a batch from the " + "aidatatang_200zh dataset", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 1000, + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + # TODO: We can add an option to switch between Conformer and Transformer + encoder = Conformer( + num_features=params.feature_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + nhead=params.nhead, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + decoder_datatang = get_decoder_model(params) + joiner_datatang = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + decoder_datatang=decoder_datatang, + joiner_datatang=joiner_datatang, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def is_aishell(c: Cut) -> bool: + """Return True if this cut is from the AIShell dataset. + + Note: + During data preparation, we set the custom field in + the supervision segment of aidatatang_200zh to + dict(origin='aidatatang_200zh') + See ../local/process_aidatatang_200zh.py. + """ + return c.supervisions[0].custom is None + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + batch: dict, + is_training: bool, + warmup: float = 1.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + aishell = is_aishell(supervisions["cut"][0]) + + texts = batch["supervisions"]["text"] + y = graph_compiler.texts_to_ids(texts) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + aishell=aishell, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + ) + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = ( + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + graph_compiler: CharCtcTrainingGraphCompiler, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + graph_compiler: CharCtcTrainingGraphCompiler, + train_dl: torch.utils.data.DataLoader, + datatang_train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + rng: random.Random, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + aishell_tot_loss = MetricsTracker() + datatang_tot_loss = MetricsTracker() + tot_loss = MetricsTracker() + + # index 0: for LibriSpeech + # index 1: for GigaSpeech + # This sets the probabilities for choosing which datasets + dl_weights = [1 - params.datatang_prob, params.datatang_prob] + + iter_aishell = iter(train_dl) + iter_datatang = iter(datatang_train_dl) + + batch_idx = 0 + + while True: + idx = rng.choices((0, 1), weights=dl_weights, k=1)[0] + dl = iter_aishell if idx == 0 else iter_datatang + + try: + batch = next(dl) + except StopIteration: + break + batch_idx += 1 + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + aishell = is_aishell(batch["supervisions"]["cut"][0]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + if aishell: + aishell_tot_loss = ( + aishell_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "aishell" # for logging only + else: + datatang_tot_loss = ( + datatang_tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info + prefix = "datatang" + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch( + batch, params=params, graph_compiler=graph_compiler + ) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, {prefix}_loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"aishell_tot_loss[{aishell_tot_loss}], " + f"datatang_tot_loss[{datatang_tot_loss}], " + f"batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, + f"train/current_{prefix}_", + params.batch_idx_train, + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + aishell_tot_loss.write_summary( + tb_writer, "train/aishell_tot_", params.batch_idx_train + ) + datatang_tot_loss.write_summary( + tb_writer, "train/datatang_tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def filter_short_and_long_utterances(cuts: CutSet) -> CutSet: + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 12 seconds + # + # Caution: There is a reason to select 12.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + return 1.0 <= c.duration <= 12.0 + + return cuts + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + rng = random.Random(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + lexicon = Lexicon(params.lang_dir) + graph_compiler = CharCtcTrainingGraphCompiler( + lexicon=lexicon, + device=device, + oov="", + ) + + params.blank_id = 0 + params.vocab_size = max(lexicon.tokens) + 1 + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2 ** 22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + aishell = AIShell(manifest_dir=args.manifest_dir) + train_cuts = aishell.train_cuts() + train_cuts = filter_short_and_long_utterances(train_cuts) + + datatang = AIDatatang200zh(manifest_dir=args.manifest_dir) + train_datatang_cuts = datatang.train_cuts() + train_datatang_cuts = filter_short_and_long_utterances(train_datatang_cuts) + train_datatang_cuts = train_datatang_cuts.repeat(times=None) + + if args.enable_musan: + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) + else: + cuts_musan = None + + asr_datamodule = AsrDataModule(args) + + train_dl = asr_datamodule.train_dataloaders( + train_cuts, + on_the_fly_feats=False, + cuts_musan=cuts_musan, + ) + + datatang_train_dl = asr_datamodule.train_dataloaders( + train_datatang_cuts, + on_the_fly_feats=False, + cuts_musan=cuts_musan, + ) + + valid_cuts = aishell.valid_cuts() + valid_dl = asr_datamodule.valid_dataloaders(valid_cuts) + + for dl in [ + train_dl, + # datatang_train_dl + ]: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=dl, + optimizer=optimizer, + graph_compiler=graph_compiler, + params=params, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + logging.info(f"start training from epoch {params.start_epoch}") + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + datatang_train_dl.sampler.set_epoch(epoch) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + graph_compiler=graph_compiler, + train_dl=train_dl, + datatang_train_dl=datatang_train_dl, + valid_dl=valid_dl, + rng=rng, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + graph_compiler: CharCtcTrainingGraphCompiler, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = graph_compiler.texts_to_ids(supervisions["text"]) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + graph_compiler: CharCtcTrainingGraphCompiler, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + # warmup = 0.0 is so that the derivs for the pruned loss stay zero + # (i.e. are not remembered by the decaying-average in adam), because + # we want to avoid these params being subject to shrinkage in adam. + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + graph_compiler=graph_compiler, + batch=batch, + is_training=True, + warmup=0.0 if params.start_epoch == 1 else 1.0, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch( + batch, params=params, graph_compiler=graph_compiler + ) + raise + + +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lang_dir = Path(args.lang_dir) + + assert 0 <= args.datatang_prob < 1, args.datatang_prob + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/aishell/ASR/transducer_stateless_modified-2/train.py b/egs/aishell/ASR/transducer_stateless_modified-2/train.py index 962fffdf5..225d0d709 100755 --- a/egs/aishell/ASR/transducer_stateless_modified-2/train.py +++ b/egs/aishell/ASR/transducer_stateless_modified-2/train.py @@ -405,7 +405,7 @@ def compute_loss( is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: """ - Compute CTC loss given the model and its inputs. + Compute RNN-T loss given the model and its inputs. Args: params: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index b6d94aaf1..1ddfce034 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -73,6 +73,9 @@ class Decoder(nn.Module): groups=decoder_dim, bias=False, ) + else: + # It is to support torch script + self.conv = nn.Identity() def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: """ diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 6f7231f4b..bf3917df0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -117,10 +117,7 @@ class Conformer(EncoderInterface): x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # Caution: We assume the subsampling factor is 4! - lengths = ((x_lens - 1) // 2 - 1) // 2 + lengths = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == lengths.max().item() mask = make_pad_mask(lengths) @@ -293,8 +290,10 @@ class ConformerEncoder(nn.Module): ) self.num_layers = num_layers + assert len(set(aux_layers)) == len(aux_layers) + assert num_layers - 1 not in aux_layers - self.aux_layers = set(aux_layers + [num_layers - 1]) + self.aux_layers = aux_layers + [num_layers - 1] num_channels = encoder_layer.norm_final.num_channels self.combiner = RandomCombine( @@ -1154,7 +1153,7 @@ class RandomCombine(nn.Module): """ num_inputs = self.num_inputs assert len(inputs) == num_inputs - if not self.training: + if not self.training or torch.jit.is_scripting(): return inputs[-1] # Shape of weights: (*, num_inputs) @@ -1162,8 +1161,22 @@ class RandomCombine(nn.Module): num_frames = inputs[0].numel() // num_channels mod_inputs = [] - for i in range(num_inputs - 1): - mod_inputs.append(self.linear[i](inputs[i])) + + if False: + # It throws the following error for torch 1.6.0 when using + # torch script. + # + # Expected integer literal for index. ModuleList/Sequential + # indexing is only supported with integer literals. Enumeration is + # supported, e.g. 'for index, v in enumerate(self): ...': + # for i in range(num_inputs - 1): + # mod_inputs.append(self.linear[i](inputs[i])) + assert False + else: + for i, linear in enumerate(self.linear): + if i < num_inputs - 1: + mod_inputs.append(linear(inputs[i])) + mod_inputs.append(inputs[num_inputs - 1]) ndim = inputs[0].ndim @@ -1181,11 +1194,13 @@ class RandomCombine(nn.Module): # ans: (num_frames, num_channels, 1) ans = torch.matmul(stacked_inputs, weights) # ans: (*, num_channels) - ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels) - if __name__ == "__main__": - # for testing only... - print("Weights = ", weights.reshape(num_frames, num_inputs)) + ans = ans.reshape(inputs[0].shape[:-1] + [num_channels]) + + # The following if causes errors for torch script in torch 1.6.0 + # if __name__ == "__main__": + # # for testing only... + # print("Weights = ", weights.reshape(num_frames, num_inputs)) return ans def _get_random_weights( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index f1269a4bd..936508900 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -146,8 +146,6 @@ def main(): args = get_parser().parse_args() args.exp_dir = Path(args.exp_dir) - assert args.jit is False, "Support torchscript will be added later" - params = get_params() params.update(vars(args)) @@ -246,12 +244,15 @@ def main(): ) ) - model.eval() - model.to("cpu") model.eval() if params.jit: + # We won't use the forward() method of the model in C++, so just ignore + # it here. + # Otherwise, one of its arguments is a ragged tensor and is not + # torch scriptabe. + model.__class__.forward = torch.jit.ignore(model.__class__.forward) logging.info("Using torch.jit.script") model = torch.jit.script(model) filename = params.exp_dir / "cpu_jit.pt" From dc89b61b808a22b64e7e38cb09771ea7a6bb64d2 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 22 Jun 2022 00:09:25 +0800 Subject: [PATCH 6/9] Add fast_beam_search_nbest. (#420) * Add fast_beam_search_nbest. * Fix CI errors. * Fix CI errors. * More fixes. * Small fixes. * Support using log_add in LG decoding with fast_beam_search. * Support LG decoding in pruned_transducer_stateless * Support LG for pruned_transducer_stateless2. * Support LG for fast beam search. * Minor fixes. --- ...pruned-transducer-stateless5-2022-05-13.sh | 6 + .github/workflows/test.yml | 12 +- .../beam_search.py | 196 ++++++++++++++ .../ASR/pruned_transducer_stateless/decode.py | 204 ++++++++++----- .../beam_search.py | 200 +++++++++++++- .../pruned_transducer_stateless2/decode.py | 202 ++++++++++++-- .../pruned_transducer_stateless3/decode.py | 246 +++++++++++++----- .../pruned_transducer_stateless4/decode.py | 210 +++++++++++++-- .../pruned_transducer_stateless5/decode.py | 210 +++++++++++++-- icefall/decode.py | 6 +- 10 files changed, 1298 insertions(+), 194 deletions(-) diff --git a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh index 3d0c4e2ef..61210ac6e 100755 --- a/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh +++ b/.github/scripts/run-librispeech-pruned-transducer-stateless5-2022-05-13.sh @@ -32,6 +32,12 @@ for sym in 1 2 3; do --max-sym-per-frame $sym \ --checkpoint $repo/exp/pretrained.pt \ --bpe-model $repo/data/lang_bpe_500/bpe.model \ + --num-encoder-layers 18 \ + --dim-feedforward 2048 \ + --nhead 8 \ + --encoder-dim 512 \ + --decoder-dim 512 \ + --joiner-dim 512 $repo/test_wavs/1089-134686-0001.wav \ $repo/test_wavs/1221-135766-0001.wav \ $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f9dab7afe..1583926ec 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -33,13 +33,13 @@ jobs: # disable macOS test for now. os: [ubuntu-18.04] python-version: [3.7, 3.8] - torch: ["1.8.0", "1.10.0"] - torchaudio: ["0.8.0", "0.10.0"] - k2-version: ["1.9.dev20211101"] + torch: ["1.8.0", "1.11.0"] + torchaudio: ["0.8.0", "0.11.0"] + k2-version: ["1.15.1.dev20220427"] exclude: - torch: "1.8.0" - torchaudio: "0.10.0" - - torch: "1.10.0" + torchaudio: "0.11.0" + - torch: "1.11.0" torchaudio: "0.8.0" fail-fast: false @@ -67,7 +67,7 @@ jobs: # numpy 1.20.x does not support python 3.6 pip install numpy==1.19 pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html - if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then + if [[ ${{ matrix.torchaudio }} == "0.11.0" ]]; then pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html else pip install torchaudio==${{ matrix.torchaudio }} diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index db23fd993..40c442e7a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -75,6 +75,202 @@ def fast_beam_search_one_best( return hyps +def fast_beam_search_nbest_LG( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + Returns: + Return the decoded result. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # The following code is modified from nbest.intersect() + word_fsa = k2.invert(nbest.fsa) + if hasattr(lattice, "aux_labels"): + # delete token IDs as it is not needed + del word_fsa.aux_labels + word_fsa.scores.zero_() + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) + path_to_utt_map = nbest.shape.row_ids(1) + + if hasattr(lattice, "aux_labels"): + # lattice has token IDs as labels and word IDs as aux_labels. + # inv_lattice has word IDs as labels and token IDs as aux_labels + inv_lattice = k2.invert(lattice) + inv_lattice = k2.arc_sort(inv_lattice) + else: + inv_lattice = k2.arc_sort(lattice) + + if inv_lattice.shape[0] == 1: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=torch.zeros_like(path_to_utt_map), + sorted_match_a=True, + ) + else: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_utt_map, + sorted_match_a=True, + ) + + # path_lattice has word IDs as labels and token IDs as aux_labels + path_lattice = k2.top_sort(k2.connect(path_lattice)) + tot_scores = path_lattice.get_tot_scores( + use_double_scores=use_double_scores, + log_semiring=True, # Note: we always use True + ) + # See https://github.com/k2-fsa/icefall/pull/420 for why + # we always use log_semiring=True + + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + best_hyp_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) + + hyps = get_texts(best_path) + + return hyps + + +def fast_beam_search_nbest( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + Returns: + Return the decoded result. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + max_indexes = nbest.tot_scores().argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + hyps = get_texts(best_path) + + return hyps + + def fast_beam_search_nbest_oracle( model: Transducer, decoding_graph: k2.Fsa, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index ea43836bd..f39cc614c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -50,20 +50,44 @@ Usage: --exp-dir ./pruned_transducer_stateless/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 -(5) fast beam search using LG +(5) fast beam search (nbest) ./pruned_transducer_stateless/decode.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless/exp \ - --use-LG True \ - --use-max False \ --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 8 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ --max-contexts 8 \ --max-states 64 """ @@ -82,6 +106,9 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, fast_beam_search_one_best, greedy_search, greedy_search_batch, @@ -99,7 +126,6 @@ from icefall.utils import ( AttributeDict, setup_logger, store_transcripts, - str2bool, write_error_stats, ) @@ -153,7 +179,7 @@ def get_parser(): parser.add_argument( "--lang-dir", - type=str, + type=Path, default="data/lang_bpe_500", help="The lang dir containing word table and LG graph", ) @@ -167,6 +193,11 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. """, ) @@ -182,30 +213,13 @@ def get_parser(): parser.add_argument( "--beam", type=float, - default=4, + default=20.0, help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--use-LG", - type=str2bool, - default=False, - help="""Whether to use an LG graph for FSA-based beam search. - Used only when --decoding_method is fast_beam_search. If setting true, - it assumes there is an LG.pt file in lang_dir.""", - ) - - parser.add_argument( - "--use-max", - type=str2bool, - default=False, - help="""If True, use max-op to select the hypothesis that have the - max log_prob in case of duplicate hypotheses. - If False, use log_add. - Used only for beam_search, modified_beam_search, and fast_beam_search + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle """, ) @@ -214,7 +228,7 @@ def get_parser(): type=float, default=0.01, help=""" - Used only when --decoding_method is fast_beam_search. + Used only when --decoding_method is fast_beam_search_nbest_LG. It specifies the scale for n-gram LM scores. """, ) @@ -222,9 +236,10 @@ def get_parser(): parser.add_argument( "--max-contexts", type=int, - default=4, + default=8, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -232,7 +247,8 @@ def get_parser(): type=int, default=8, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -250,6 +266,24 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + return parser @@ -286,7 +320,8 @@ def decode_one_batch( The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -299,6 +334,7 @@ def decode_one_batch( # at entry, feature is (N, T, C) supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) encoder_out, encoder_out_lens = model.encoder( @@ -316,12 +352,51 @@ def decode_one_batch( max_contexts=params.max_contexts, max_states=params.max_states, ) - if params.use_LG: - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - else: - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -339,7 +414,6 @@ def decode_one_batch( encoder_out=encoder_out, encoder_out_lens=encoder_out_lens, beam=params.beam_size, - use_max=params.use_max, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) @@ -361,7 +435,6 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size, - use_max=params.use_max, ) else: raise ValueError( @@ -371,14 +444,17 @@ def decode_one_batch( if params.decoding_method == "greedy_search": return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} else: return {f"beam_size_{params.beam_size}": hyps} @@ -406,7 +482,8 @@ def decode_dataset( The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -424,7 +501,7 @@ def decode_dataset( if params.decoding_method == "greedy_search": log_interval = 50 else: - log_interval = 10 + log_interval = 20 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -517,6 +594,9 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method @@ -527,16 +607,18 @@ def main(): params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if "fast_beam_search" in params.decoding_method: - params.suffix += f"-use-LG-{params.use_LG}" params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" - params.suffix += f"-use-max-{params.use_max}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" ) - params.suffix += f"-use-max-{params.use_max}" else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -596,12 +678,14 @@ def main(): model.eval() model.device = device - if params.decoding_method == "fast_beam_search": - if params.use_LG: + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": lexicon = Lexicon(params.lang_dir) word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") decoding_graph = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/LG.pt", map_location=device) + torch.load(lg_filename, map_location=device) ) decoding_graph.scores *= params.ngram_lm_scale else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 7c936b257..6b6190a09 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -37,7 +37,7 @@ def fast_beam_search_one_best( ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using modified beam search, and then + A lattice is first obtained using fast beam search, and then the shortest path within the lattice is used as the final output. Args: @@ -74,6 +74,202 @@ def fast_beam_search_one_best( return hyps +def fast_beam_search_nbest_LG( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + Returns: + Return the decoded result. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # The following code is modified from nbest.intersect() + word_fsa = k2.invert(nbest.fsa) + if hasattr(lattice, "aux_labels"): + # delete token IDs as it is not needed + del word_fsa.aux_labels + word_fsa.scores.zero_() + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) + path_to_utt_map = nbest.shape.row_ids(1) + + if hasattr(lattice, "aux_labels"): + # lattice has token IDs as labels and word IDs as aux_labels. + # inv_lattice has word IDs as labels and token IDs as aux_labels + inv_lattice = k2.invert(lattice) + inv_lattice = k2.arc_sort(inv_lattice) + else: + inv_lattice = k2.arc_sort(lattice) + + if inv_lattice.shape[0] == 1: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=torch.zeros_like(path_to_utt_map), + sorted_match_a=True, + ) + else: + path_lattice = k2.intersect_device( + inv_lattice, + word_fsa_with_epsilon_loops, + b_to_a_map=path_to_utt_map, + sorted_match_a=True, + ) + + # path_lattice has word IDs as labels and token IDs as aux_labels + path_lattice = k2.top_sort(k2.connect(path_lattice)) + tot_scores = path_lattice.get_tot_scores( + use_double_scores=use_double_scores, + log_semiring=True, # Note: we always use True + ) + # See https://github.com/k2-fsa/icefall/pull/420 for why + # we always use log_semiring=True + + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + best_hyp_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) + + hyps = get_texts(best_path) + + return hyps + + +def fast_beam_search_nbest( + model: Transducer, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + num_paths: int, + nbest_scale: float = 0.5, + use_double_scores: bool = True, +) -> List[List[int]]: + """It limits the maximum number of symbols per frame to 1. + + The process to get the results is: + - (1) Use fast beam search to get a lattice + - (2) Select `num_paths` paths from the lattice using k2.random_paths() + - (3) Unique the selected paths + - (4) Intersect the selected paths with the lattice and compute the + shortest path from the intersection result + - (5) The path with the largest score is used as the decoding output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a HLG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + num_paths: + Number of paths to extract from the decoded lattice. + nbest_scale: + It's the scale applied to the lattice.scores. A smaller value + yields more unique paths. + use_double_scores: + True to use double precision for computation. False to use + single precision. + Returns: + Return the decoded result. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + ) + + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + + # at this point, nbest.fsa.scores are all zeros. + + nbest = nbest.intersect(lattice) + # Now nbest.fsa.scores contains acoustic scores + + max_indexes = nbest.tot_scores().argmax() + + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + hyps = get_texts(best_path) + + return hyps + + def fast_beam_search_nbest_oracle( model: Transducer, decoding_graph: k2.Fsa, @@ -89,7 +285,7 @@ def fast_beam_search_nbest_oracle( ) -> List[List[int]]: """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using modified beam search, and then + A lattice is first obtained using fast beam search, and then we select `num_paths` linear paths from the lattice. The path that has the minimum edit distance with the given reference transcript is used as the output. diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index d7d6b1202..ea368fb87 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -43,16 +43,53 @@ Usage: --decoding-method modified_beam_search \ --beam-size 4 -(4) fast beam search +(4) fast beam search (one best) ./pruned_transducer_stateless2/decode.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless2/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless2/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless2/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless2/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless2/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 """ @@ -69,6 +106,9 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, fast_beam_search_one_best, greedy_search, greedy_search_batch, @@ -81,6 +121,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) +from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, setup_logger, @@ -136,6 +177,13 @@ def get_parser(): help="Path to the BPE model", ) + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + parser.add_argument( "--decoding-method", type=str, @@ -145,6 +193,11 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. """, ) @@ -160,27 +213,42 @@ def get_parser(): parser.add_argument( "--beam", type=float, - default=4, + default=20.0, help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, ) parser.add_argument( "--max-contexts", type=int, - default=4, + default=8, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( "--max-states", type=int, - default=8, + default=64, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -198,6 +266,24 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + return parser @@ -206,6 +292,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the @@ -229,9 +316,12 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -263,6 +353,49 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -318,6 +451,17 @@ def decode_one_batch( f"max_states_{params.max_states}" ): hyps } + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} else: return {f"beam_size_{params.beam_size}": hyps} @@ -327,6 +471,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -340,9 +485,12 @@ def decode_dataset( The neural model. sp: The BPE model. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -360,7 +508,7 @@ def decode_dataset( if params.decoding_method == "greedy_search": log_interval = 50 else: - log_interval = 10 + log_interval = 20 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -370,6 +518,7 @@ def decode_dataset( params=params, model=model, sp=sp, + word_table=word_table, decoding_graph=decoding_graph, batch=batch, ) @@ -452,6 +601,9 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method @@ -465,6 +617,11 @@ def main(): params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" @@ -528,10 +685,24 @@ def main(): model.eval() model.device = device - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None + word_table = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -553,6 +724,7 @@ def main(): params=params, model=model, sp=sp, + word_table=word_table, decoding_graph=decoding_graph, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index 5b3dce853..8b1ddc930 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -19,40 +19,77 @@ Usage: (1) greedy search ./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method greedy_search + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method greedy_search (2) beam search (not recommended) ./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 (3) modified beam search ./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 -(4) fast beam search +(4) fast beam search (one best) ./pruned_transducer_stateless3/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless3/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless3/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 """ @@ -69,6 +106,8 @@ import torch.nn as nn from asr_datamodule import AsrDataModule from beam_search import ( beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, fast_beam_search_nbest_oracle, fast_beam_search_one_best, greedy_search, @@ -83,6 +122,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) +from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, setup_logger, @@ -138,6 +178,13 @@ def get_parser(): help="Path to the BPE model", ) + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + parser.add_argument( "--decoding-method", type=str, @@ -147,7 +194,11 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_nbest - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. """, ) @@ -163,28 +214,42 @@ def get_parser(): parser.add_argument( "--beam", type=float, - default=4, + default=20.0, help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is - fast_beam_search or fast_beam_search_nbest_oracle""", + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, ) parser.add_argument( "--max-contexts", type=int, - default=4, + default=8, help="""Used only when --decoding-method is - fast_beam_search or fast_beam_search_nbest_oracle""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( "--max-states", type=int, - default=8, + default=64, help="""Used only when --decoding-method is - fast_beam_search or fast_beam_search_nbest_oracle""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -205,10 +270,10 @@ def get_parser(): parser.add_argument( "--num-paths", type=int, - default=100, - help="""Number of paths for computed nbest oracle WER - when the decoding method is fast_beam_search_nbest_oracle. - """, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -216,9 +281,10 @@ def get_parser(): type=float, default=0.5, help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding_method is fast_beam_search_nbest_oracle. - """, + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", ) + return parser @@ -227,6 +293,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the @@ -250,10 +317,12 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is - fast_beam_search or fast_beam_search_nbest_oracle. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -285,6 +354,34 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif params.decoding_method == "fast_beam_search_nbest_oracle": hyp_tokens = fast_beam_search_nbest_oracle( model=model, @@ -355,16 +452,25 @@ def decode_one_batch( f"max_states_{params.max_states}" ): hyps } - elif params.decoding_method == "fast_beam_search_nbest_oracle": + elif params.decoding_method == "fast_beam_search": return { ( f"beam_{params.beam}_" f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}_" - f"num_paths_{params.num_paths}_" - f"nbest_scale_{params.nbest_scale}" + f"max_states_{params.max_states}" ): hyps } + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} else: return {f"beam_size_{params.beam_size}": hyps} @@ -374,6 +480,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -387,9 +494,12 @@ def decode_dataset( The neural model. sp: The BPE model. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -407,7 +517,7 @@ def decode_dataset( if params.decoding_method == "greedy_search": log_interval = 50 else: - log_interval = 10 + log_interval = 20 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -417,6 +527,7 @@ def decode_dataset( params=params, model=model, sp=sp, + word_table=word_table, decoding_graph=decoding_graph, batch=batch, ) @@ -499,6 +610,8 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", "fast_beam_search_nbest_oracle", "modified_beam_search", ) @@ -509,16 +622,15 @@ def main(): else: params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if params.decoding_method == "fast_beam_search": + if "fast_beam_search" in params.decoding_method: params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" - elif params.decoding_method == "fast_beam_search_nbest_oracle": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - params.suffix += f"-num-paths-{params.num_paths}" - params.suffix += f"-nbest-scale-{params.nbest_scale}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" @@ -539,9 +651,9 @@ def main(): sp = spm.SentencePieceProcessor() sp.load(params.bpe_model) - # and is defined in local/train_bpe_model.py + # and are defined in local/train_bpe_model.py params.blank_id = sp.piece_to_id("") - params.unk_id = sp.unk_id() + params.unk_id = sp.piece_to_id("") params.vocab_size = sp.get_piece_size() logging.info(params) @@ -583,13 +695,24 @@ def main(): model.device = device model.unk_id = params.unk_id - if params.decoding_method in ( - "fast_beam_search", - "fast_beam_search_nbest_oracle", - ): - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None + word_table = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -612,6 +735,7 @@ def main(): params=params, model=model, sp=sp, + word_table=word_table, decoding_graph=decoding_graph, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 70afc3ea3..a8d730ad6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -44,16 +44,53 @@ Usage: --decoding-method modified_beam_search \ --beam-size 4 -(4) fast beam search +(4) fast beam search (one best) ./pruned_transducer_stateless4/decode.py \ --epoch 30 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless4/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless4/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless4/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless4/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless4/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 """ @@ -70,6 +107,9 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, fast_beam_search_one_best, greedy_search, greedy_search_batch, @@ -83,6 +123,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) +from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, setup_logger, @@ -150,6 +191,13 @@ def get_parser(): help="Path to the BPE model", ) + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + parser.add_argument( "--decoding-method", type=str, @@ -159,6 +207,11 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. """, ) @@ -174,27 +227,42 @@ def get_parser(): parser.add_argument( "--beam", type=float, - default=4, + default=20.0, help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, ) parser.add_argument( "--max-contexts", type=int, - default=4, + default=8, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( "--max-states", type=int, - default=8, + default=64, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -212,6 +280,24 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + return parser @@ -220,6 +306,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the @@ -243,9 +330,12 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -277,6 +367,49 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -324,14 +457,17 @@ def decode_one_batch( if params.decoding_method == "greedy_search": return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} else: return {f"beam_size_{params.beam_size}": hyps} @@ -341,6 +477,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -354,9 +491,12 @@ def decode_dataset( The neural model. sp: The BPE model. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -374,7 +514,7 @@ def decode_dataset( if params.decoding_method == "greedy_search": log_interval = 50 else: - log_interval = 10 + log_interval = 20 results = defaultdict(list) for batch_idx, batch in enumerate(dl): @@ -385,6 +525,7 @@ def decode_dataset( model=model, sp=sp, decoding_graph=decoding_graph, + word_table=word_table, batch=batch, ) @@ -466,6 +607,9 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method @@ -479,6 +623,11 @@ def main(): params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" @@ -592,10 +741,24 @@ def main(): model.to(device) model.eval() - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None + word_table = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -617,6 +780,7 @@ def main(): params=params, model=model, sp=sp, + word_table=word_table, decoding_graph=decoding_graph, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index c2ca07480..f87d23cc9 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -44,16 +44,53 @@ Usage: --decoding-method modified_beam_search \ --beam-size 4 -(4) fast beam search +(4) fast beam search (one best) ./pruned_transducer_stateless5/decode.py \ --epoch 28 \ --avg 15 \ --exp-dir ./pruned_transducer_stateless5/exp \ --max-duration 600 \ --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./pruned_transducer_stateless5/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless5/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 """ @@ -70,6 +107,9 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, fast_beam_search_one_best, greedy_search, greedy_search_batch, @@ -83,6 +123,7 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) +from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, setup_logger, @@ -128,7 +169,7 @@ def get_parser(): parser.add_argument( "--use-averaged-model", type=str2bool, - default=False, + default=True, help="Whether to load averaged model. Currently it only supports " "using --epoch. If True, it would decode with the averaged model " "over the epoch range from `epoch-avg` (excluded) to `epoch`." @@ -150,6 +191,13 @@ def get_parser(): help="Path to the BPE model", ) + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + parser.add_argument( "--decoding-method", type=str, @@ -159,6 +207,11 @@ def get_parser(): - beam_search - modified_beam_search - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. """, ) @@ -174,27 +227,42 @@ def get_parser(): parser.add_argument( "--beam", type=float, - default=4, + default=20.0, help="""A floating point value to calculate the cutoff score during beam search (i.e., `cutoff = max-score - beam`), which is the same as the `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, ) parser.add_argument( "--max-contexts", type=int, - default=4, + default=8, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( "--max-states", type=int, - default=8, + default=64, help="""Used only when --decoding-method is - fast_beam_search""", + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", ) parser.add_argument( @@ -212,6 +280,24 @@ def get_parser(): Used only when --decoding_method is greedy_search""", ) + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + add_model_arguments(parser) return parser @@ -222,6 +308,7 @@ def decode_one_batch( model: nn.Module, sp: spm.SentencePieceProcessor, batch: dict, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[str]]]: """Decode one batch and return the result in a dict. The dict has the @@ -245,9 +332,12 @@ def decode_one_batch( It is the return value from iterating `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation for the format of the `batch`. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return the decoding result. See above description for the format of the returned dict. @@ -279,6 +369,49 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) elif ( params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1 @@ -326,14 +459,17 @@ def decode_one_batch( if params.decoding_method == "greedy_search": return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} else: return {f"beam_size_{params.beam_size}": hyps} @@ -343,6 +479,7 @@ def decode_dataset( params: AttributeDict, model: nn.Module, sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[str], List[str]]]]: """Decode dataset. @@ -356,9 +493,12 @@ def decode_dataset( The neural model. sp: The BPE model. + word_table: + The word symbol table. decoding_graph: The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. Returns: Return a dict, whose key may be "greedy_search" if greedy search is used, or it may be "beam_7" if beam size of 7 is used. @@ -387,6 +527,7 @@ def decode_dataset( model=model, sp=sp, decoding_graph=decoding_graph, + word_table=word_table, batch=batch, ) @@ -468,6 +609,9 @@ def main(): "greedy_search", "beam_search", "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", "modified_beam_search", ) params.res_dir = params.exp_dir / params.decoding_method @@ -481,6 +625,11 @@ def main(): params.suffix += f"-beam-{params.beam}" params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: params.suffix += ( f"-{params.decoding_method}-beam-size-{params.beam_size}" @@ -594,10 +743,24 @@ def main(): model.to(device) model.eval() - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None + word_table = None num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") @@ -619,6 +782,7 @@ def main(): params=params, model=model, sp=sp, + word_table=word_table, decoding_graph=decoding_graph, ) diff --git a/icefall/decode.py b/icefall/decode.py index 94f3e88ba..3ba899b4e 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -308,9 +308,7 @@ class Nbest(object): del word_fsa.aux_labels word_fsa.scores.zero_() - word_fsa_with_epsilon_loops = k2.remove_epsilon_and_add_self_loops( - word_fsa - ) + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) path_to_utt_map = self.shape.row_ids(1) @@ -609,7 +607,7 @@ def rescore_with_n_best_list( num_paths: Size of nbest list. lm_scale_list: - A list of float representing LM score scales. + A list of floats representing LM score scales. nbest_scale: Scale to be applied to ``lattice.score`` when sampling paths using ``k2.random_paths``. From 0475d75d15d4fca2c6f7a6ab15c4674a429f460b Mon Sep 17 00:00:00 2001 From: ezerhouni <61225408+ezerhouni@users.noreply.github.com> Date: Thu, 23 Jun 2022 13:37:03 +0200 Subject: [PATCH 7/9] [Ready to be merged] Add RNN-LM to Conformer-CTC decoding (#439) --- egs/librispeech/ASR/RESULTS.md | 45 +- egs/librispeech/ASR/conformer_ctc/decode.py | 152 ++++- egs/librispeech/ASR/local/download_lm.py | 2 + .../ASR/local/prepare_lm_training_data.py | 172 +++++ .../ASR/local/sort_lm_training_data.py | 1 + egs/librispeech/ASR/local/train_bpe_model.py | 1 - egs/librispeech/ASR/prepare.sh | 103 ++- egs/ptb/LM/README.md | 18 + egs/ptb/LM/local/prepare_lm_training_data.py | 1 + egs/ptb/LM/local/sort_lm_training_data.py | 143 ++++ .../LM/local/test_prepare_lm_training_data.py | 62 ++ egs/ptb/LM/local/train_bpe_model.py | 1 + egs/ptb/LM/prepare.sh | 115 ++++ egs/ptb/LM/shared | 1 + icefall/decode.py | 187 +++++- icefall/dist.py | 46 +- icefall/rnn_lm/compute_perplexity.py | 237 +++++++ icefall/rnn_lm/dataset.py | 218 +++++++ icefall/rnn_lm/export.py | 167 +++++ icefall/rnn_lm/model.py | 120 ++++ icefall/rnn_lm/test_dataset.py | 71 ++ icefall/rnn_lm/test_dataset_ddp.py | 103 +++ icefall/rnn_lm/test_model.py | 69 ++ icefall/rnn_lm/train.py | 617 ++++++++++++++++++ icefall/utils.py | 49 +- 25 files changed, 2659 insertions(+), 42 deletions(-) create mode 100755 egs/librispeech/ASR/local/prepare_lm_training_data.py create mode 120000 egs/librispeech/ASR/local/sort_lm_training_data.py create mode 100644 egs/ptb/LM/README.md create mode 120000 egs/ptb/LM/local/prepare_lm_training_data.py create mode 100755 egs/ptb/LM/local/sort_lm_training_data.py create mode 100755 egs/ptb/LM/local/test_prepare_lm_training_data.py create mode 120000 egs/ptb/LM/local/train_bpe_model.py create mode 100755 egs/ptb/LM/prepare.sh create mode 120000 egs/ptb/LM/shared create mode 100755 icefall/rnn_lm/compute_perplexity.py create mode 100644 icefall/rnn_lm/dataset.py create mode 100644 icefall/rnn_lm/export.py create mode 100644 icefall/rnn_lm/model.py create mode 100755 icefall/rnn_lm/test_dataset.py create mode 100755 icefall/rnn_lm/test_dataset_ddp.py create mode 100755 icefall/rnn_lm/test_model.py create mode 100755 icefall/rnn_lm/train.py diff --git a/egs/librispeech/ASR/RESULTS.md b/egs/librispeech/ASR/RESULTS.md index 5eb07fae5..3c5027c77 100644 --- a/egs/librispeech/ASR/RESULTS.md +++ b/egs/librispeech/ASR/RESULTS.md @@ -1299,17 +1299,18 @@ You can find the tensorboard log at: +and the RNN-LM pre-trained model: + + The tensorboard log for training is available at diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 177e33a6e..0e8247b8d 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -30,7 +30,7 @@ from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler -from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.checkpoint import load_checkpoint from icefall.decode import ( get_lattice, nbest_decoding, @@ -38,15 +38,19 @@ from icefall.decode import ( one_best_decoding, rescore_with_attention_decoder, rescore_with_n_best_list, + rescore_with_rnn_lm, rescore_with_whole_lattice, ) from icefall.env import get_env_info from icefall.lexicon import Lexicon +from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( AttributeDict, get_texts, + load_averaged_model, setup_logger, store_transcripts, + str2bool, write_error_stats, ) @@ -93,7 +97,9 @@ def get_parser(): is the decoding result. - (5) attention-decoder. Extract n paths from the LM rescored lattice, the path with the highest score is the decoding result. - - (6) nbest-oracle. Its WER is the lower bound of any n-best + - (6) rnn-lm. Rescoring with attention-decoder and RNN LM. We assume + you have trained an RNN LM using ./rnn_lm/train.py + - (7) nbest-oracle. Its WER is the lower bound of any n-best rescoring method can achieve. Useful for debugging n-best rescoring method. """, @@ -105,7 +111,7 @@ def get_parser(): default=100, help="""Number of paths for n-best based decoding method. Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, and nbest-oracle + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle """, ) @@ -116,7 +122,7 @@ def get_parser(): help="""The scale to be applied to `lattice.scores`. It's needed if you use any kinds of n-best based rescoring. Used only when "method" is one of the following values: - nbest, nbest-rescoring, attention-decoder, and nbest-oracle + nbest, nbest-rescoring, attention-decoder, rnn-lm, and nbest-oracle A smaller value results in more unique paths. """, ) @@ -139,11 +145,67 @@ def get_parser(): "--lm-dir", type=str, default="data/lm", - help="""The LM dir. + help="""The n-gram LM dir. It should contain either G_4_gram.pt or G_4_gram.fst.txt """, ) + parser.add_argument( + "--rnn-lm-exp-dir", + type=str, + default="rnn_lm/exp", + help="""Used only when --method is rnn-lm. + It specifies the path to RNN LM exp dir. + """, + ) + + parser.add_argument( + "--rnn-lm-epoch", + type=int, + default=7, + help="""Used only when --method is rnn-lm. + It specifies the checkpoint to use. + """, + ) + + parser.add_argument( + "--rnn-lm-avg", + type=int, + default=2, + help="""Used only when --method is rnn-lm. + It specifies the number of checkpoints to average. + """, + ) + + parser.add_argument( + "--rnn-lm-embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--rnn-lm-hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--rnn-lm-num-layers", + type=int, + default=4, + help="Number of RNN layers the model", + ) + parser.add_argument( + "--rnn-lm-tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + return parser @@ -173,6 +235,7 @@ def get_params() -> AttributeDict: def decode_one_batch( params: AttributeDict, model: nn.Module, + rnn_lm_model: Optional[nn.Module], HLG: Optional[k2.Fsa], H: Optional[k2.Fsa], bpe_model: Optional[spm.SentencePieceProcessor], @@ -205,6 +268,8 @@ def decode_one_batch( model: The neural model. + rnn_lm_model: + The neural model for RNN LM. HLG: The decoding graph. Used only when params.method is NOT ctc-decoding. H: @@ -330,6 +395,7 @@ def decode_one_batch( "nbest-rescoring", "whole-lattice-rescoring", "attention-decoder", + "rnn-lm", ] lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7] @@ -357,8 +423,6 @@ def decode_one_batch( G_with_epsilon_loops=G, lm_scale_list=None, ) - # TODO: pass `lattice` instead of `rescored_lattice` to - # `rescore_with_attention_decoder` best_path_dict = rescore_with_attention_decoder( lattice=rescored_lattice, @@ -370,6 +434,26 @@ def decode_one_batch( eos_id=eos_id, nbest_scale=params.nbest_scale, ) + elif params.method == "rnn-lm": + # lattice uses a 3-gram Lm. We rescore it with a 4-gram LM. + rescored_lattice = rescore_with_whole_lattice( + lattice=lattice, + G_with_epsilon_loops=G, + lm_scale_list=None, + ) + + best_path_dict = rescore_with_rnn_lm( + lattice=rescored_lattice, + num_paths=params.num_paths, + rnn_lm_model=rnn_lm_model, + model=model, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, + blank_id=0, + nbest_scale=params.nbest_scale, + ) else: assert False, f"Unsupported decoding method: {params.method}" @@ -388,6 +472,7 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, + rnn_lm_model: Optional[nn.Module], HLG: Optional[k2.Fsa], H: Optional[k2.Fsa], bpe_model: Optional[spm.SentencePieceProcessor], @@ -405,6 +490,8 @@ def decode_dataset( It is returned by :func:`get_params`. model: The neural model. + rnn_lm_model: + The neural model for RNN LM. HLG: The decoding graph. Used only when params.method is NOT ctc-decoding. H: @@ -442,6 +529,7 @@ def decode_dataset( hyps_dict = decode_one_batch( params=params, model=model, + rnn_lm_model=rnn_lm_model, HLG=HLG, H=H, bpe_model=bpe_model, @@ -490,7 +578,7 @@ def save_results( test_set_name: str, results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): - if params.method == "attention-decoder": + if params.method in ("attention-decoder", "rnn-lm"): # Set it to False since there are too many logs. enable_log = False else: @@ -566,6 +654,10 @@ def main(): sos_id = graph_compiler.sos_id eos_id = graph_compiler.eos_id + params.num_classes = num_classes + params.sos_id = sos_id + params.eos_id = eos_id + if params.method == "ctc-decoding": HLG = None H = k2.ctc_topo( @@ -590,6 +682,7 @@ def main(): "nbest-rescoring", "whole-lattice-rescoring", "attention-decoder", + "rnn-lm", ): if not (params.lm_dir / "G_4_gram.pt").is_file(): logging.info("Loading G_4_gram.fst.txt") @@ -621,7 +714,11 @@ def main(): d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device) G = k2.Fsa.from_dict(d) - if params.method in ["whole-lattice-rescoring", "attention-decoder"]: + if params.method in [ + "whole-lattice-rescoring", + "attention-decoder", + "rnn-lm", + ]: # Add epsilon self-loops to G as we will compose # it with the whole lattice later G = k2.add_epsilon_self_loops(G) @@ -648,20 +745,40 @@ def main(): if params.avg == 1: load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) + model = load_averaged_model( + params.exp_dir, model, params.epoch, params.avg, device + ) model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") + rnn_lm_model = None + if params.method == "rnn-lm": + rnn_lm_model = RnnLmModel( + vocab_size=params.num_classes, + embedding_dim=params.rnn_lm_embedding_dim, + hidden_dim=params.rnn_lm_hidden_dim, + num_layers=params.rnn_lm_num_layers, + tie_weights=params.rnn_lm_tie_weights, + ) + if params.rnn_lm_avg == 1: + load_checkpoint( + f"{params.rnn_lm_exp_dir}/epoch-{params.rnn_lm_epoch}.pt", + rnn_lm_model, + ) + rnn_lm_model.to(device) + else: + rnn_lm_model = load_averaged_model( + params.rnn_lm_exp_dir, + rnn_lm_model, + params.rnn_lm_epoch, + params.rnn_lm_avg, + device, + ) + rnn_lm_model.eval() + librispeech = LibriSpeechAsrDataModule(args) test_clean_cuts = librispeech.test_clean_cuts() @@ -678,6 +795,7 @@ def main(): dl=test_dl, params=params, model=model, + rnn_lm_model=rnn_lm_model, HLG=HLG, H=H, bpe_model=bpe_model, diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 94d23afed..030122aa7 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -23,6 +23,7 @@ This file downloads the following LibriSpeech LM files: - 4-gram.arpa.gz - librispeech-vocab.txt - librispeech-lexicon.txt + - librispeech-lm-norm.txt.gz from http://www.openslr.org/resources/11 and save them in the user provided directory. @@ -61,6 +62,7 @@ def main(out_dir: str): "4-gram.arpa.gz", "librispeech-vocab.txt", "librispeech-lexicon.txt", + "librispeech-lm-norm.txt.gz", ) for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"): diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py new file mode 100755 index 000000000..5070341f1 --- /dev/null +++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Daniel Povey +# Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This script takes a `bpe.model` and a text file such as +./download/lm/librispeech-lm-norm.txt +and outputs the LM training data to a supplied directory such +as data/lm_training_bpe_500. The format is as follows: + +It creates a PyTorch archive (.pt file), say data/lm_training.pt, which is a +representation of a dict with the following format: + + 'words' -> a k2.RaggedTensor of two axes [word][token] with dtype torch.int32 + containing the BPE representations of each word, indexed by + integer word ID. (These integer word IDS are present in + 'lm_data'). The sentencepiece object can be used to turn the + words and BPE units into string form. + 'sentences' -> a k2.RaggedTensor of two axes [sentence][word] with dtype + torch.int32 containing all the sentences, as word-ids (we don't + output the string form of this directly but it can be worked out + together with 'words' and the bpe.model). + 'sentence_lengths' -> a 1-D torch.Tensor of dtype torch.int32, containing + number of BPE tokens of each sentence. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import sentencepiece as spm +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--bpe-model", + type=str, + help="Input BPE model, e.g. data/bpe_500/bpe.model", + ) + parser.add_argument( + "--lm-data", + type=str, + help="""Input LM training data as text, e.g. + download/pb.train.txt""", + ) + parser.add_argument( + "--lm-archive", + type=str, + help="""Path to output archive, e.g. data/bpe_500/lm_data.pt; + look at the source of this script to see the format.""", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + + if Path(args.lm_archive).exists(): + logging.warning(f"{args.lm_archive} exists - skipping") + return + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + # word2index is a dictionary from words to integer ids. No need to reserve + # space for epsilon, etc.; the words are just used as a convenient way to + # compress the sequences of BPE pieces. + word2index = dict() + + word2bpe = [] # Will be a list-of-list-of-int, representing BPE pieces. + sentences = [] # Will be a list-of-list-of-int, representing word-ids. + + if "librispeech-lm-norm" in args.lm_data: + num_lines_in_total = 40418261.0 + step = 5000000 + elif "valid" in args.lm_data: + num_lines_in_total = 5567.0 + step = 3000 + elif "test" in args.lm_data: + num_lines_in_total = 5559.0 + step = 3000 + else: + num_lines_in_total = None + step = None + + processed = 0 + + with open(args.lm_data) as f: + while True: + line = f.readline() + if line == "": + break + + if step and processed % step == 0: + logging.info( + f"Processed number of lines: {processed} " + f"({processed/num_lines_in_total*100: .3f}%)" + ) + processed += 1 + + line_words = line.split() + for w in line_words: + if w not in word2index: + w_bpe = sp.encode(w) + word2index[w] = len(word2bpe) + word2bpe.append(w_bpe) + sentences.append([word2index[w] for w in line_words]) + + logging.info("Constructing ragged tensors") + words = k2.ragged.RaggedTensor(word2bpe) + sentences = k2.ragged.RaggedTensor(sentences) + + output = dict(words=words, sentences=sentences) + + num_sentences = sentences.dim0 + logging.info(f"Computing sentence lengths, num_sentences: {num_sentences}") + sentence_lengths = [0] * num_sentences + for i in range(num_sentences): + if step and i % step == 0: + logging.info( + f"Processed number of lines: {i} " + f"({i/num_sentences*100: .3f}%)" + ) + + word_ids = sentences[i] + + # NOTE: If word_ids is a tensor with only 1 entry, + # token_ids is a torch.Tensor + token_ids = words[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + + # token_ids is a 1-D tensor containing the BPE tokens + # of the current sentence + + sentence_lengths[i] = token_ids.numel() + + output["sentence_lengths"] = torch.tensor( + sentence_lengths, dtype=torch.int32 + ) + + torch.save(output, args.lm_archive) + logging.info(f"Saved to {args.lm_archive}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/local/sort_lm_training_data.py b/egs/librispeech/ASR/local/sort_lm_training_data.py new file mode 120000 index 000000000..bb26b5f5c --- /dev/null +++ b/egs/librispeech/ASR/local/sort_lm_training_data.py @@ -0,0 +1 @@ +../../../ptb/LM/local/sort_lm_training_data.py \ No newline at end of file diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py index bc5812810..42aba9572 100755 --- a/egs/librispeech/ASR/local/train_bpe_model.py +++ b/egs/librispeech/ASR/local/train_bpe_model.py @@ -38,7 +38,6 @@ def get_args(): "--lang-dir", type=str, help="""Input and output directory. - It should contain the training corpus: transcript_words.txt. The generated bpe.model is saved to this directory. """, ) diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 17a638502..94e003036 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -24,6 +24,7 @@ stop_stage=100 # - 4-gram.arpa # - librispeech-vocab.txt # - librispeech-lexicon.txt +# - librispeech-lm-norm.txt.gz # # - $dl_dir/musan # This directory contains the following directories downloaded from @@ -40,9 +41,9 @@ dl_dir=$PWD/download # It will generate data/lang_bpe_xxx, # data/lang_bpe_yyy if the array contains xxx, yyy vocab_sizes=( - # 5000 - # 2000 - # 1000 + 5000 + 2000 + 1000 500 ) @@ -278,3 +279,99 @@ if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then ./local/compile_lg.py --lang-dir $lang_dir done fi + +if [ $stage -le 11 ] && [ $stop_stage -ge 11 ]; then + log "Stage 11: Generate LM training data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + lang_dir=data/lang_bpe_${vocab_size} + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $dl_dir/lm/librispeech-lm-norm.txt \ + --lm-archive $out_dir/lm_data.pt + done +fi + +if [ $stage -le 12 ] && [ $stop_stage -ge 12 ]; then + log "Stage 12: Generate LM validation data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/valid.txt ]; then + files=$( + find "$dl_dir/LibriSpeech/dev-clean" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/dev-other" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > $out_dir/valid.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/valid.txt \ + --lm-archive $out_dir/lm_data-valid.pt + done +fi + +if [ $stage -le 13 ] && [ $stop_stage -ge 13 ]; then + log "Stage 13: Generate LM test data" + + for vocab_size in ${vocab_sizes[@]}; do + log "Processing vocab_size == ${vocab_size}" + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + + if [ ! -f $out_dir/test.txt ]; then + files=$( + find "$dl_dir/LibriSpeech/test-clean" -name "*.trans.txt" + find "$dl_dir/LibriSpeech/test-other" -name "*.trans.txt" + ) + for f in ${files[@]}; do + cat $f | cut -d " " -f 2- + done > $out_dir/test.txt + fi + + lang_dir=data/lang_bpe_${vocab_size} + ./local/prepare_lm_training_data.py \ + --bpe-model $lang_dir/bpe.model \ + --lm-data $out_dir/test.txt \ + --lm-archive $out_dir/lm_data-test.pt + done +fi + +if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then + log "Stage 14: Sort LM training data" + # Sort LM training data by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of BPE tokens + # in a sentence. + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/lm_training_bpe_${vocab_size} + mkdir -p $out_dir + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-valid.pt \ + --out-lm-data $out_dir/sorted_lm_data-valid.pt \ + --out-statistics $out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-test.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt + done +fi diff --git a/egs/ptb/LM/README.md b/egs/ptb/LM/README.md new file mode 100644 index 000000000..7629a950d --- /dev/null +++ b/egs/ptb/LM/README.md @@ -0,0 +1,18 @@ +## Description + +(Note: the experiments here are only about language modeling) + +ptb is short for Penn Treebank. + + +About the Penn Treebank corpus: + - This corpus is free for research purposes + - ptb.train.txt: train set + - ptb.valid.txt: development set (should be used just for tuning hyper-parameters, but not for training) + - ptb.test.txt: test set for reporting perplexity + +You can download the dataset from one of the following URLs: + +- https://github.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage +- http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz +- https://deepai.org/dataset/penn-treebank diff --git a/egs/ptb/LM/local/prepare_lm_training_data.py b/egs/ptb/LM/local/prepare_lm_training_data.py new file mode 120000 index 000000000..eebce957e --- /dev/null +++ b/egs/ptb/LM/local/prepare_lm_training_data.py @@ -0,0 +1 @@ +../../librispeech/ASR/local/prepare_lm_training_data.py \ No newline at end of file diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py new file mode 100755 index 000000000..af54dbd07 --- /dev/null +++ b/egs/ptb/LM/local/sort_lm_training_data.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file takes as input the filename of LM training data +generated by ./local/prepare_lm_training_data.py and sorts +it by sentence length. + +Sentence length equals to the number of BPE tokens in a sentence. +""" + +import argparse +import logging +from pathlib import Path + +import k2 +import numpy as np +import torch + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--in-lm-data", + type=str, + help="Input LM training data, e.g., data/bpe_500/lm_data.pt", + ) + + parser.add_argument( + "--out-lm-data", + type=str, + help="Input LM training data, e.g., data/bpe_500/sorted_lm_data.pt", + ) + + parser.add_argument( + "--out-statistics", + type=str, + help="Statistics about LM training data., data/bpe_500/statistics.txt", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + in_lm_data = Path(args.in_lm_data) + out_lm_data = Path(args.out_lm_data) + assert in_lm_data.is_file(), f"{in_lm_data}" + if out_lm_data.is_file(): + logging.warning(f"{out_lm_data} exists - skipping") + return + data = torch.load(in_lm_data) + words2bpe = data["words"] + sentences = data["sentences"] + sentence_lengths = data["sentence_lengths"] + + num_sentences = sentences.dim0 + assert num_sentences == sentence_lengths.numel(), ( + num_sentences, + sentence_lengths.numel(), + ) + + indices = torch.argsort(sentence_lengths, descending=True) + + sorted_sentences = sentences[indices.to(torch.int32)] + sorted_sentence_lengths = sentence_lengths[indices] + + # Check that sentences are ordered by length + assert num_sentences == sorted_sentences.dim0, ( + num_sentences, + sorted_sentences.dim0, + ) + + cur = None + for i in range(num_sentences): + word_ids = sorted_sentences[i] + token_ids = words2bpe[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + if cur is not None: + assert cur >= token_ids.numel(), (cur, token_ids.numel()) + + cur = token_ids.numel() + assert cur == sorted_sentence_lengths[i] + + data["sentences"] = sorted_sentences + data["sentence_lengths"] = sorted_sentence_lengths + torch.save(data, args.out_lm_data) + logging.info(f"Saved to {args.out_lm_data}") + + statistics = Path(args.out_statistics) + + # Write statistics + num_words = sorted_sentences.numel() + num_tokens = sentence_lengths.sum().item() + max_sentence_length = sentence_lengths[indices[0]] + min_sentence_length = sentence_lengths[indices[-1]] + + step = 10 + hist, bins = np.histogram( + sentence_lengths.numpy(), + bins=np.arange(1, max_sentence_length + step, step), + ) + + histogram = np.stack((bins[:-1], hist)).transpose() + + with open(statistics, "w") as f: + f.write(f"num_sentences: {num_sentences}\n") + f.write(f"num_words: {num_words}\n") + f.write(f"num_tokens: {num_tokens}\n") + f.write(f"max_sentence_length: {max_sentence_length}\n") + f.write(f"min_sentence_length: {min_sentence_length}\n") + f.write("histogram:\n") + f.write(" bin count percent\n") + for row in histogram: + f.write( + f"{int(row[0]):>5} {int(row[1]):>5} " + f"{100.*row[1]/num_sentences:.3f}%\n" + ) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ptb/LM/local/test_prepare_lm_training_data.py b/egs/ptb/LM/local/test_prepare_lm_training_data.py new file mode 100755 index 000000000..877720e7b --- /dev/null +++ b/egs/ptb/LM/local/test_prepare_lm_training_data.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path + +import sentencepiece as spm +import torch + + +def main(): + lm_training_data = Path("./data/bpe_500/lm_data.pt") + bpe_model = Path("./data/bpe_500/bpe.model") + if not lm_training_data.exists(): + logging.warning(f"{lm_training_data} does not exist - skipping") + return + + if not bpe_model.exists(): + logging.warning(f"{bpe_model} does not exist - skipping") + return + + sp = spm.SentencePieceProcessor() + sp.load(str(bpe_model)) + + data = torch.load(lm_training_data) + words2bpe = data["words"] + sentences = data["sentences"] + + ss = [] + unk = sp.decode(sp.unk_id()).strip() + for i in range(10): + s = sp.decode(words2bpe[sentences[i]].values.tolist()) + s = s.replace(unk, "") + ss.append(s) + + for s in ss: + print(s) + # You can compare the output with the first 10 lines of ptb.train.txt + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/ptb/LM/local/train_bpe_model.py b/egs/ptb/LM/local/train_bpe_model.py new file mode 120000 index 000000000..6f018a0e2 --- /dev/null +++ b/egs/ptb/LM/local/train_bpe_model.py @@ -0,0 +1 @@ +../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file diff --git a/egs/ptb/LM/prepare.sh b/egs/ptb/LM/prepare.sh new file mode 100755 index 000000000..70586785d --- /dev/null +++ b/egs/ptb/LM/prepare.sh @@ -0,0 +1,115 @@ +#!/usr/bin/env bash + +set -eou pipefail + +nj=15 +stage=-1 +stop_stage=100 + +dl_dir=$PWD/download +# The following files will be downloaded to $dl_dir +# - ptb.train.txt +# - ptb.valid.txt +# - ptb.test.txt + +. shared/parse_options.sh || exit 1 + +# vocab size for sentence piece models. +# It will generate data/bpe_xxx, data/bpe_yyy +# if the array contains xxx, yyy +vocab_sizes=( + 500 + 1000 + 2000 + 5000 +) + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data +mkdir -p $dl_dir + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +log "dl_dir: $dl_dir" + +if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then + log "Stage -1: Download data" + if [ ! -f $dl_dir/.complete ]; then + url=https://raw.githubusercontent.com/townie/PTB-dataset-from-Tomas-Mikolov-s-webpage/master/data/ + wget --no-verbose --directory-prefix $dl_dir $url/ptb.train.txt + wget --no-verbose --directory-prefix $dl_dir $url/ptb.valid.txt + wget --no-verbose --directory-prefix $dl_dir $url/ptb.test.txt + touch $dl_dir/.complete + fi +fi + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "Stage 0: Train BPE model" + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/bpe_${vocab_size} + mkdir -p $out_dir + ./local/train_bpe_model.py \ + --out-dir $out_dir \ + --vocab-size $vocab_size \ + --transcript $dl_dir/ptb.train.txt + done +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Generate LM training data" + # Note: ptb.train.txt has already been normalized + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/bpe_${vocab_size} + mkdir -p $out_dir + ./local/prepare_lm_training_data.py \ + --bpe-model $out_dir/bpe.model \ + --lm-data $dl_dir/ptb.train.txt \ + --lm-archive $out_dir/lm_data.pt + + ./local/prepare_lm_training_data.py \ + --bpe-model $out_dir/bpe.model \ + --lm-data $dl_dir/ptb.valid.txt \ + --lm-archive $out_dir/lm_data-valid.pt + + ./local/prepare_lm_training_data.py \ + --bpe-model $out_dir/bpe.model \ + --lm-data $dl_dir/ptb.test.txt \ + --lm-archive $out_dir/lm_data-test.pt + done +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Sort LM training data" + # Sort LM training data generated in stage 1 + # by sentence length in descending order + # for ease of training. + # + # Sentence length equals to the number of BPE tokens + # in a sentence. + + for vocab_size in ${vocab_sizes[@]}; do + out_dir=data/bpe_${vocab_size} + mkdir -p $out_dir + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data.pt \ + --out-lm-data $out_dir/sorted_lm_data.pt \ + --out-statistics $out_dir/statistics.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-valid.pt \ + --out-lm-data $out_dir/sorted_lm_data-valid.pt \ + --out-statistics $out_dir/statistics-valid.txt + + ./local/sort_lm_training_data.py \ + --in-lm-data $out_dir/lm_data-test.pt \ + --out-lm-data $out_dir/sorted_lm_data-test.pt \ + --out-statistics $out_dir/statistics-test.txt + done +fi diff --git a/egs/ptb/LM/shared b/egs/ptb/LM/shared new file mode 120000 index 000000000..4c5e91438 --- /dev/null +++ b/egs/ptb/LM/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/icefall/decode.py b/icefall/decode.py index 3ba899b4e..680e29619 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -20,7 +20,34 @@ from typing import Dict, List, Optional, Union import k2 import torch -from icefall.utils import get_texts +from icefall.utils import add_eos, add_sos, get_texts + +DEFAULT_LM_SCALE = [ + 0.01, + 0.05, + 0.08, + 0.1, + 0.3, + 0.5, + 0.6, + 0.7, + 0.9, + 1.0, + 1.1, + 1.2, + 1.3, + 1.5, + 1.7, + 1.9, + 2.0, + 2.1, + 2.2, + 2.3, + 2.5, + 3.0, + 4.0, + 5.0, +] def _intersect_device( @@ -952,3 +979,161 @@ def rescore_with_attention_decoder( key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}" ans[key] = best_path return ans + + +def rescore_with_rnn_lm( + lattice: k2.Fsa, + num_paths: int, + rnn_lm_model: torch.nn.Module, + model: torch.nn.Module, + memory: torch.Tensor, + memory_key_padding_mask: Optional[torch.Tensor], + sos_id: int, + eos_id: int, + blank_id: int, + nbest_scale: float = 1.0, + ngram_lm_scale: Optional[float] = None, + attention_scale: Optional[float] = None, + rnn_lm_scale: Optional[float] = None, + use_double_scores: bool = True, +) -> Dict[str, k2.Fsa]: + """This function extracts `num_paths` paths from the given lattice and uses + an attention decoder to rescore them. The path with the highest score is + the decoding output. + + Args: + lattice: + An FsaVec with axes [utt][state][arc]. + num_paths: + Number of paths to extract from the given lattice for rescoring. + model: + A transformer model. See the class "Transformer" in + conformer_ctc/transformer.py for its interface. + memory: + The encoder memory of the given model. It is the output of + the last torch.nn.TransformerEncoder layer in the given model. + Its shape is `(T, N, C)`. + memory_key_padding_mask: + The padding mask for memory with shape `(N, T)`. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. + nbest_scale: + It's the scale applied to `lattice.scores`. A smaller value + leads to more unique paths at the risk of missing the correct path. + ngram_lm_scale: + Optional. It specifies the scale for n-gram LM scores. + attention_scale: + Optional. It specifies the scale for attention decoder scores. + rnn_lm_scale: + Optional. It specifies the scale for RNN LM scores. + Returns: + A dict of FsaVec, whose key contains a string + ngram_lm_scale_attention_scale and the value is the + best decoding path for each utterance in the lattice. + """ + nbest = Nbest.from_lattice( + lattice=lattice, + num_paths=num_paths, + use_double_scores=use_double_scores, + nbest_scale=nbest_scale, + ) + # nbest.fsa.scores are all 0s at this point + + nbest = nbest.intersect(lattice) + # Now nbest.fsa has its scores set. + # Also, nbest.fsa inherits the attributes from `lattice`. + assert hasattr(nbest.fsa, "lm_scores") + + am_scores = nbest.compute_am_scores() + ngram_lm_scores = nbest.compute_lm_scores() + + # The `tokens` attribute is set inside `compile_hlg.py` + assert hasattr(nbest.fsa, "tokens") + assert isinstance(nbest.fsa.tokens, torch.Tensor) + + path_to_utt_map = nbest.shape.row_ids(1).to(torch.long) + # the shape of memory is (T, N, C), so we use axis=1 here + expanded_memory = memory.index_select(1, path_to_utt_map) + + if memory_key_padding_mask is not None: + # The shape of memory_key_padding_mask is (N, T), so we + # use axis=0 here. + expanded_memory_key_padding_mask = memory_key_padding_mask.index_select( + 0, path_to_utt_map + ) + else: + expanded_memory_key_padding_mask = None + + # remove axis corresponding to states. + tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) + tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens) + tokens = tokens.remove_values_leq(0) + token_ids = tokens.tolist() + + if len(token_ids) == 0: + print("Warning: rescore_with_attention_decoder(): empty token-ids") + return None + + nll = model.decoder_nll( + memory=expanded_memory, + memory_key_padding_mask=expanded_memory_key_padding_mask, + token_ids=token_ids, + sos_id=sos_id, + eos_id=eos_id, + ) + assert nll.ndim == 2 + assert nll.shape[0] == len(token_ids) + + attention_scores = -nll.sum(dim=1) + + # Now for RNN LM + sos_tokens = add_sos(tokens, sos_id) + tokens_eos = add_eos(tokens, eos_id) + sos_tokens_row_splits = sos_tokens.shape.row_splits(1) + sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + + x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) + y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) + + x_tokens = x_tokens.to(torch.int64) + y_tokens = y_tokens.to(torch.int64) + sentence_lengths = sentence_lengths.to(torch.int64) + + rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths) + assert rnn_lm_nll.ndim == 2 + assert rnn_lm_nll.shape[0] == len(token_ids) + + rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) + + ngram_lm_scale_list = DEFAULT_LM_SCALE + attention_scale_list = DEFAULT_LM_SCALE + rnn_lm_scale_list = DEFAULT_LM_SCALE + + if ngram_lm_scale: + ngram_lm_scale_list = [ngram_lm_scale] + + if attention_scale: + attention_scale_list = [attention_scale] + + if rnn_lm_scale: + rnn_lm_scale_list = [rnn_lm_scale] + + ans = dict() + for n_scale in ngram_lm_scale_list: + for a_scale in attention_scale_list: + for r_scale in rnn_lm_scale_list: + tot_scores = ( + am_scores.values + + n_scale * ngram_lm_scores.values + + a_scale * attention_scores + + r_scale * rnn_lm_scores + ) + ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) + max_indexes = ragged_tot_scores.argmax() + best_path = k2.index_fsa(nbest.fsa, max_indexes) + + key = f"ngram_lm_scale_{n_scale}_attention_scale_{a_scale}_rnn_lm_scale_{r_scale}" # noqa + ans[key] = best_path + return ans diff --git a/icefall/dist.py b/icefall/dist.py index 203c7c563..6334f9c13 100644 --- a/icefall/dist.py +++ b/icefall/dist.py @@ -21,14 +21,46 @@ import torch from torch import distributed as dist -def setup_dist(rank, world_size, master_port=None): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = ( - "12354" if master_port is None else str(master_port) - ) - dist.init_process_group("nccl", rank=rank, world_size=world_size) - torch.cuda.set_device(rank) +def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False): + """ + rank and world_size are used only if use_ddp_launch is False. + """ + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = "localhost" + + if "MASTER_PORT" not in os.environ: + os.environ["MASTER_PORT"] = ( + "12354" if master_port is None else str(master_port) + ) + + if use_ddp_launch is False: + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + else: + dist.init_process_group("nccl") def cleanup_dist(): dist.destroy_process_group() + + +def get_world_size(): + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) + if dist.is_available() and dist.is_initialized(): + return dist.get_world_size() + else: + return 1 + + +def get_rank(): + if "RANK" in os.environ: + return int(os.environ["RANK"]) + elif dist.is_available() and dist.is_initialized(): + return dist.rank() + else: + return 1 + + +def get_local_rank(): + return int(os.environ.get("LOCAL_RANK", 0)) diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py new file mode 100755 index 000000000..550801a8f --- /dev/null +++ b/icefall/rnn_lm/compute_perplexity.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + ./rnn_lm/compute_perplexity.py \ + --epoch 4 \ + --avg 2 \ + --lm-data ./data/bpe_500/sorted_lm_data-test.pt + +""" + +import argparse +import logging +import math +from pathlib import Path + +import torch +from dataset import get_dataloader +from model import RnnLmModel + +from icefall.checkpoint import average_checkpoints, load_checkpoint +from icefall.utils import AttributeDict, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=49, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + parser.add_argument( + "--avg", + type=int, + default=20, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="rnn_lm/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--lm-data", + type=str, + help="Path to the LM test data for computing perplexity", + ) + + parser.add_argument( + "--vocab-size", + type=int, + default=500, + help="Vocabulary size of the model", + ) + + parser.add_argument( + "--embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--num-layers", + type=int, + default=3, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + parser.add_argument( + "--batch-size", + type=int, + default=50, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--max-sent-len", + type=int, + default=100, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--sos-id", + type=int, + default=1, + help="SOS ID", + ) + + parser.add_argument( + "--eos-id", + type=int, + default=1, + help="EOS ID", + ) + + parser.add_argument( + "--blank-id", + type=int, + default=0, + help="Blank ID", + ) + return parser + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + args.lm_data = Path(args.lm_data) + + params = AttributeDict(vars(args)) + + setup_logger(f"{params.exp_dir}/log-ppl/") + logging.info("Computing perplexity started") + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + logging.info("About to create model") + model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + hidden_dim=params.hidden_dim, + num_layers=params.num_layers, + tie_weights=params.tie_weights, + ) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if start >= 0: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + model.eval() + num_param = sum([p.numel() for p in model.parameters()]) + num_param_requires_grad = sum( + [p.numel() for p in model.parameters() if p.requires_grad] + ) + + logging.info(f"Number of model parameters: {num_param}") + logging.info( + f"Number of model parameters (requires_grad): " + f"{num_param_requires_grad} " + f"({num_param_requires_grad/num_param_requires_grad*100}%)" + ) + + logging.info(f"Loading LM test data from {params.lm_data}") + test_dl = get_dataloader( + filename=params.lm_data, + is_distributed=False, + params=params, + ) + + tot_loss = 0.0 + num_tokens = 0 + num_sentences = 0 + for batch_idx, batch in enumerate(test_dl): + x, y, sentence_lengths = batch + x = x.to(device) + y = y.to(device) + sentence_lengths = sentence_lengths.to(device) + + nll = model(x, y, sentence_lengths) + loss = nll.sum().cpu().item() + + tot_loss += loss + num_tokens += sentence_lengths.sum().cpu().item() + num_sentences += x.size(0) + + ppl = math.exp(tot_loss / num_tokens) + logging.info( + f"total nll: {tot_loss}, num tokens: {num_tokens}, " + f"num sentences: {num_sentences}, ppl: {ppl:.3f}" + ) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + + +if __name__ == "__main__": + main() diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py new file mode 100644 index 000000000..598e329c4 --- /dev/null +++ b/icefall/rnn_lm/dataset.py @@ -0,0 +1,218 @@ +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Tuple + +import k2 +import torch +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from icefall.utils import AttributeDict, add_eos, add_sos + + +class LmDataset(torch.utils.data.Dataset): + def __init__( + self, + sentences: k2.RaggedTensor, + words: k2.RaggedTensor, + sentence_lengths: torch.Tensor, + max_sent_len: int, + batch_size: int, + ): + """ + Args: + sentences: + A ragged tensor of dtype torch.int32 with 2 axes [sentence][word]. + words: + A ragged tensor of dtype torch.int32 with 2 axes [word][token]. + sentence_lengths: + A 1-D tensor of dtype torch.int32 containing number of tokens + of each sentence. + max_sent_len: + Maximum sentence length. It is used to change the batch size + dynamically. In general, we try to keep the product of + "max_sent_len in a batch" and "num_of_sent in a batch" being + a constant. + batch_size: + The expected batch size. It is changed dynamically according + to the "max_sent_len". + + See `../local/prepare_lm_training_data.py` for how `sentences` and + `words` are generated. We assume that `sentences` are sorted by length. + See `../local/sort_lm_training_data.py`. + """ + super().__init__() + self.sentences = sentences + self.words = words + + sentence_lengths = sentence_lengths.tolist() + + assert batch_size > 0, batch_size + assert max_sent_len > 1, max_sent_len + batch_indexes = [] + num_sentences = sentences.dim0 + cur = 0 + while cur < num_sentences: + sz = sentence_lengths[cur] // max_sent_len + 1 + # Assume the current sentence has 3 * max_sent_len tokens, + # in the worst case, the subsequent sentences also have + # this number of tokens, we should reduce the batch size + # so that this batch will not contain too many tokens + actual_batch_size = batch_size // sz + 1 + actual_batch_size = min(actual_batch_size, batch_size) + end = cur + actual_batch_size + end = min(end, num_sentences) + this_batch_indexes = torch.arange(cur, end).tolist() + batch_indexes.append(this_batch_indexes) + cur = end + assert batch_indexes[-1][-1] == num_sentences - 1 + + self.batch_indexes = k2.RaggedTensor(batch_indexes) + + def __len__(self) -> int: + """Return number of batches in this dataset""" + return self.batch_indexes.dim0 + + def __getitem__(self, i: int) -> k2.RaggedTensor: + """Get the i'th batch in this dataset + Return a ragged tensor with 2 axes [sentence][token]. + """ + assert 0 <= i < len(self), i + + # indexes is a 1-D tensor containing sentence indexes + indexes = self.batch_indexes[i] + + # sentence_words is a ragged tensor with 2 axes + # [sentence][word] + sentence_words = self.sentences[indexes] + + # in case indexes contains only 1 entry, the returned + # sentence_words is a 1-D tensor, we have to convert + # it to a ragged tensor + if isinstance(sentence_words, torch.Tensor): + sentence_words = k2.RaggedTensor(sentence_words.unsqueeze(0)) + + # sentence_word_tokens is a ragged tensor with 3 axes + # [sentence][word][token] + sentence_word_tokens = self.words.index(sentence_words) + assert sentence_word_tokens.num_axes == 3 + + sentence_tokens = sentence_word_tokens.remove_axis(1) + return sentence_tokens + + +class LmDatasetCollate: + def __init__(self, sos_id: int, eos_id: int, blank_id: int): + """ + Args: + sos_id: + Token ID of the SOS symbol. + eos_id: + Token ID of the EOS symbol. + blank_id: + Token ID of the blank symbol. + """ + self.sos_id = sos_id + self.eos_id = eos_id + self.blank_id = blank_id + + def __call__( + self, batch: List[k2.RaggedTensor] + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Return a tuple containing 3 tensors: + + - x, a 2-D tensor of dtype torch.int32; each row contains tokens + for a sentence starting with `self.sos_id`. It is padded to + the max sentence length with `self.blank_id`. + + - y, a 2-D tensor of dtype torch.int32; each row contains tokens + for a sentence ending with `self.eos_id` before padding. + Then it is padded to the max sentence length with + `self.blank_id`. + + - lengths, a 2-D tensor of dtype torch.int32, containing the number of + tokens of each sentence before padding. + """ + # The batching stuff has already been done in LmDataset + assert len(batch) == 1 + sentence_tokens = batch[0] + row_splits = sentence_tokens.shape.row_splits(1) + sentence_token_lengths = row_splits[1:] - row_splits[:-1] + sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id) + sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id) + + x = sentence_tokens_with_sos.pad( + mode="constant", padding_value=self.blank_id + ) + y = sentence_tokens_with_eos.pad( + mode="constant", padding_value=self.blank_id + ) + sentence_token_lengths += 1 # plus 1 since we added a SOS + + return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths + + +def get_dataloader( + filename: str, + is_distributed: bool, + params: AttributeDict, +) -> torch.utils.data.DataLoader: + """Get dataloader for LM training. + + Args: + filename: + Path to the file containing LM data. The file is assumed to + be generated by `../local/sort_lm_training_data.py`. + is_distributed: + True if using DDP training. False otherwise. + params: + Set `get_params()` from `rnn_lm/train.py` + Returns: + Return a dataloader containing the LM data. + """ + lm_data = torch.load(filename) + + words = lm_data["words"] + sentences = lm_data["sentences"] + sentence_lengths = lm_data["sentence_lengths"] + + dataset = LmDataset( + sentences=sentences, + words=words, + sentence_lengths=sentence_lengths, + max_sent_len=params.max_sent_len, + batch_size=params.batch_size, + ) + if is_distributed: + sampler = DistributedSampler(dataset, shuffle=True, drop_last=False) + else: + sampler = None + + collate_fn = LmDatasetCollate( + sos_id=params.sos_id, + eos_id=params.eos_id, + blank_id=params.blank_id, + ) + + dataloader = DataLoader( + dataset, + batch_size=1, + collate_fn=collate_fn, + sampler=sampler, + shuffle=sampler is None, + ) + return dataloader diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py new file mode 100644 index 000000000..094035fce --- /dev/null +++ b/icefall/rnn_lm/export.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +# +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. + +import argparse +import logging +from pathlib import Path + +import torch +from model import RnnLmModel + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, load_averaged_model, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=29, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--avg", + type=int, + default=5, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--vocab-size", + type=int, + default=500, + help="Vocabulary size of the model", + ) + + parser.add_argument( + "--embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--num-layers", + type=int, + default=3, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--tie-weights", + type=str2bool, + default=True, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="rnn_lm/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--jit", + type=str2bool, + default=True, + help="""True to save a model after applying torch.jit.script. + """, + ) + + return parser + + +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = AttributeDict({}) + params.update(vars(args)) + + logging.info(params) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + hidden_dim=params.hidden_dim, + num_layers=params.num_layers, + tie_weights=params.tie_weights, + ) + + model.to(device) + + if params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + model = load_averaged_model( + params.exp_dir, model, params.epoch, params.avg, device + ) + + model.to("cpu") + model.eval() + + if params.jit: + logging.info("Using torch.jit.script") + model = torch.jit.script(model) + filename = params.exp_dir / "cpu_jit.pt" + model.save(str(filename)) + logging.info(f"Saved to {filename}") + else: + logging.info("Not using torch.jit.script") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py new file mode 100644 index 000000000..88b2cc41f --- /dev/null +++ b/icefall/rnn_lm/model.py @@ -0,0 +1,120 @@ +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +import torch +import torch.nn.functional as F + +from icefall.utils import make_pad_mask + + +class RnnLmModel(torch.nn.Module): + def __init__( + self, + vocab_size: int, + embedding_dim: int, + hidden_dim: int, + num_layers: int, + tie_weights: bool = False, + ): + """ + Args: + vocab_size: + Vocabulary size of BPE model. + embedding_dim: + Input embedding dimension. + hidden_dim: + Hidden dimension of RNN layers. + num_layers: + Number of RNN layers. + tie_weights: + True to share the weights between the input embedding layer and the + last output linear layer. See https://arxiv.org/abs/1608.05859 + and https://arxiv.org/abs/1611.01462 + """ + super().__init__() + + self.input_embedding = torch.nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=embedding_dim, + ) + + self.rnn = torch.nn.LSTM( + input_size=embedding_dim, + hidden_size=hidden_dim, + num_layers=num_layers, + batch_first=True, + ) + + self.output_linear = torch.nn.Linear( + in_features=hidden_dim, out_features=vocab_size + ) + + self.vocab_size = vocab_size + if tie_weights: + logging.info("Tying weights") + assert embedding_dim == hidden_dim, (embedding_dim, hidden_dim) + self.output_linear.weight = self.input_embedding.weight + else: + logging.info("Not tying weights") + + def forward( + self, x: torch.Tensor, y: torch.Tensor, lengths: torch.Tensor + ) -> torch.Tensor: + """ + Args: + x: + A 2-D tensor with shape (N, L). Each row + contains token IDs for a sentence and starts with the SOS token. + y: + A shifted version of `x` and with EOS appended. + lengths: + A 1-D tensor of shape (N,). It contains the sentence lengths + before padding. + Returns: + Return a 2-D tensor of shape (N, L) containing negative log-likelihood + loss values. Note: Loss values for padding positions are set to 0. + """ + assert x.ndim == y.ndim == 2, (x.ndim, y.ndim) + assert lengths.ndim == 1, lengths.ndim + assert x.shape == y.shape, (x.shape, y.shape) + + batch_size = x.size(0) + assert lengths.size(0) == batch_size, (lengths.size(0), batch_size) + + # embedding is of shape (N, L, embedding_dim) + embedding = self.input_embedding(x) + + # Note: We use batch_first==True + rnn_out, _ = self.rnn(embedding) + logits = self.output_linear(rnn_out) + + # Note: No need to use `log_softmax()` here + # since F.cross_entropy() expects unnormalized probabilities + + # nll_loss is of shape (N*L,) + # nll -> negative log-likelihood + nll_loss = F.cross_entropy( + logits.reshape(-1, self.vocab_size), y.reshape(-1), reduction="none" + ) + # Set loss values for padding positions to 0 + mask = make_pad_mask(lengths).reshape(-1) + nll_loss.masked_fill_(mask, 0) + + nll_loss = nll_loss.reshape(batch_size, -1) + + return nll_loss diff --git a/icefall/rnn_lm/test_dataset.py b/icefall/rnn_lm/test_dataset.py new file mode 100755 index 000000000..bf961f54b --- /dev/null +++ b/icefall/rnn_lm/test_dataset.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import k2 +import torch +from rnn_lm.dataset import LmDataset, LmDatasetCollate + + +def main(): + sentences = k2.RaggedTensor( + [[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]] + ) + words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]]) + + num_sentences = sentences.dim0 + + sentence_lengths = [0] * num_sentences + for i in range(num_sentences): + word_ids = sentences[i] + + # NOTE: If word_ids is a tensor with only 1 entry, + # token_ids is a torch.Tensor + token_ids = words[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + + # token_ids is a 1-D tensor containing the BPE tokens + # of the current sentence + + sentence_lengths[i] = token_ids.numel() + + sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32) + + indices = torch.argsort(sentence_lengths, descending=True) + sentences = sentences[indices.to(torch.int32)] + sentence_lengths = sentence_lengths[indices] + + dataset = LmDataset( + sentences=sentences, + words=words, + sentence_lengths=sentence_lengths, + max_sent_len=3, + batch_size=4, + ) + + collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0) + dataloader = torch.utils.data.DataLoader( + dataset, batch_size=1, collate_fn=collate_fn + ) + + for i in dataloader: + print(i) + # I've checked the output manually; the output is as expected. + + +if __name__ == "__main__": + main() diff --git a/icefall/rnn_lm/test_dataset_ddp.py b/icefall/rnn_lm/test_dataset_ddp.py new file mode 100755 index 000000000..48fbb19cb --- /dev/null +++ b/icefall/rnn_lm/test_dataset_ddp.py @@ -0,0 +1,103 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import k2 +import torch +import torch.multiprocessing as mp +from rnn_lm.dataset import LmDataset, LmDatasetCollate +from torch import distributed as dist + + +def generate_data(): + sentences = k2.RaggedTensor( + [[0, 1, 2], [1, 0, 1], [0, 1], [1, 3, 0, 2, 0], [3], [0, 2, 1]] + ) + words = k2.RaggedTensor([[3, 6], [2, 8, 9, 3], [5], [5, 6, 7, 8, 9]]) + + num_sentences = sentences.dim0 + + sentence_lengths = [0] * num_sentences + for i in range(num_sentences): + word_ids = sentences[i] + + # NOTE: If word_ids is a tensor with only 1 entry, + # token_ids is a torch.Tensor + token_ids = words[word_ids] + if isinstance(token_ids, k2.RaggedTensor): + token_ids = token_ids.values + + # token_ids is a 1-D tensor containing the BPE tokens + # of the current sentence + + sentence_lengths[i] = token_ids.numel() + + sentence_lengths = torch.tensor(sentence_lengths, dtype=torch.int32) + + indices = torch.argsort(sentence_lengths, descending=True) + sentences = sentences[indices.to(torch.int32)] + sentence_lengths = sentence_lengths[indices] + + return sentences, words, sentence_lengths + + +def run(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12352" + + dist.init_process_group("nccl", rank=rank, world_size=world_size) + torch.cuda.set_device(rank) + + sentences, words, sentence_lengths = generate_data() + + dataset = LmDataset( + sentences=sentences, + words=words, + sentence_lengths=sentence_lengths, + max_sent_len=3, + batch_size=4, + ) + sampler = torch.utils.data.distributed.DistributedSampler( + dataset, shuffle=True, drop_last=False + ) + + collate_fn = LmDatasetCollate(sos_id=1, eos_id=-1, blank_id=0) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=1, + collate_fn=collate_fn, + sampler=sampler, + shuffle=False, + ) + + for i in dataloader: + print(f"rank: {rank}", i) + + dist.destroy_process_group() + + +def main(): + world_size = 2 + mp.spawn(run, args=(world_size,), nprocs=world_size, join=True) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/icefall/rnn_lm/test_model.py b/icefall/rnn_lm/test_model.py new file mode 100755 index 000000000..5a216a3fb --- /dev/null +++ b/icefall/rnn_lm/test_model.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +# Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from rnn_lm.model import RnnLmModel + + +def test_rnn_lm_model(): + vocab_size = 4 + model = RnnLmModel( + vocab_size=vocab_size, embedding_dim=10, hidden_dim=10, num_layers=2 + ) + x = torch.tensor( + [ + [1, 3, 2, 2], + [1, 2, 2, 0], + [1, 2, 0, 0], + ] + ) + y = torch.tensor( + [ + [3, 2, 2, 1], + [2, 2, 1, 0], + [2, 1, 0, 0], + ] + ) + lengths = torch.tensor([4, 3, 2]) + nll_loss = model(x, y, lengths) + print(nll_loss) + """ + tensor([[1.1180, 1.3059, 1.2426, 1.7773], + [1.4231, 1.2783, 1.7321, 0.0000], + [1.4231, 1.6752, 0.0000, 0.0000]], grad_fn=) + """ + + +def test_rnn_lm_model_tie_weights(): + model = RnnLmModel( + vocab_size=10, + embedding_dim=10, + hidden_dim=10, + num_layers=2, + tie_weights=True, + ) + assert model.input_embedding.weight is model.output_linear.weight + + +def main(): + test_rnn_lm_model() + test_rnn_lm_model_tie_weights() + + +if __name__ == "__main__": + torch.manual_seed(20211122) + main() diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py new file mode 100755 index 000000000..bb5f03fb9 --- /dev/null +++ b/icefall/rnn_lm/train.py @@ -0,0 +1,617 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Usage: + ./rnn_lm/train.py \ + --start-epoch 0 \ + --world-size 2 \ + --num-epochs 1 \ + --use-fp16 0 \ + --embedding-dim 800 \ + --hidden-dim 200 \ + --num-layers 2\ + --batch-size 400 + +""" + +import argparse +import logging +import math +from pathlib import Path +from shutil import copyfile +from typing import Optional, Tuple + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +import torch.optim as optim +from dataset import get_dataloader +from lhotse.utils import fix_random_seed +from model import RnnLmModel +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.nn.utils import clip_grad_norm_ +from torch.utils.tensorboard import SummaryWriter + +from icefall.checkpoint import load_checkpoint +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=10, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=0, + help="""Resume training from from this epoch. + If it is positive, it will load checkpoint from + exp_dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="rnn_lm/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, logs, etc, are saved + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--batch-size", + type=int, + default=50, + ) + + parser.add_argument( + "--lm-data", + type=str, + default="data/lm_training_bpe_500/sorted_lm_data.pt", + help="LM training data", + ) + + parser.add_argument( + "--lm-data-valid", + type=str, + default="data/lm_training_bpe_500/sorted_lm_data-valid.pt", + help="LM validation data", + ) + + parser.add_argument( + "--vocab-size", + type=int, + default=500, + help="Vocabulary size of the model", + ) + + parser.add_argument( + "--embedding-dim", + type=int, + default=2048, + help="Embedding dim of the model", + ) + + parser.add_argument( + "--hidden-dim", + type=int, + default=2048, + help="Hidden dim of the model", + ) + + parser.add_argument( + "--num-layers", + type=int, + default=3, + help="Number of RNN layers the model", + ) + + parser.add_argument( + "--tie-weights", + type=str2bool, + default=False, + help="""True to share the weights between the input embedding layer and the + last output linear layer + """, + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters.""" + + params = AttributeDict( + { + "max_sent_len": 200, + "sos_id": 1, + "eos_id": 1, + "blank_id": 0, + "lr": 1e-3, + "weight_decay": 1e-6, + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 200, + "reset_interval": 2000, + "valid_interval": 5000, + "env_info": get_env_info(), + } + ) + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, +) -> None: + """Load checkpoint from file. + + If params.start_epoch is positive, it will load the checkpoint from + `params.start_epoch - 1`. Otherwise, this function does nothing. + + Apart from loading state dict for `model`, `optimizer` and `scheduler`, + it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + optimizer: + The optimizer that we are using. + scheduler: + The learning rate scheduler we are using. + Returns: + Return None. + """ + if params.start_epoch <= 0: + return + + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + logging.info(f"Loading checkpoint: {filename}") + saved_params = load_checkpoint( + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: nn.Module, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + params=params, + optimizer=optimizer, + scheduler=scheduler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + model: nn.Module, + x: torch.Tensor, + y: torch.Tensor, + sentence_lengths: torch.Tensor, + is_training: bool, +) -> Tuple[torch.Tensor, MetricsTracker]: + """Compute the negative log-likelihood loss given a model and its input. + Args: + model: + The NN model, e.g., RnnLmModel. + x: + A 2-D tensor. Each row contains BPE token IDs for a sentence. Also, + each row starts with SOS ID. + y: + A 2-D tensor. Each row is a shifted version of the corresponding row + in `x` but ends with an EOS ID (before padding). + sentence_lengths: + A 1-D tensor containing number of tokens of each sentence + before padding. + is_training: + True for training. False for validation. + """ + with torch.set_grad_enabled(is_training): + device = model.device + x = x.to(device) + y = y.to(device) + sentence_lengths = sentence_lengths.to(device) + + nll = model(x, y, sentence_lengths) + loss = nll.sum() + + num_tokens = sentence_lengths.sum().item() + + loss_info = MetricsTracker() + # Note: Due to how MetricsTracker() is designed, + # we use "frames" instead of "num_tokens" as a key here + loss_info["frames"] = num_tokens + loss_info["loss"] = loss.detach().item() + return loss, loss_info + + +def compute_validation_loss( + params: AttributeDict, + model: nn.Module, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process. The validation loss + is saved in `params.valid_loss`. + """ + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + x, y, sentence_lengths = batch + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + model=model, + x=x, + y=y, + sentence_lengths=sentence_lengths, + is_training=False, + ) + + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: nn.Module, + optimizer: torch.optim.Optimizer, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all sentences is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + x, y, sentence_lengths = batch + batch_size = x.size(0) + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + model=model, + x=x, + y=y, + sentence_lengths=sentence_lengths, + is_training=True, + ) + + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + optimizer.zero_grad() + loss.backward() + clip_grad_norm_(model.parameters(), 5.0, 2.0) + optimizer.step() + + if batch_idx % params.log_interval == 0: + # Note: "frames" here means "num_tokens" + this_batch_ppl = math.exp(loss_info["loss"] / loss_info["frames"]) + tot_ppl = math.exp(tot_loss["loss"] / tot_loss["frames"]) + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}, ppl: {this_batch_ppl}] " + f"tot_loss[{tot_loss}, ppl: {tot_ppl}], " + f"batch size: {batch_size}" + ) + + if tb_writer is not None: + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + tb_writer.add_scalar( + "train/current_ppl", this_batch_ppl, params.batch_idx_train + ) + + tb_writer.add_scalar( + "train/tot_ppl", tot_ppl, params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + + valid_info = compute_validation_loss( + params=params, + model=model, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + + valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"]) + logging.info( + f"Epoch {params.cur_epoch}, validation: {valid_info}, " + f"ppl: {valid_ppl}" + ) + + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + tb_writer.add_scalar( + "train/valid_ppl", valid_ppl, params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + is_distributed = world_size > 1 + + fix_random_seed(params.seed) + if is_distributed: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + logging.info(params) + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + + logging.info(f"Device: {device}") + + logging.info("About to create model") + model = RnnLmModel( + vocab_size=params.vocab_size, + embedding_dim=params.embedding_dim, + hidden_dim=params.hidden_dim, + num_layers=params.num_layers, + tie_weights=params.tie_weights, + ) + + checkpoints = load_checkpoint_if_available(params=params, model=model) + + model.to(device) + if is_distributed: + model = DDP(model, device_ids=[rank]) + + model.device = device + + optimizer = optim.Adam( + model.parameters(), + lr=params.lr, + weight_decay=params.weight_decay, + ) + if checkpoints: + logging.info("Load optimizer state_dict from checkpoint") + optimizer.load_state_dict(checkpoints["optimizer"]) + + logging.info(f"Loading LM training data from {params.lm_data}") + train_dl = get_dataloader( + filename=params.lm_data, + is_distributed=is_distributed, + params=params, + ) + + logging.info(f"Loading LM validation data from {params.lm_data_valid}") + valid_dl = get_dataloader( + filename=params.lm_data_valid, + is_distributed=is_distributed, + params=params, + ) + + # Note: No learning rate scheduler is used here + for epoch in range(params.start_epoch, params.num_epochs): + if is_distributed: + train_dl.sampler.set_epoch(epoch) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + optimizer=optimizer, + train_dl=train_dl, + valid_dl=valid_dl, + tb_writer=tb_writer, + world_size=world_size, + ) + + save_checkpoint( + params=params, + model=model, + optimizer=optimizer, + rank=rank, + ) + + logging.info("Done!") + + if is_distributed: + torch.distributed.barrier() + cleanup_dist() + + +def main(): + parser = get_parser() + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/icefall/utils.py b/icefall/utils.py index b38574f0c..10a2e6301 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -35,6 +35,8 @@ import torch.distributed as dist import torch.nn as nn from torch.utils.tensorboard import SummaryWriter +from icefall.checkpoint import average_checkpoints + Pathlike = Union[str, Path] @@ -90,7 +92,11 @@ def str2bool(v): def setup_logger( - log_filename: Pathlike, log_level: str = "info", use_console: bool = True + log_filename: Pathlike, + log_level: str = "info", + rank: int = 0, + world_size: int = 1, + use_console: bool = True, ) -> None: """Setup log level. @@ -100,12 +106,16 @@ def setup_logger( log_level: The log level to use, e.g., "debug", "info", "warning", "error", "critical" + rank: + Rank of this node in DDP training. + world_size: + Number of nodes in DDP training. + use_console: + True to also print logs to console. """ now = datetime.now() date_time = now.strftime("%Y-%m-%d-%H-%M-%S") - if dist.is_available() and dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() + if world_size > 1: formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa log_filename = f"{log_filename}-{date_time}-{rank}" else: @@ -799,3 +809,34 @@ def optim_step_and_measure_param_change( delta = l2_norm(p_orig - p_new) / l2_norm(p_orig) relative_change[n] = delta.item() return relative_change + + +def load_averaged_model( + model_dir: str, + model: torch.nn.Module, + epoch: int, + avg: int, + device: torch.device, +): + """ + Load a model which is the average of all checkpoints + + :param model_dir: a str of the experiment directory + :param model: a torch.nn.Module instance + + :param epoch: the last epoch to load from + :param avg: how many models to average from + :param device: move model to this device + + :return: A model averaged + """ + + # start cannot be negative + start = max(epoch - avg + 1, 0) + filenames = [f"{model_dir}/epoch-{i}.pt" for i in range(start, epoch + 1)] + + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + + return model From c391bfd1000e0b2d7ff906be529d4c5fd1bcb7b0 Mon Sep 17 00:00:00 2001 From: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Date: Fri, 24 Jun 2022 10:40:46 +0800 Subject: [PATCH 8/9] fix errors for soft connection (#443) --- egs/ptb/LM/local/prepare_lm_training_data.py | 2 +- egs/ptb/LM/local/train_bpe_model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/ptb/LM/local/prepare_lm_training_data.py b/egs/ptb/LM/local/prepare_lm_training_data.py index eebce957e..abc00d421 120000 --- a/egs/ptb/LM/local/prepare_lm_training_data.py +++ b/egs/ptb/LM/local/prepare_lm_training_data.py @@ -1 +1 @@ -../../librispeech/ASR/local/prepare_lm_training_data.py \ No newline at end of file +../../../librispeech/ASR/local/prepare_lm_training_data.py \ No newline at end of file diff --git a/egs/ptb/LM/local/train_bpe_model.py b/egs/ptb/LM/local/train_bpe_model.py index 6f018a0e2..6fad36421 120000 --- a/egs/ptb/LM/local/train_bpe_model.py +++ b/egs/ptb/LM/local/train_bpe_model.py @@ -1 +1 @@ -../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file +../../../librispeech/ASR/local/train_bpe_model.py \ No newline at end of file From c0ea33473868d964bb1242207ac9011b6bde3d08 Mon Sep 17 00:00:00 2001 From: Tiance Wang Date: Fri, 24 Jun 2022 19:31:09 +0800 Subject: [PATCH 9/9] fix bug of concatenating list to tuple (#444) --- egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index bf3917df0..0fa4b6907 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -1195,7 +1195,7 @@ class RandomCombine(nn.Module): ans = torch.matmul(stacked_inputs, weights) # ans: (*, num_channels) - ans = ans.reshape(inputs[0].shape[:-1] + [num_channels]) + ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,)) # The following if causes errors for torch script in torch 1.6.0 # if __name__ == "__main__":