diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md
index 5088497a1..a2d32013a 100644
--- a/egs/aishell/ASR/RESULTS.md
+++ b/egs/aishell/ASR/RESULTS.md
@@ -2,6 +2,56 @@
### Aishell training result(Stateless Transducer)
+#### Pruned transducer stateless 7 streaming
+[./pruned_transducer_stateless7_streaming](./pruned_transducer_stateless7_streaming)
+
+It's Streaming version of Zipformer1 with Pruned RNNT loss.
+
+| | test | dev | comment |
+|------------------------|------|------|---------------------------------------|
+| greedy search | 6.95 | 6.29 | --epoch 44 --avg 15 --max-duration 600 |
+| modified beam search | 6.51 | 5.90 | --epoch 44 --avg 15 --max-duration 600 |
+| fast beam search | 6.73 | 6.09 | --epoch 44 --avg 15 --max-duration 600 |
+
+Training command is:
+
+```bash
+./prepare.sh
+
+export CUDA_VISIBLE_DEVICES="0,1"
+
+./pruned_transducer_stateless7_streaming/train.py \
+ --world-size 2 \
+ --num-epochs 50 \
+ --use-fp16 1 \
+ --context-size 1 \
+ --max-duration 800 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --enable-musan 0 \
+ --spec-aug-time-warp-factor 20
+```
+
+**Caution**: It uses `--context-size=1`.
+
+The decoding command is:
+```bash
+for m in greedy_search modified_beam_search fast_beam_search ; do
+ ./pruned_transducer_stateless7_streaming/decode.py \
+ --epoch 44 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --lang-dir data/lang_char \
+ --context-size 1 \
+ --decoding-method $m
+done
+```
+
+Pretrained models, training logs, decoding logs, tensorboard and decoding results
+are available at
+
+
+
+
#### Pruned transducer stateless 7
[./pruned_transducer_stateless7](./pruned_transducer_stateless7)
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/README.md b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/README.md
new file mode 120000
index 000000000..a784292cd
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/README.md
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/README.md
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py
new file mode 120000
index 000000000..a074d6085
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/asr_datamodule.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/beam_search.py
new file mode 120000
index 000000000..8554e44cc
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/beam_search.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/beam_search.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py
new file mode 100755
index 000000000..f5ae836fd
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode.py
@@ -0,0 +1,735 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./pruned_transducer_stateless7_streaming/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --max-duration 600 \
+ --decode-chunk-len 32 \
+ --decoding-method greedy_search
+
+(2) beam search (not recommended)
+./pruned_transducer_stateless7_streaming/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --max-duration 600 \
+ --decode-chunk-len 32 \
+ --decoding-method beam_search \
+ --beam-size 4
+
+(3) modified beam search
+./pruned_transducer_stateless7_streaming/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --max-duration 600 \
+ --decode-chunk-len 32 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+
+(4) fast beam search (one best)
+./pruned_transducer_stateless7_streaming/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --max-duration 600 \
+ --decode-chunk-len 32 \
+ --decoding-method fast_beam_search \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+
+(5) fast beam search (nbest)
+./pruned_transducer_stateless7_streaming/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --max-duration 600 \
+ --decode-chunk-len 32 \
+ --decoding-method fast_beam_search_nbest \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(6) fast beam search (nbest oracle WER)
+./pruned_transducer_stateless7_streaming/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --max-duration 600 \
+ --decode-chunk-len 32 \
+ --decoding-method fast_beam_search_nbest_oracle \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64 \
+ --num-paths 200 \
+ --nbest-scale 0.5
+
+(7) fast beam search (with LG)
+./pruned_transducer_stateless7_streaming/decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --max-duration 600 \
+ --decode-chunk-len 32 \
+ --decoding-method fast_beam_search_nbest_LG \
+ --beam 20.0 \
+ --max-contexts 8 \
+ --max-states 64
+"""
+
+
+import argparse
+import logging
+import math
+import os
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from asr_datamodule import AishellAsrDataModule
+from beam_search import (
+ beam_search,
+ fast_beam_search_one_best,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from train import add_model_arguments, get_params, get_transducer_model
+
+from icefall import ContextGraph
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=False,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless3/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="The lang dir",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ - fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=8,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=1,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--context-score",
+ type=float,
+ default=2,
+ help="""
+ The bonus score of each token for the context biasing words/phrases.
+ Used only when --decoding_method is modified_beam_search.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-file",
+ type=str,
+ default="",
+ help="""
+ The path of the context biasing lists, one word/phrase each line
+ Used only when --decoding_method is modified_beam_search.
+ """,
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ token_table: k2.SymbolTable,
+ batch: dict,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ token_table:
+ It maps token ID to a string.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ # at entry, feature is (N, T, C)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ feature_lens += 30
+ feature = torch.nn.functional.pad(
+ feature,
+ pad=(0, 0, 0, 30),
+ value=LOG_EPS,
+ )
+ encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
+
+ if params.decoding_method == "fast_beam_search":
+ hyp_tokens = fast_beam_search_one_best(
+ model=model,
+ decoding_graph=decoding_graph,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam,
+ max_contexts=params.max_contexts,
+ max_states=params.max_states,
+ )
+ elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ context_graph=context_graph,
+ )
+ else:
+ hyp_tokens = []
+ batch_size = encoder_out.size(0)
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyp_tokens.append(hyp)
+
+ hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens]
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": hyps}
+ elif params.decoding_method == "fast_beam_search":
+ return {
+ (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}"
+ ): hyps
+ }
+ else:
+ key = f"beam_size_{params.beam_size}"
+ if params.has_contexts:
+ key += f"-context-score-{params.context_score}"
+ else:
+ key += "-no-context-words"
+ return {key: hyps}
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ token_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ token_table:
+ It maps a token ID to a string.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ for batch_idx, batch in enumerate(dl):
+ texts = batch["supervisions"]["text"]
+ cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
+
+ hyps_dict = decode_one_batch(
+ params=params,
+ model=model,
+ token_table=token_table,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ batch=batch,
+ )
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ assert len(hyps) == len(texts)
+ for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
+ ref_words = ref_text.split()
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(texts)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AishellAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+ args.lang_dir = Path(args.lang_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "fast_beam_search",
+ "modified_beam_search",
+ )
+
+ if os.path.exists(params.context_file):
+ params.has_contexts = True
+ else:
+ params.has_contexts = False
+
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if "fast_beam_search" in params.decoding_method:
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+ elif "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ if params.has_contexts:
+ params.suffix += f"-context-score-{params.context_score}"
+ else:
+ params.suffix += "-no-contexts-words"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ )
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+ assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
+ model.encoder.decode_chunk_size,
+ params.decode_chunk_len,
+ )
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+ else:
+ decoding_graph = None
+
+ if params.decoding_method == "modified_beam_search":
+ if os.path.exists(params.context_file):
+ contexts_text = []
+ for line in open(params.context_file).readlines():
+ contexts_text.append(line.strip())
+ contexts = graph_compiler.texts_to_ids(contexts_text)
+ context_graph = ContextGraph(params.context_score)
+ context_graph.build(contexts)
+ else:
+ context_graph = None
+ else:
+ context_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ aishell = AishellAsrDataModule(args)
+
+ test_cuts = aishell.test_cuts()
+ dev_cuts = aishell.valid_cuts()
+
+ test_dl = aishell.test_dataloaders(test_cuts)
+ dev_dl = aishell.test_dataloaders(dev_cuts)
+
+ test_sets = ["test", "dev"]
+ test_dls = [test_dl, dev_dl]
+
+ import time
+
+ for test_set, test_dl in zip(test_sets, test_dls):
+ start = time.time()
+ results_dict = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ token_table=lexicon.token_table,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ )
+ logging.info(f"Elasped time for {test_set}: {time.time() - start}")
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode_stream.py
new file mode 120000
index 000000000..ca8fed319
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decode_stream.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decoder.py
new file mode 120000
index 000000000..33944d0d2
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/decoder.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/decoder.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py
new file mode 120000
index 000000000..b9aa0ae08
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless2/encoder_interface.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
new file mode 120000
index 000000000..72e43c297
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn-zh.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
new file mode 120000
index 000000000..3b36924ef
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-for-ncnn.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py
new file mode 120000
index 000000000..eca5e2956
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx-zh.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx.py
new file mode 120000
index 000000000..57a0cd0a0
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export-onnx.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export-onnx.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export.py
new file mode 120000
index 000000000..2acafdc61
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/export.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py
new file mode 120000
index 000000000..5d9c6ba00
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_pretrained.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py
new file mode 120000
index 000000000..457131699
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_export.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py
new file mode 120000
index 000000000..2b8fa3cbb
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/jit_trace_pretrained.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/joiner.py
new file mode 120000
index 000000000..ecfb6dd8a
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/joiner.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/joiner.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/model.py
new file mode 120000
index 000000000..e17d4f734
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/model.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/model.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py
new file mode 120000
index 000000000..8eea90e04
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/ncnn_custom_layer.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_check.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_check.py
new file mode 120000
index 000000000..28bf7bb82
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_check.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_check.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py
new file mode 120000
index 000000000..c8548d459
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_model_wrapper.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py
new file mode 120000
index 000000000..ae4d9bb04
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/onnx_pretrained.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/optim.py
new file mode 120000
index 000000000..81ac4a89a
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/optim.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/pretrained.py
new file mode 120000
index 000000000..9510b8fde
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/pretrained.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling.py
new file mode 120000
index 000000000..2428b74b9
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py
new file mode 120000
index 000000000..b8b8ba432
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py
@@ -0,0 +1 @@
+../pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py
new file mode 120000
index 000000000..92c3904af
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py
new file mode 120000
index 000000000..1199a61d6
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
new file mode 100755
index 000000000..6b4f183cf
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py
@@ -0,0 +1,627 @@
+#!/usr/bin/env python3
+# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Usage:
+./pruned_transducer_stateless7_streaming/streaming_decode.py \
+ --epoch 28 \
+ --avg 15 \
+ --decode-chunk-len 32 \
+ --exp-dir ./pruned_transducer_stateless7_streaming/exp \
+ --decoding-method greedy_search \
+ --num-decode-streams 2000
+"""
+
+import argparse
+import logging
+import math
+import os
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import numpy as np
+import torch
+import torch.nn as nn
+from asr_datamodule import AishellAsrDataModule
+from decode_stream import DecodeStream
+from kaldifeat import Fbank, FbankOptions
+from lhotse import CutSet
+from streaming_beam_search import (
+ fast_beam_search_one_best,
+ greedy_search,
+ modified_beam_search,
+)
+from torch.nn.utils.rnn import pad_sequence
+from train import add_model_arguments, get_params, get_transducer_model
+from zipformer import stack_states, unstack_states
+
+from icefall import ContextGraph
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_error_stats,
+)
+
+LOG_EPS = math.log(1e-10)
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=28,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=15,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7_streaming/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Supported decoding methods are:
+ greedy_search
+ modified_beam_search
+ fast_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--num_active_paths",
+ type=int,
+ default=4,
+ help="""An interger indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--beam",
+ type=float,
+ default=4,
+ help="""A floating point value to calculate the cutoff score during beam
+ search (i.e., `cutoff = max-score - beam`), which is the same as the
+ `beam` in Kaldi.
+ Used only when --decoding-method is fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-contexts",
+ type=int,
+ default=4,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--max-states",
+ type=int,
+ default=32,
+ help="""Used only when --decoding-method is
+ fast_beam_search""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--num-decode-streams",
+ type=int,
+ default=2000,
+ help="The number of streams that can be decoded parallel.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_chunk(
+ params: AttributeDict,
+ model: nn.Module,
+ decode_streams: List[DecodeStream],
+) -> List[int]:
+ """Decode one chunk frames of features for each decode_streams and
+ return the indexes of finished streams in a List.
+
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ decode_streams:
+ A List of DecodeStream, each belonging to a utterance.
+ Returns:
+ Return a List containing which DecodeStreams are finished.
+ """
+ device = model.device
+
+ features = []
+ feature_lens = []
+ states = []
+ processed_lens = []
+
+ for stream in decode_streams:
+ feat, feat_len = stream.get_feature_frames(params.decode_chunk_len)
+ features.append(feat)
+ feature_lens.append(feat_len)
+ states.append(stream.states)
+ processed_lens.append(stream.done_frames)
+
+ feature_lens = torch.tensor(feature_lens, device=device)
+ features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
+
+ # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling
+ # factor in encoders is 8.
+ # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8.
+ tail_length = 23
+ if features.size(1) < tail_length:
+ pad_length = tail_length - features.size(1)
+ feature_lens += pad_length
+ features = torch.nn.functional.pad(
+ features,
+ (0, 0, 0, pad_length),
+ mode="constant",
+ value=LOG_EPS,
+ )
+
+ states = stack_states(states)
+ processed_lens = torch.tensor(processed_lens, device=device)
+
+ encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward(
+ x=features,
+ x_lens=feature_lens,
+ states=states,
+ )
+
+ encoder_out = model.joiner.encoder_proj(encoder_out)
+
+ if params.decoding_method == "greedy_search":
+ greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
+ elif params.decoding_method == "fast_beam_search":
+ processed_lens = processed_lens + encoder_out_lens
+ fast_beam_search_one_best(
+ model=model,
+ encoder_out=encoder_out,
+ processed_lens=processed_lens,
+ streams=decode_streams,
+ beam=params.beam,
+ max_states=params.max_states,
+ max_contexts=params.max_contexts,
+ )
+ elif params.decoding_method == "modified_beam_search":
+ modified_beam_search(
+ model=model,
+ streams=decode_streams,
+ encoder_out=encoder_out,
+ num_active_paths=params.num_active_paths,
+ )
+ else:
+ raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+
+ states = unstack_states(new_states)
+
+ finished_streams = []
+ for i in range(len(decode_streams)):
+ decode_streams[i].states = states[i]
+ decode_streams[i].done_frames += encoder_out_lens[i]
+ if decode_streams[i].done:
+ finished_streams.append(i)
+
+ return finished_streams
+
+
+def decode_dataset(
+ cuts: CutSet,
+ params: AttributeDict,
+ model: nn.Module,
+ token_table: Optional[k2.SymbolTable] = None,
+ decoding_graph: Optional[k2.Fsa] = None,
+ context_graph: Optional[ContextGraph] = None,
+) -> Dict[str, List[Tuple[List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ cuts:
+ Lhotse Cutset containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ decoding_graph:
+ The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
+ only when --decoding_method is fast_beam_search.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ device = model.device
+
+ opts = FbankOptions()
+ opts.device = device
+ opts.frame_opts.dither = 0
+ opts.frame_opts.snip_edges = False
+ opts.frame_opts.samp_freq = 16000
+ opts.mel_opts.num_bins = 80
+
+ log_interval = 50
+
+ decode_results = []
+ # Contain decode streams currently running.
+ decode_streams = []
+ for num, cut in enumerate(cuts):
+ # each utterance has a DecodeStream.
+ initial_states = model.encoder.get_init_state(device=device)
+ decode_stream = DecodeStream(
+ params=params,
+ cut_id=cut.id,
+ initial_states=initial_states,
+ decoding_graph=decoding_graph,
+ device=device,
+ )
+
+ audio: np.ndarray = cut.load_audio()
+ # audio.shape: (1, num_samples)
+ assert len(audio.shape) == 2
+ assert audio.shape[0] == 1, "Should be single channel"
+ assert audio.dtype == np.float32, audio.dtype
+
+ # The trained model is using normalized samples
+ assert audio.max() <= 1, "Should be normalized to [-1, 1])"
+
+ samples = torch.from_numpy(audio).squeeze(0)
+
+ fbank = Fbank(opts)
+ feature = fbank(samples.to(device))
+ decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len)
+ decode_stream.ground_truth = cut.supervisions[0].text
+
+ decode_streams.append(decode_stream)
+
+ while len(decode_streams) >= params.num_decode_streams:
+ finished_streams = decode_one_chunk(
+ params=params, model=model, decode_streams=decode_streams
+ )
+ for i in sorted(finished_streams, reverse=True):
+ decode_results.append(
+ (
+ decode_streams[i].id,
+ decode_streams[i].ground_truth.split(),
+ [
+ token_table[result]
+ for result in decode_streams[i].decoding_result()
+ ],
+ )
+ )
+ del decode_streams[i]
+
+ if num % log_interval == 0:
+ logging.info(f"Cuts processed until now is {num}.")
+
+ # decode final chunks of last sequences
+ while len(decode_streams):
+ finished_streams = decode_one_chunk(
+ params=params, model=model, decode_streams=decode_streams
+ )
+ for i in sorted(finished_streams, reverse=True):
+ decode_results.append(
+ (
+ decode_streams[i].id,
+ decode_streams[i].ground_truth.split(),
+ [
+ token_table[result]
+ for result in decode_streams[i].decoding_result()
+ ],
+ )
+ )
+ del decode_streams[i]
+
+ if params.decoding_method == "greedy_search":
+ key = "greedy_search"
+ elif params.decoding_method == "fast_beam_search":
+ key = (
+ f"beam_{params.beam}_"
+ f"max_contexts_{params.max_contexts}_"
+ f"max_states_{params.max_states}"
+ )
+ elif params.decoding_method == "modified_beam_search":
+ key = f"num_active_paths_{params.num_active_paths}"
+ else:
+ raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
+ return {key: decode_results}
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt"
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt"
+ with open(errs_filename, "w") as f:
+ wer = write_error_stats(
+ f, f"{test_set_name}-{key}", results, enable_log=True
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt"
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ AishellAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ params.res_dir = params.exp_dir / "streaming" / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ # for streaming
+ params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
+
+ # for fast_beam_search
+ if params.decoding_method == "fast_beam_search":
+ params.suffix += f"-beam-{params.beam}"
+ params.suffix += f"-max-contexts-{params.max_contexts}"
+ params.suffix += f"-max-states-{params.max_states}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ )
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if start >= 0:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+ model.device = device
+
+ decoding_graph = None
+ if params.decoding_method == "fast_beam_search":
+ decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
+
+ if params.decoding_method == "modified_beam_search":
+ if os.path.exists(params.context_file):
+ contexts_text = []
+ for line in open(params.context_file).readlines():
+ contexts_text.append(line.strip())
+ contexts = graph_compiler.texts_to_ids(contexts_text)
+ context_graph = ContextGraph(params.context_score)
+ context_graph.build(contexts)
+ else:
+ context_graph = None
+ else:
+ context_graph = None
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ aishell = AishellAsrDataModule(args)
+
+ test_cuts = aishell.test_cuts()
+ valid_cuts = aishell.valid_cuts()
+
+ test_sets = ["test", "valid"]
+ cuts = [test_cuts, valid_cuts]
+
+ for test_set, test_cut in zip(test_sets, cuts):
+ results_dict = decode_dataset(
+ cuts=test_cut,
+ params=params,
+ model=model,
+ token_table=lexicon.token_table,
+ decoding_graph=decoding_graph,
+ context_graph=context_graph,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=test_set,
+ results_dict=results_dict,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/test_model.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/test_model.py
new file mode 120000
index 000000000..1259849e0
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/test_model.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py
new file mode 100755
index 000000000..2e1044658
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train.py
@@ -0,0 +1,1251 @@
+#!/usr/bin/env python3
+# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,)
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7_streaming/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp \
+ --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless7_streaming/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp \
+ --max-duration 550
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import AishellAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for module in model.modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,4,3,2,4",
+ help="Number of zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dims",
+ type=str,
+ default="1024,1024,2048,2048,1024",
+ help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--nhead",
+ type=str,
+ default="8,8,8,8,8",
+ help="Number of attention heads in the zipformer encoder layers.",
+ )
+
+ parser.add_argument(
+ "--encoder-dims",
+ type=str,
+ default="384,384,384,384,384",
+ help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+ )
+
+ parser.add_argument(
+ "--attention-dims",
+ type=str,
+ default="192,192,192,192,192",
+ help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+ not the same as embedding dimension.""",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dims",
+ type=str,
+ default="256,256,256,256,256",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
+ " worse.",
+ )
+
+ parser.add_argument(
+ "--zipformer-downsampling-factors",
+ type=str,
+ default="1,2,4,8,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernels",
+ type=str,
+ default="31,31,31,31,31",
+ help="Sizes of kernels in convolution modules",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--short-chunk-size",
+ type=int,
+ default=50,
+ help="""Chunk length of dynamic training, the chunk size would be either
+ max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
+ """,
+ )
+
+ parser.add_argument(
+ "--num-left-chunks",
+ type=int,
+ default=4,
+ help="How many left context can be seen in chunks when calculating attention.",
+ )
+
+ parser.add_argument(
+ "--decode-chunk-len",
+ type=int,
+ default=32,
+ help="The chunk size for decoding (in frames before subsampling)",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7_streaming/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.05, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network) part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=2000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Zipformer and Transformer
+ def to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+ encoder = Zipformer(
+ num_features=params.feature_dim,
+ output_downsampling_factor=2,
+ zipformer_downsampling_factors=to_int_tuple(
+ params.zipformer_downsampling_factors
+ ),
+ encoder_dims=to_int_tuple(params.encoder_dims),
+ attention_dim=to_int_tuple(params.attention_dims),
+ encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+ nhead=to_int_tuple(params.nhead),
+ feedforward_dim=to_int_tuple(params.feedforward_dims),
+ cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+ num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+ num_left_chunks=params.num_left_chunks,
+ short_chunk_size=params.short_chunk_size,
+ decode_chunk_size=params.decode_chunk_len // 2,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute transducer loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = graph_compiler.texts_to_ids(texts)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+
+ loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ set_batch_count(model, params.batch_idx_train)
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+ if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ oov="",
+ )
+
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ parameters_names = []
+ parameters_names.append(
+ [name_param_pair[0] for name_param_pair in model.named_parameters()]
+ )
+ optimizer = ScaledAdam(
+ model.parameters(),
+ lr=params.base_lr,
+ clipping_scale=2.0,
+ parameters_names=parameters_names,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ aishell = AishellAsrDataModule(args)
+
+ train_cuts = aishell.train_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ # T = ((c.num_frames - 7) // 2 + 1) // 2
+ # tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ # if T < len(tokens):
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. "
+ # f"Number of frames (before subsampling): {c.num_frames}. "
+ # f"Number of frames (after subsampling): {T}. "
+ # f"Text: {c.supervisions[0].text}. "
+ # f"Tokens: {tokens}. "
+ # f"Number of tokens: {len(tokens)}"
+ # )
+ return False
+
+ return True
+
+ # train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = aishell.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = aishell.valid_cuts()
+ valid_dl = aishell.valid_dataloaders(valid_cuts)
+
+ # if not params.print_diagnostics:
+ # scan_pessimistic_batches_for_oom(
+ # model=model,
+ # train_dl=train_dl,
+ # optimizer=optimizer,
+ # graph_compiler=graph_compiler,
+ # params=params,
+ # )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ graph_compiler=graph_compiler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = graph_compiler.texts_to_ids(supervisions["text"])
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ AishellAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+ args.lang_dir = Path(args.lang_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py
new file mode 100755
index 000000000..88eb34104
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/train2.py
@@ -0,0 +1,1253 @@
+#!/usr/bin/env python3
+# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,)
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+./pruned_transducer_stateless7_streaming/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp \
+ --max-duration 300
+
+# For mix precision training:
+
+./pruned_transducer_stateless7_streaming/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir pruned_transducer_stateless7_streaming/exp \
+ --max-duration 550
+"""
+
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import AishellAsrDataModule
+from decoder import Decoder
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from model import Transducer
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer2 import Zipformer
+
+from icefall import diagnostics
+from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.lexicon import Lexicon
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for module in model.modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,4,3,2,4",
+ help="Number of zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dims",
+ type=str,
+ default="1024,1024,2048,2048,1024",
+ help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--nhead",
+ type=str,
+ default="8,8,8,8,8",
+ help="Number of attention heads in the zipformer encoder layers.",
+ )
+
+ parser.add_argument(
+ "--encoder-dims",
+ type=str,
+ default="384,384,384,384,384",
+ help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+ )
+
+ parser.add_argument(
+ "--attention-dims",
+ type=str,
+ default="192,192,192,192,192",
+ help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+ not the same as embedding dimension.""",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dims",
+ type=str,
+ default="256,256,256,256,256",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
+ " worse.",
+ )
+
+ parser.add_argument(
+ "--zipformer-downsampling-factors",
+ type=str,
+ default="1,2,4,8,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernels",
+ type=str,
+ default="31,31,31,31,31",
+ help="Sizes of kernels in convolution modules",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--short-chunk-size",
+ type=int,
+ default=50,
+ help="""Chunk length of dynamic training, the chunk size would be either
+ max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
+ """,
+ )
+
+ parser.add_argument(
+ "--num-left-chunks",
+ type=int,
+ default=4,
+ help="How many left context can be seen in chunks when calculating attention.",
+ )
+
+ parser.add_argument(
+ "--decode-chunk-len",
+ type=int,
+ default=32,
+ help="The chunk size for decoding (in frames before subsampling)",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7_streaming/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=str,
+ default="data/lang_char",
+ help="""The lang dir
+ It contains language related input files such as
+ "lexicon.txt"
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.05, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=3.5,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network) part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=2000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 1.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - encoder_dim: Hidden dim for multi-head attention model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warmup period that dictates the decay of the
+ scale on "simple" (un-pruned) loss.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 3000, # For the 100h subset, use 800
+ # parameters for zipformer
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed.
+ "warm_step": 2000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Zipformer and Transformer
+ def to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+ encoder = Zipformer(
+ num_features=params.feature_dim,
+ output_downsampling_factor=2,
+ zipformer_downsampling_factors=to_int_tuple(
+ params.zipformer_downsampling_factors
+ ),
+ encoder_dims=to_int_tuple(params.encoder_dims),
+ attention_dim=to_int_tuple(params.attention_dims),
+ encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+ nhead=to_int_tuple(params.nhead),
+ feedforward_dim=to_int_tuple(params.feedforward_dims),
+ cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+ num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+ num_left_chunks=params.num_left_chunks,
+ short_chunk_size=params.short_chunk_size,
+ decode_chunk_size=params.decode_chunk_len // 2,
+ is_pnnx=True,
+ )
+ return encoder
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_transducer_model(params: AttributeDict) -> nn.Module:
+ encoder = get_encoder_model(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = Transducer(
+ encoder=encoder,
+ decoder=decoder,
+ joiner=joiner,
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute transducer loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Zipformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ warmup: a floating point value which increases throughout training;
+ values >= 1.0 are fully warmed up and have all modules present.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+ feature = feature.to(device)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.warm_step
+
+ texts = batch["supervisions"]["text"]
+ y = graph_compiler.texts_to_ids(texts)
+ y = k2.RaggedTensor(y).to(device)
+
+ with torch.set_grad_enabled(is_training):
+ simple_loss, pruned_loss = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ )
+
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+
+ loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ set_batch_count(model, params.batch_idx_train)
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+ if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ lexicon = Lexicon(params.lang_dir)
+ graph_compiler = CharCtcTrainingGraphCompiler(
+ lexicon=lexicon,
+ device=device,
+ oov="",
+ )
+
+ params.blank_id = 0
+ params.vocab_size = max(lexicon.tokens) + 1
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_transducer_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ parameters_names = []
+ parameters_names.append(
+ [name_param_pair[0] for name_param_pair in model.named_parameters()]
+ )
+ optimizer = ScaledAdam(
+ model.parameters(),
+ lr=params.base_lr,
+ clipping_scale=2.0,
+ parameters_names=parameters_names,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 512
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ aishell = AishellAsrDataModule(args)
+
+ train_cuts = aishell.train_cuts()
+
+ def remove_short_and_long_utt(c: Cut):
+ # Keep only utterances with duration between 1 second and 20 seconds
+ #
+ # Caution: There is a reason to select 20.0 here. Please see
+ # ../local/display_manifest_statistics.py
+ #
+ # You should use ../local/display_manifest_statistics.py to get
+ # an utterance duration distribution for your dataset to select
+ # the threshold
+ if c.duration < 1.0 or c.duration > 20.0:
+ logging.warning(
+ f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
+ )
+ return False
+
+ # In pruned RNN-T, we require that T >= S
+ # where T is the number of feature frames after subsampling
+ # and S is the number of tokens in the utterance
+
+ # In ./zipformer.py, the conv module uses the following expression
+ # for subsampling
+ # T = ((c.num_frames - 7) // 2 + 1) // 2
+ # tokens = sp.encode(c.supervisions[0].text, out_type=str)
+
+ # if T < len(tokens):
+ # logging.warning(
+ # f"Exclude cut with ID {c.id} from training. "
+ # f"Number of frames (before subsampling): {c.num_frames}. "
+ # f"Number of frames (after subsampling): {T}. "
+ # f"Text: {c.supervisions[0].text}. "
+ # f"Tokens: {tokens}. "
+ # f"Number of tokens: {len(tokens)}"
+ # )
+ # return False
+
+ return True
+
+ train_cuts = train_cuts.filter(remove_short_and_long_utt)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = aishell.train_dataloaders(
+ train_cuts, sampler_state_dict=sampler_state_dict
+ )
+
+ valid_cuts = aishell.valid_cuts()
+ valid_dl = aishell.valid_dataloaders(valid_cuts)
+
+ if not params.print_diagnostics:
+ scan_pessimistic_batches_for_oom(
+ model=model,
+ train_dl=train_dl,
+ optimizer=optimizer,
+ graph_compiler=graph_compiler,
+ params=params,
+ )
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ graph_compiler=graph_compiler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = graph_compiler.texts_to_ids(supervisions["text"])
+ num_tokens = sum(len(i) for i in y)
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def scan_pessimistic_batches_for_oom(
+ model: Union[nn.Module, DDP],
+ train_dl: torch.utils.data.DataLoader,
+ optimizer: torch.optim.Optimizer,
+ graph_compiler: CharCtcTrainingGraphCompiler,
+ params: AttributeDict,
+):
+ from lhotse.dataset import find_pessimistic_batches
+
+ logging.info(
+ "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
+ )
+ batches, crit_values = find_pessimistic_batches(train_dl.sampler)
+ for criterion, cuts in batches.items():
+ batch = train_dl.dataset[cuts]
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, _ = compute_loss(
+ params=params,
+ model=model,
+ graph_compiler=graph_compiler,
+ batch=batch,
+ is_training=True,
+ )
+ loss.backward()
+ optimizer.zero_grad()
+ except Exception as e:
+ if "CUDA out of memory" in str(e):
+ logging.error(
+ "Your GPU ran out of memory with the current "
+ "max_duration setting. We recommend decreasing "
+ "max_duration and trying again.\n"
+ f"Failing criterion: {criterion} "
+ f"(={crit_values[criterion]}) ..."
+ )
+ display_and_save_batch(batch, params=params, graph_compiler=graph_compiler)
+ raise
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+
+
+def main():
+ parser = get_parser()
+ AishellAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer.py
new file mode 120000
index 000000000..ec183baa7
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py
\ No newline at end of file
diff --git a/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
new file mode 120000
index 000000000..12dbda888
--- /dev/null
+++ b/egs/aishell/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer2.py
\ No newline at end of file