From a25af9d61dd764f1a9b6a807d6d7b53dd9e66af1 Mon Sep 17 00:00:00 2001 From: pkufool Date: Wed, 2 Aug 2023 14:43:23 +0800 Subject: [PATCH] Add subformer into zipformer --- egs/librispeech/ASR/subformer/__init__.py | 0 .../ASR/subformer/asr_datamodule.py | 1 + egs/librispeech/ASR/subformer/beam_search.py | 1 + egs/librispeech/ASR/subformer/decode.py | 834 ++++++++ egs/librispeech/ASR/subformer/decoder.py | 1 + .../ASR/subformer/encoder_interface.py | 1 + egs/librispeech/ASR/subformer/joiner.py | 1 + egs/librispeech/ASR/subformer/mixformer.py | 484 +++++ egs/librispeech/ASR/subformer/model.py | 217 ++ egs/librispeech/ASR/subformer/optim.py | 1 + egs/librispeech/ASR/subformer/scaling.py | 1 + egs/librispeech/ASR/subformer/subformer.py | 1875 +++++++++++++++++ egs/librispeech/ASR/subformer/subsampling.py | 1 + egs/librispeech/ASR/subformer/train.py | 1395 ++++++++++++ egs/librispeech/ASR/subformer/zipformer.py | 1 + egs/librispeech/ASR/zipformer/scaling.py | 52 + 16 files changed, 4866 insertions(+) create mode 100644 egs/librispeech/ASR/subformer/__init__.py create mode 120000 egs/librispeech/ASR/subformer/asr_datamodule.py create mode 120000 egs/librispeech/ASR/subformer/beam_search.py create mode 100755 egs/librispeech/ASR/subformer/decode.py create mode 120000 egs/librispeech/ASR/subformer/decoder.py create mode 120000 egs/librispeech/ASR/subformer/encoder_interface.py create mode 120000 egs/librispeech/ASR/subformer/joiner.py create mode 100644 egs/librispeech/ASR/subformer/mixformer.py create mode 100644 egs/librispeech/ASR/subformer/model.py create mode 120000 egs/librispeech/ASR/subformer/optim.py create mode 120000 egs/librispeech/ASR/subformer/scaling.py create mode 100644 egs/librispeech/ASR/subformer/subformer.py create mode 120000 egs/librispeech/ASR/subformer/subsampling.py create mode 100755 egs/librispeech/ASR/subformer/train.py create mode 120000 egs/librispeech/ASR/subformer/zipformer.py diff --git a/egs/librispeech/ASR/subformer/__init__.py b/egs/librispeech/ASR/subformer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/egs/librispeech/ASR/subformer/asr_datamodule.py b/egs/librispeech/ASR/subformer/asr_datamodule.py new file mode 120000 index 000000000..a074d6085 --- /dev/null +++ b/egs/librispeech/ASR/subformer/asr_datamodule.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/asr_datamodule.py \ No newline at end of file diff --git a/egs/librispeech/ASR/subformer/beam_search.py b/egs/librispeech/ASR/subformer/beam_search.py new file mode 120000 index 000000000..8554e44cc --- /dev/null +++ b/egs/librispeech/ASR/subformer/beam_search.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/beam_search.py \ No newline at end of file diff --git a/egs/librispeech/ASR/subformer/decode.py b/egs/librispeech/ASR/subformer/decode.py new file mode 100755 index 000000000..7a3f4e4be --- /dev/null +++ b/egs/librispeech/ASR/subformer/decode.py @@ -0,0 +1,834 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./zipformer/decode.py \ + --epoch 28 \ + --avg 15 \ + --exp-dir ./zipformer/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + make_pad_mask, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + if params.causal: + # this seems to cause insertions at the end of the utterance if used with zipformer. + pad_len = int(params.chunk_size) + feature_lens += pad_len + feature = torch.nn.functional.pad( + feature, + pad=(0, 0, 0, pad_len), + value=LOG_EPS, + ) + + x, x_lens = model.encoder_embed(feature, feature_lens) + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, encoder_out_lens = model.encoder( + x, x_lens, src_key_padding_mask + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if params.causal: + assert ( + "," not in params.chunk_size + ), "chunk_size should be one value in decoding." + assert ( + "," not in params.left_context_frames + ), "left_context_frames should be one value in decoding." + params.suffix += f"-chunk-{params.chunk_size}" + params.suffix += f"-left-context-{params.left_context_frames}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/subformer/decoder.py b/egs/librispeech/ASR/subformer/decoder.py new file mode 120000 index 000000000..cab465d2b --- /dev/null +++ b/egs/librispeech/ASR/subformer/decoder.py @@ -0,0 +1 @@ +../zipformer/decoder.py \ No newline at end of file diff --git a/egs/librispeech/ASR/subformer/encoder_interface.py b/egs/librispeech/ASR/subformer/encoder_interface.py new file mode 120000 index 000000000..b9aa0ae08 --- /dev/null +++ b/egs/librispeech/ASR/subformer/encoder_interface.py @@ -0,0 +1 @@ +../pruned_transducer_stateless2/encoder_interface.py \ No newline at end of file diff --git a/egs/librispeech/ASR/subformer/joiner.py b/egs/librispeech/ASR/subformer/joiner.py new file mode 120000 index 000000000..444cb5f15 --- /dev/null +++ b/egs/librispeech/ASR/subformer/joiner.py @@ -0,0 +1 @@ +../zipformer/joiner.py \ No newline at end of file diff --git a/egs/librispeech/ASR/subformer/mixformer.py b/egs/librispeech/ASR/subformer/mixformer.py new file mode 100644 index 000000000..a3549ccdb --- /dev/null +++ b/egs/librispeech/ASR/subformer/mixformer.py @@ -0,0 +1,484 @@ +#!/usr/bin/env python3 +# Copyright (c) 2023 Xiaomi Corp. (author: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import warnings +from typing import List, Optional, Tuple, Union +import logging +import torch +import random +from encoder_interface import EncoderInterface +from scaling import ( + Balancer, + BiasNorm, + Dropout2, + ChunkCausalDepthwiseConv1d, + ActivationDropoutAndLinear, + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + Whiten, + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. + penalize_abs_values_gt, + softmax, + ScheduledFloat, + FloatLike, + limit_param_value, + convert_num_channels, +) +from subformer import ( + BypassModule, + CompactRelPositionalEncoding, + LearnedDownsamplingModule, + SubformerEncoder, + SubformerEncoderLayer, +) +from zipformer import ( + DownsampledZipformer2Encoder, + SimpleDownsample, + SimpleUpsample, + Zipformer2Encoder, + Zipformer2EncoderLayer, +) +from torch import Tensor, nn + + +class Mixformer(EncoderInterface): + def __init__( + self, + structure: str = "ZZS(S(S)S)SZ", + output_downsampling_factor: int = 2, + downsampling_factor: Tuple[int] = (1, 1, 2, 2, 1), + encoder_dim: Union[int, Tuple[int]] = ( + 192, + 192, + 256, + 384, + 512, + 384, + 256, + 192, + ), + num_encoder_layers: Union[int, Tuple[int]] = ( + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + ), + encoder_unmasked_dim: Union[int, Tuple[int]] = (192, 192, 192), + query_head_dim: Union[int, Tuple[int]] = (32,), + value_head_dim: Union[int, Tuple[int]] = (12,), + pos_head_dim: Union[int, Tuple[int]] = (4,), + pos_dim: int = (48,), + num_heads: Union[int, Tuple[int]] = (4,), + feedforward_dim: Union[int, Tuple[int]] = ( + 512, + 768, + 1024, + 1536, + 2048, + 1536, + 1024, + 768, + ), + cnn_module_kernel: Union[int, Tuple[int]] = (15, 31, 31), + encoder_chunk_sizes: Tuple[Tuple[int, ...]] = ((128, 1024),), + memory_dim: int = -1, + dropout: Optional[FloatLike] = None, # see code below for default + warmup_batches: float = 4000.0, + causal: bool = False, + ) -> None: + super(Mixformer, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + + num_zip_encoders = len([s for s in structure if s == 'Z']) + num_sub_encoders = len([s for s in structure if s == 'S']) + num_encoders = num_zip_encoders + num_sub_encoders + num_downsamplers = len([s for s in structure if s == '(']) + + def _to_tuple(x, length): + """Converts a single int or a 1-tuple of an int to a tuple with the same length + as downsampling_factor""" + assert isinstance(x, tuple) + if len(x) == 1: + x = x * length + else: + assert len(x) == length and isinstance( + x[0], int + ) + return x + + self.output_downsampling_factor = output_downsampling_factor # int + self.downsampling_factor = ( + downsampling_factor + ) = _to_tuple(downsampling_factor, num_zip_encoders + num_downsamplers) # tuple + self.encoder_dim = encoder_dim = _to_tuple(encoder_dim, num_encoders) # tuple + self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple( + encoder_unmasked_dim, num_zip_encoders + ) # tuple + num_encoder_layers = _to_tuple(num_encoder_layers, num_encoders) + self.query_head_dim = query_head_dim = _to_tuple(query_head_dim, num_encoders) + self.value_head_dim = value_head_dim = _to_tuple(value_head_dim, num_encoders) + pos_head_dim = _to_tuple(pos_head_dim, num_encoders) + pos_dim = _to_tuple(pos_dim, num_encoders) + self.num_heads = num_heads = _to_tuple(num_heads, num_encoders) + feedforward_dim = _to_tuple(feedforward_dim, num_encoders) + self.cnn_module_kernel = cnn_module_kernel = _to_tuple( + cnn_module_kernel, num_zip_encoders + ) + encoder_chunk_sizes = _to_tuple(encoder_chunk_sizes, num_sub_encoders) + + self.causal = causal + + # for u, d in zip(encoder_unmasked_dim, encoder_dim): + # assert u <= d + + # each one will be Zipformer2Encoder, DownsampledZipformer2Encoder, + # SubformerEncoder or DownsampledSubformerEncoder + zip_encoders = [] + sub_encoders = [] + downsamplers = [] + bypasses = [] + + layer_indexes = [] + + cur_max_dim = 0 + + downsampling_factors_list = [] + def cur_downsampling_factor(): + c = 1 + for d in downsampling_factors_list: c *= d + return c + + zip_encoder_dim = [] + zip_downsampling_factor = [] + for s in structure: + if s == "Z": + i = len(zip_encoders) + len(sub_encoders) + j = len(zip_encoders) + k = len(downsamplers) + len(zip_encoders) + assert encoder_unmasked_dim[j] <= encoder_dim[i] + zip_encoder_dim.append(encoder_dim[i]) + encoder_layer = Zipformer2EncoderLayer( + embed_dim=encoder_dim[i], + pos_dim=pos_dim[i], + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + pos_head_dim=pos_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_dim=feedforward_dim[i], + dropout=dropout, + cnn_module_kernel=cnn_module_kernel[j], + causal=causal, + ) + + # For the segment of the warmup period, we let the Conv2dSubsampling + # layer learn something. Then we start to warm up the other encoders. + encoder = Zipformer2Encoder( + encoder_layer, + num_encoder_layers[i], + pos_dim=pos_dim[i], + dropout=dropout, + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + final_layerdrop_rate=0.035 * (downsampling_factor[k] ** 0.5), + ) + + if downsampling_factor[k] != 1: + encoder = DownsampledZipformer2Encoder( + encoder, + dim=encoder_dim[i], + downsample=downsampling_factor[k], + dropout=dropout, + ) + zip_downsampling_factor.append(downsampling_factor[k]) + layer_indexes.append(len(zip_encoders)) + zip_encoders.append(encoder) + elif s == 'S': + i = len(zip_encoders) + len(sub_encoders) + j = len(sub_encoders) + if len(sub_encoders) == 0: + cur_max_dim = encoder_dim[i] + encoder_layer = SubformerEncoderLayer( + embed_dim=encoder_dim[i], + pos_dim=pos_head_dim[i], + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_dim=feedforward_dim[i], + memory_dim=memory_dim, + dropout=dropout, + causal=causal, + ) + cur_max_dim = max(cur_max_dim, encoder_dim[i]) + encoder = SubformerEncoder( + encoder_layer, + num_encoder_layers[i], + embed_dim=cur_max_dim, + dropout=dropout, + chunk_sizes=encoder_chunk_sizes[j], + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + final_layerdrop_rate=0.035 * (cur_downsampling_factor() ** 0.5), + ) + layer_indexes.append(len(sub_encoders)) + sub_encoders.append(encoder) + elif s =='(': + i = len(zip_encoders) + len(downsamplers) + downsampler = LearnedDownsamplingModule(cur_max_dim, + downsampling_factor[i]) + downsampling_factors_list.append(downsampling_factor[i]) + layer_indexes.append(len(downsamplers)) + downsamplers.append(downsampler) + else: + assert s == ')' + bypass = BypassModule(cur_max_dim, straight_through_rate=0.0) + layer_indexes.append(len(bypasses)) + bypasses.append(bypass) + downsampling_factors_list.pop() + + logging.info(f"cur_downsampling_factor={cur_downsampling_factor()}") + + self.zip_encoder_dim = zip_encoder_dim + self.zip_downsampling_factor = zip_downsampling_factor + self.layer_indexes = layer_indexes + self.structure = structure + self.zip_encoders = nn.ModuleList(zip_encoders) + self.sub_encoders = nn.ModuleList(sub_encoders) + self.downsamplers = nn.ModuleList(downsamplers) + self.bypasses = nn.ModuleList(bypasses) + + self.encoder_pos = CompactRelPositionalEncoding(64, pos_head_dim[0], + dropout_rate=0.15, + length_factor=1.0) + + self.downsample_output = SimpleDownsample( + max(encoder_dim), + downsample=output_downsampling_factor, + dropout=dropout, + ) + + def _get_full_dim_output(self, outputs: List[Tensor]): + num_encoders = len(self.zip_encoders) + 1 + assert len(outputs) == num_encoders + output_dim = max(self.encoder_dim) + output_pieces = [outputs[-1]] + cur_dim = self.encoder_dim[-1] + for i in range(num_encoders - 2, -1, -1): + d = list(outputs[i].shape)[-1] + if d > cur_dim: + this_output = outputs[i] + output_pieces.append(this_output[..., cur_dim:d]) + cur_dim = d + assert cur_dim == output_dim, (cur_dim, output_dim) + return torch.cat(output_pieces, dim=-1) + + def get_feature_masks(self, x: Tensor) -> Union[List[float], List[Tensor]]: + """ + In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of + randomized feature masks, one per encoder. + On e.g. 15% of frames, these masks will zero out all enocder dims larger than + some supplied number, e.g. >256, so in effect on those frames we are using + a smaller encoer dim. + + We generate the random masks at this level because we want the 2 masks to 'agree' + all the way up the encoder stack. This will mean that the 1st mask will have + mask values repeated self.zipformer_subsampling_factor times. + + Args: + x: the embeddings (needed for the shape and dtype and device), of shape + (1, batch_size, encoder_dims0) + """ + num_encoders = len(self.zip_encoders) + if not self.training: + return [1.0] * num_encoders + + (num_frames0, batch_size, _encoder_dims0) = x.shape + + assert self.encoder_dim[0] == _encoder_dims0 + + feature_mask_dropout_prob = 0.125 + + # mask1 shape: (1, batch_size, 1) + mask1 = ( + torch.rand(1, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype) + + # mask2 has additional sequences masked, about twice the number. + mask2 = torch.logical_and( + mask1, + ( + torch.rand(1, batch_size, 1, device=x.device) + > feature_mask_dropout_prob + ).to(x.dtype), + ) + + # dim: (1, batch_size, 2) + mask = torch.cat((mask1, mask2), dim=-1) + + feature_masks = [] + for i in range(num_encoders): + channels = self.zip_encoder_dim[i] + feature_mask = torch.ones( + 1, batch_size, channels, dtype=x.dtype, device=x.device + ) + u1 = self.encoder_unmasked_dim[i] + u2 = u1 + (channels - u1) // 2 + + feature_mask[:, :, u1:u2] *= mask[..., 0:1] + feature_mask[:, :, u2:] *= mask[..., 1:2] + + feature_masks.append(feature_mask) + + return feature_masks + + def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]: + """ + Return attention offset of shape (1 or batch_size, seq_len, seq_len), interpreted as (1 or batch_size, tgt_seq_len, + src_seq_len); this reflects masking, if causal == True, otherwise will be all zeros. + + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + src_key_padding_mask: optional key-padding mask of shape (batch_size, seq_len) with True in masked positions. + """ + seq_len, batch_size, _num_channels = x.shape + + ans = torch.zeros(batch_size, seq_len, seq_len, device=x.device) + + if self.causal: + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + src_t = t + tgt_t = t.unsqueeze(-1) + attn_mask = (src_t > tgt_t) + ans.masked_fill_(attn_mask, float('-inf')) + + if src_key_padding_mask is not None: + ans.masked_fill_(src_key_padding_mask.unsqueeze(1), float('-inf')) + # now ans: (batch_size, seq_len, seq_len). + return ans + + + def forward( + self, + x: Tensor, + x_lens: Tensor, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + x: + The input tensor. Its shape is (seq_len, batch_size, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (output_seq_len, batch_size, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + outputs = [] + + attn_offsets = [ self._get_attn_offset(x, src_key_padding_mask) ] + pos_embs = [ self.encoder_pos(x) ] + downsample_info = [] + + if torch.jit.is_scripting(): + feature_masks = [1.0] * len(self.zip_encoders) + else: + feature_masks = self.get_feature_masks(x) + + for s, i in zip(self.structure, self.layer_indexes): + if s == 'Z': + encoder = self.zip_encoders[i] + ds = self.zip_downsampling_factor[i] + x = convert_num_channels(x, self.zip_encoder_dim[i]) + x = encoder( + x, + feature_mask=feature_masks[i], + src_key_padding_mask=( + None + if src_key_padding_mask is None + else src_key_padding_mask[..., ::ds] + ), + ) + outputs.append(x) + elif s == 'S': + encoder = self.sub_encoders[i] # one encoder stack + x = encoder(x, + pos_embs[-1], + attn_offset=attn_offsets[-1]) + + # only the last output of subformer will be used to combine the + # final output. + if i == len(self.sub_encoders) - 1: + outputs.append(x) + # x will have the maximum dimension up till now, even if + # `encoder` uses lower dim in its layers. + elif s == '(': + downsampler = self.downsamplers[i] + + indexes, weights, x_new = downsampler(x) + downsample_info.append((indexes, weights, x)) + x = x_new + + pos_embs.append(downsampler.downsample_pos_emb(pos_embs[-1], indexes)) + + attn_offsets.append(downsampler.downsample_attn_offset(attn_offsets[-1], + indexes, + weights)) + else: + assert s == ')' # upsample and bypass + indexes, weights, x_orig = downsample_info.pop() + _attn_offset = attn_offsets.pop() + _pos_emb = pos_embs.pop() + x_orig = convert_num_channels(x_orig, x.shape[-1]) + + x = LearnedDownsamplingModule.upsample(x_orig, x, indexes, weights) + + bypass = self.bypasses[i] + x = bypass(x_orig, x) + + # Only "balanced" structure is supported now + assert len(downsample_info) == 0, len(downsample_info) + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = self._get_full_dim_output(outputs) + x = self.downsample_output(x) + # class Downsample has this rounding behavior.. + assert self.output_downsampling_factor == 2 + if torch.jit.is_scripting(): + lengths = (x_lens + 1) // 2 + else: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + lengths = (x_lens + 1) // 2 + + return x, lengths diff --git a/egs/librispeech/ASR/subformer/model.py b/egs/librispeech/ASR/subformer/model.py new file mode 100644 index 000000000..7fcab04ae --- /dev/null +++ b/egs/librispeech/ASR/subformer/model.py @@ -0,0 +1,217 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface + +from icefall.utils import add_sos, make_pad_mask +from scaling import ScaledLinear + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder_embed: nn.Module, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + """ + Args: + encoder_embed: + It is a Convolutional 2D subsampling module. It converts + an input of shape (N, T, idim) to an output of of shape + (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output contains + unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder_embed = encoder_embed + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, + vocab_size, + initial_scale=0.25, + ) + self.simple_lm_proj = ScaledLinear( + decoder_dim, + vocab_size, + initial_scale=0.25, + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + ) -> torch.Tensor: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") + x, x_lens = self.encoder_embed(x, x_lens) + # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") + + src_key_padding_mask = make_pad_mask(x_lens) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (encoder_out.size(0), 4), + dtype=torch.int64, + device=encoder_out.device, + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + # if self.training and random.random() < 0.25: + # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) + # if self.training and random.random() < 0.25: + # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction="sum", + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + reduction="sum", + ) + + return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/subformer/optim.py b/egs/librispeech/ASR/subformer/optim.py new file mode 120000 index 000000000..207eecfcd --- /dev/null +++ b/egs/librispeech/ASR/subformer/optim.py @@ -0,0 +1 @@ +../zipformer/optim.py \ No newline at end of file diff --git a/egs/librispeech/ASR/subformer/scaling.py b/egs/librispeech/ASR/subformer/scaling.py new file mode 120000 index 000000000..58e4b0a0f --- /dev/null +++ b/egs/librispeech/ASR/subformer/scaling.py @@ -0,0 +1 @@ +../zipformer/scaling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/subformer/subformer.py b/egs/librispeech/ASR/subformer/subformer.py new file mode 100644 index 000000000..d117ff617 --- /dev/null +++ b/egs/librispeech/ASR/subformer/subformer.py @@ -0,0 +1,1875 @@ +#!/usr/bin/env python3 +# Copyright (c) 2023 Xiaomi Corp. (author: Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import warnings +from typing import List, Optional, Tuple, Union +import logging +import torch +import random +from encoder_interface import EncoderInterface +from scaling import ( + Balancer, + BiasNorm, + Dropout2, + ActivationDropoutAndLinear, + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. + Whiten, + Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons. + penalize_abs_values_gt, + softmax, + ScheduledFloat, + FloatLike, + limit_param_value, + clip_grad, + convert_num_channels, + AbsValuePenalizer, +) +from torch import Tensor, nn + + +class Subformer(EncoderInterface): + """ + Args: + + Note: all "int or Tuple[int]" arguments below will be treated as lists of the same length + as downsampling_factor if they are single ints or one-element tuples. The length of + downsampling_factor defines the number of stacks. + + + structure (str): determines the structure of the module, S is encoder stack, + open-parenthesis is downsampling operation, close-parenthesis is a corresponding + upsampling operation (but not all parentheses have to be closed if you want + the whole stack to downsample.) + encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per + encoder stack (i.e. one per "S" in structure). + encoder_chunk_sizes (Tuple[Tuple[int]]): A tuple containing either one tuple or + one tuple per encoder stack. Each element tuple is a list of the chunk sizes + that we use during training, e.g. (128, 1024); we go through these round-robin + in successive layers. + downsampling_factor (Tuple[int]): downsampling factor for each downsampling + operation (each open-parenthesis). + num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack + query_head_dim (int or Tuple[int]): dimension of query and key per attention + head: per stack, if a tuple.. + value_head_dim (int or Tuple[int]): dimension of value in each attention head + pos_head_dim (int or Tuple[int]): dimension of positional-encoding projection per + attention head + num_heads: (int or Tuple[int]): number of heads in the self-attention mechanism. + Must be at least 4. + feedforward_dim (int or Tuple[int]): hidden dimension in feedforward modules + + pos_dim (int): the dimension of each positional-encoding vector prior to projection, + e.g. 128. + + dropout (float): dropout rate + warmup_batches (float): number of batches to warm up over; this controls + dropout of encoder layers. + causal (bool): if True, use causal attention-mask. + memory_dim: if supplied and >0, will be the dimension of the memory embeddings + passed into the zipformer (e.g. this might be the output of another + Subformer used to create embedding vectors.) + """ + def __init__( + self, + structure: str = "S(S)S", + encoder_dim: Tuple[int, ...] = (384, 512, 384), + downsampling_factor: Tuple[int, ...] = (2,), + encoder_chunk_sizes: Tuple[Tuple[int, ...]] = ((128,1024),), + num_encoder_layers: Union[int, Tuple[int, ...]] = (4,), + query_head_dim: Tuple[int, ...] = (24,), + value_head_dim: Tuple[int, ...] = (12,), + num_heads: Tuple[int, ...] = (8,), + feedforward_dim: Tuple[int, ...] = (1536,), + memory_dim: int = -1, + pos_dim: int = 4, + dropout: Optional[FloatLike] = None, # see code below for default + warmup_batches: float = 4000.0, + causal: bool = False, + ) -> None: + super(Subformer, self).__init__() + + if dropout is None: + dropout = ScheduledFloat((0.0, 0.3), + (20000.0, 0.1)) + + num_encoders = len([s for s in structure if s == 'S']) + num_downsamplers = len([s for s in structure if s == '(']) + # when we upsample, we use the same downsampling object that we + # downsampled with, but we also need a BypassModule at that point. + num_bypass = len([s for s in structure if s == ')']) + + def _to_tuple(x): + """ Converts a single int or a 1-tuple of an int to a tuple with the same length + as num_encoders""" + assert isinstance(x, tuple) + if len(x) == 1: + x = x * num_encoders + else: + assert len(x) == num_encoders + return x + + self.encoder_dim = encoder_dim + encoder_chunk_sizes = _to_tuple(encoder_chunk_sizes) + num_encoder_layers = _to_tuple(num_encoder_layers) + query_head_dim = _to_tuple(query_head_dim) + value_head_dim = _to_tuple(value_head_dim) + num_heads = _to_tuple(num_heads) + feedforward_dim = _to_tuple(feedforward_dim) + self.causal = causal + + + if len(downsampling_factor) == 1: + downsampling_factor = downsampling_factor * num_downsamplers + assert len(downsampling_factor) == num_downsamplers + + # each one will be SubformerEncoder or DownsampledSubformerEncoder + encoders = [] + downsamplers = [] + bypasses = [] + + layer_indexes = [] + + cur_max_dim = encoder_dim[0] + + downsampling_factors_list = [] + def cur_downsampling_factor(): + c = 1 + for d in downsampling_factors_list: c *= d + return c + + for s in structure: + if s == 'S': + i = len(encoders) + encoder_layer = SubformerEncoderLayer( + embed_dim=encoder_dim[i], + pos_dim=pos_dim, + num_heads=num_heads[i], + query_head_dim=query_head_dim[i], + value_head_dim=value_head_dim[i], + feedforward_dim=feedforward_dim[i], + memory_dim=memory_dim, + dropout=dropout, + causal=causal, + ) + cur_max_dim = max(cur_max_dim, encoder_dim[i]) + encoder = SubformerEncoder( + encoder_layer, + num_encoder_layers[i], + embed_dim=cur_max_dim, + dropout=dropout, + chunk_sizes=encoder_chunk_sizes[i], + warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + final_layerdrop_rate=0.035 * (cur_downsampling_factor() ** 0.5), + ) + layer_indexes.append(len(encoders)) + encoders.append(encoder) + elif s =='(': + i = len(downsamplers) + downsampler = LearnedDownsamplingModule(cur_max_dim, + downsampling_factor[i]) + downsampling_factors_list.append(downsampling_factor[i]) + layer_indexes.append(len(downsamplers)) + downsamplers.append(downsampler) + else: + assert s == ')' + bypass = BypassModule(cur_max_dim, straight_through_rate=0.0) + layer_indexes.append(len(bypasses)) + bypasses.append(bypass) + downsampling_factors_list.pop() + + logging.info(f"cur_downsampling_factor={cur_downsampling_factor()}") + + self.layer_indexes = layer_indexes + self.structure = structure + self.encoders = nn.ModuleList(encoders) + self.downsamplers = nn.ModuleList(downsamplers) + self.bypasses = nn.ModuleList(bypasses) + + self.encoder_pos = CompactRelPositionalEncoding(64, pos_dim, + dropout_rate=0.15, + length_factor=1.0) + + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + src_key_padding_mask: Optional[torch.Tensor] = None, + memory: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + The input tensor. Its shape is (batch_size, seq_len, feature_dim). + x_lens: + A tensor of shape (batch_size,) containing the number of frames in + `x` before padding. + src_key_padding_mask: + The mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) + memory_key_padding_mask: optionally the mask for padding of memory input (for source- + attention), of shape (batch_size, memory_len); True means + masked position. May be None. + + Returns: + Return a tuple containing 2 tensors: + - embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim)) + - lengths, a tensor of shape (batch_size,) containing the number + of frames in `embeddings` before padding. + """ + outputs = [] + + + if self.training and memory is not None: + batch_size = x.shape[1] + # setting memory to zero should be equivalent to not using the + # memory input at all, since the Attention module has no biases. + memory_dropout_rate = 0.05 + memory = memory * (torch.rand(batch_size, 1, device=memory.device) > + memory_dropout_rate) + + attn_offsets = [ self._get_attn_offset(x, src_key_padding_mask) ] + pos_embs = [ self.encoder_pos(x) ] + downsample_info = [] + + for s, i in zip(self.structure, self.layer_indexes): + if s == 'S': + encoder = self.encoders[i] # one encoder stack + x = encoder(x, + pos_embs[-1], + attn_offset=attn_offsets[-1], + memory=memory, + memory_key_padding_mask=memory_key_padding_mask) + # x will have the maximum dimension up till now, even if + # `encoder` uses lower dim in its layers. + elif s == '(': + downsampler = self.downsamplers[i] + + indexes, weights, x_new = downsampler(x) + downsample_info.append((indexes, weights, x)) + x = x_new + + pos_embs.append(downsampler.downsample_pos_emb(pos_embs[-1], indexes)) + + attn_offsets.append(downsampler.downsample_attn_offset(attn_offsets[-1], + indexes, + weights)) + + else: + assert s == ')' # upsample and bypass + indexes, weights, x_orig = downsample_info.pop() + _attn_offset = attn_offsets.pop() + _pos_emb = pos_embs.pop() + x_orig = convert_num_channels(x_orig, x.shape[-1]) + + x = LearnedDownsamplingModule.upsample(x_orig, x, indexes, weights) + + bypass = self.bypasses[i] + x = bypass(x_orig, x) + + # d = self.output_downsampling_factor + # lengths = (x_lens + d - 1) // d + + + # The next code block will only run in the case of "unbalanced" structures, e.g. + # if structure == "S(S(S)S", where there are unmatched right-parentheses. + cur_indexes = None + while len(downsample_info) > 0: + indexes, weights, x_orig = downsample_info.pop() + if cur_indexes is not None: + # keep only a subset of the indexes and weights, corresponding + # to later downsampling operations. + indexes = torch.gather(indexes, dim=1, index=cur_indexes) + weights = torch.gather(weights, dim=1, index=cur_indexes) + + cur_indexes = indexes + + x_lens = (weights != 0).sum(dim=1) + x_orig = convert_num_channels(x_orig, x.shape[-1]) + x_orig, x = LearnedDownsamplingModule.apply_weights(x_orig, x, indexes, weights) + + + return x, x_lens + + + def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]: + """ + Return attention offset of shape (1 or batch_size, seq_len, seq_len), interpreted as (1 or batch_size, tgt_seq_len, + src_seq_len); this reflects masking, if causal == True, otherwise will be all zeros. + + Args: + x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). + src_key_padding_mask: optional key-padding mask of shape (batch_size, seq_len) with True in masked positions. + """ + seq_len, batch_size, _num_channels = x.shape + + ans = torch.zeros(1, seq_len, seq_len, device=x.device) + + if self.causal: + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + src_t = t + tgt_t = t.unsqueeze(-1) + attn_mask = (src_t > tgt_t) + ans.masked_fill_(attn_mask, float('-inf')) + + if src_key_padding_mask is not None: + ans.masked_fill_(src_key_padding_mask.unsqueeze(1), float('-inf')) + # now ans: (batch_size, seq_len, seq_len). + + return ans + + + +def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: + return ScheduledFloat((0.0, x), + (20000.0, ratio * x), + default=x) + +def _balancer_schedule(min_prob: float): + return ScheduledFloat((0.0, 0.4), (8000.0, min_prob)) + + + +class SubformerEncoderLayer(nn.Module): + """ + Args: + embed_dim: the number of expected features in the input (required). + nhead: the number of heads in the multiheadattention models (required). + feedforward_dim: the dimension of the feedforward network model (default=2048). + dropout: the dropout value (default=0.1). + + Examples:: + >>> encoder_layer = SubformerEncoderLayer(embed_dim=512, nhead=8) + >>> src = torch.rand(10, 32, 512) + >>> pos_emb = torch.rand(32, 19, 512) + >>> out = encoder_layer(src, pos_emb) + """ + def __init__( + self, + embed_dim: int, + num_heads: int, + query_head_dim: int, + value_head_dim: int, + pos_dim: int, + feedforward_dim: int, + dropout: FloatLike = 0.1, + causal: bool = False, + memory_dim: int = -1, + attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), + const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.0), default=0), + ff2_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), + ff3_skip_rate: FloatLike = ScheduledFloat((0.0, 0.1), (4000.0, 0.01), (50000.0, 0.0)), + bypass_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), (4000.0, 0.02), default=0), + ) -> None: + super(SubformerEncoderLayer, self).__init__() + self.embed_dim = embed_dim + + # self.bypass implements layer skipping as well as bypass; see its default values. + self.bypass = BypassModule(embed_dim, skip_rate=bypass_skip_rate) + + # bypass_mid is bypass used in the middle of the layer. + self.bypass_mid = BypassModule(embed_dim) + + + # skip probability for dynamic modules (meaning: anything but feedforward). + self.attention_skip_rate = copy.deepcopy(attention_skip_rate) + + + # ff2_skip_rate is to prevent the ff2 module from having output that's too big + # compared to its residual. + self.ff2_skip_rate = copy.deepcopy(ff2_skip_rate) + self.ff3_skip_rate = copy.deepcopy(ff3_skip_rate) + + self.const_attention_rate = copy.deepcopy(const_attention_rate) + + self.self_attn_weights = RelPositionMultiheadAttentionWeights( + embed_dim, num_heads=num_heads, + query_head_dim=query_head_dim, pos_dim=pos_dim, + dropout=0.0, + ) + + + self.self_attn1 = Attention(embed_dim, embed_dim, num_heads, + value_head_dim) + + self.self_attn2 = Attention(embed_dim, embed_dim, num_heads, + value_head_dim) + + if memory_dim > 0: + self.attn_weights = MultiheadAttentionWeights( + memory_dim, + embed_dim, + num_heads=num_heads, + head_dim=query_head_dim, + dropout=0.0, + ) + self.src_attn1 = Attention(memory_dim, embed_dim, num_heads, + value_head_dim) + self.src_attn2 = Attention(memory_dim, embed_dim, num_heads, + value_head_dim) + + + self.feed_forward1 = FeedforwardModule(embed_dim, + (feedforward_dim * 3) // 4, + dropout) + + self.feed_forward2 = FeedforwardModule(embed_dim, + feedforward_dim, + dropout) + + self.feed_forward3 = FeedforwardModule(embed_dim, + (feedforward_dim * 5) // 4, + dropout) + + self.nonlin_attention = NonlinAttention(embed_dim, + hidden_channels=3 * embed_dim // 4) + + + #self.attention_squeeze = AttentionSqueeze(embed_dim, embed_dim // 2) + + self.norm = BiasNorm(embed_dim) + + self.balancer1 = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.45, max_positive=0.55, + min_abs=0.2, max_abs=4.0, + ) + + # balancer for output of NonlinAttentionModule + self.balancer_na = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.3, max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.004), (4000.0, 0.02)), + prob=0.05, # out of concern for memory usage + ) + + # balancer for output of feedforward2, prevent it from staying too + # small. give this a very small probability, even at the start of + # training, it's to fix a rare problem and it's OK to fix it slowly. + self.balancer_ff2 = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.3, max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.1), default=0.0), + max_abs=2.0, + prob=0.05, + ) + + self.balancer_ff3 = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.3, max_positive=0.7, + min_abs=ScheduledFloat((0.0, 0.0), (4000.0, 0.2), default=0.0), + max_abs=4.0, + prob=0.05, + ) + + self.whiten = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(4.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01) + + self.balancer2 = Balancer( + embed_dim, channel_dim=-1, + min_positive=0.45, max_positive=0.55, + min_abs=0.1, max_abs=4.0, + ) + + def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]: + if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting(): + return None + batch_size = x.shape[1] + mask = (torch.rand(batch_size, 1, device=x.device) > dropout_rate).to(x.dtype) + return mask + + + def sequence_dropout(self, x: Tensor, dropout_rate: float) -> Tensor: + """ + Apply sequence-level dropout to x. + x shape: (seq_len, batch_size, embed_dim) + """ + dropout_mask = self.get_sequence_dropout_mask(x, dropout_rate) + if dropout_mask is None: + return x + else: + return x * dropout_mask + + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + attn_offset: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + memory: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + """ + Pass the input through the encoder layer. + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: (batch_size, seq_len, seq_len, pos_dim), with e.g. pos_dim=4: relatie positional + embedding tensor. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_offset: the attention offset, of shape broadcasting with (batch_size, seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len). -inf for masked position. + src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means + masked position. May be None. + + Returns: + A tensor which has the same shape as src + """ + src_orig = src + + # dropout rate for non-feedforward submodules + attention_skip_rate = float(self.attention_skip_rate) if self.training else 0.0 + + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + attn_weights = self.self_attn_weights( + src, + pos_emb=pos_emb, + attn_offset=attn_offset, + ) + + if memory is not None and hasattr(self, 'attn_weights'): + src_attn_weights = self.attn_weights(memory, src, memory_key_padding_mask) + + src = src + self.feed_forward1(src) + + attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) + + if True: + selected_attn_weights = attn_weights[0:1] + if random.random() < float(self.const_attention_rate): + # Make attention weights constant. The intention is to + # encourage these modules to do something similar to an + # averaging-over-time operation. + # only need the mask, can just use the 1st one and expand later + selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype) + selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) + + + na = self.balancer_na(self.nonlin_attention(src, + selected_attn_weights[0:1])) + + src = src + (na if attn_dropout_mask is None else na * attn_dropout_mask) + + self_attn = self.self_attn1( + src, attn_weights) + + src = src + (self_attn if attn_dropout_mask is None else self_attn * attn_dropout_mask) + + if memory is not None and hasattr(self, 'attn_weights'): + src = src + self.sequence_dropout(self.src_attn1(memory, src_attn_weights), + attention_skip_rate) + + src = src + self.sequence_dropout(self.balancer_ff2(self.feed_forward2(src)), + float(self.ff2_skip_rate)) + + # bypass in the middle of the layer. + src = self.bypass_mid(src_orig, src) + + self_attn = self.self_attn2( + src, attn_weights) + + src = src + (self_attn if attn_dropout_mask is None else self_attn * attn_dropout_mask) + + if memory is not None and hasattr(self, 'attn_weights'): + src = src + self.sequence_dropout(self.src_attn2(memory, src_attn_weights), + attention_skip_rate) + + src = src + self.sequence_dropout(self.balancer_ff3(self.feed_forward3(src)), + float(self.ff3_skip_rate)) + + src = self.balancer1(src) + src = self.norm(src) + + src = self.bypass(src_orig, src) + + src = self.balancer2(src) + src = self.whiten(src) + + return src + +class SubformerEncoder(nn.Module): + r"""SubformerEncoder is a stack of N encoder layers + + Args: + encoder_layer: an instance of the SubformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + embed_dim: the embedding dimension to use for the bypass (may exceed the + dimension of encoder_layer, as it may not operate on the full + dimension). + + Examples:: + >>> encoder_layer = SubformerEncoderLayer(embed_dim=512, nhead=8) + >>> zipformer_encoder = SubformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = zipformer_encoder(src) + """ + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + embed_dim: int, + dropout: float, + warmup_begin: float, + warmup_end: float, + chunk_sizes: Tuple[int, ...] = (128, 2048), + initial_layerdrop_rate: float = 0.5, + final_layerdrop_rate: float = 0.05, + ) -> None: + super().__init__() + + self.chunk_sizes = chunk_sizes + + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + + self.bypass = BypassModule(embed_dim) + + assert 0 <= warmup_begin <= warmup_end + + delta = (1. / num_layers) * (warmup_end - warmup_begin) + cur_begin = warmup_begin # interpreted as a training batch index + for i in range(num_layers): + cur_end = cur_begin + delta + self.layers[i].bypass.skip_rate = ScheduledFloat((cur_begin, initial_layerdrop_rate), + (cur_end, final_layerdrop_rate), + default=0.0) + cur_begin = cur_end + + def embed_dim(self): + return self.bypass.embed_dim() + + def forward( + self, + src: Tensor, + pos_emb: Tensor, + attn_offset: Optional[Tensor] = None, + memory: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). + pos_emb: positional embedding tensor, of shape (batch_size, seq_len, seq_len, pos_dim), + e.g. pos_dim=4. + feature_mask: something that broadcasts with src, that we'll multiply `src` + by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) + attn_offset: the attention offset (does masking and related tasks), of shape + broadcasting with (batch_size, seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len). + memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) + memory_key_padding_mask: optionally the mask for padding of memory input (for source- + attention), of shape (batch_size, memory_len); True means + masked position. May be None. + + Returns: a Tensor with the same shape as src. + """ + output = convert_num_channels(src, self.layers[0].embed_dim) + + chunk_sizes, chunk_indexes = self._get_chunk_sizes(src) + b = src.shape[1] # batch_size + + pos_embs = [ self._pos_emb_to_chunk_size(pos_emb, c) for c in chunk_sizes ] + attn_offsets = [ self._attn_offset_to_chunk_size(attn_offset, b, c) for c in chunk_sizes ] + # TODO: support this for memory also; would require duplicating it maybe; + # or could modify the interior code to just assume chunking + # when doing cross-attention. + for i, mod in enumerate(self.layers): + ci = chunk_indexes[i] + c = chunk_sizes[ci] + output = self._to_chunk_size(output, c) + output = mod( + output, + pos_embs[ci], + attn_offset=attn_offsets[ci], + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, + ) + + #if feature_mask is not None: + # output = output * feature_mask + + output = self._to_chunk_size(output, src.shape[0]) + + output = convert_num_channels(output, self.bypass.embed_dim()) + src = convert_num_channels(src, self.bypass.embed_dim()) + + return self.bypass(src, output) + + def _get_chunk_sizes(self, src: Tensor) -> Tuple[List[int], List[int]]: + """ + Decide the chunk sizes (in frames) to use for each layer. + Args: + src: the input embeddings, of shape (seq_len, batch_size, embed_dim) + Returns: (chunk_sizes, chunk_indexes), where: + chunk_sizes: a list of the unique chunk sizes to use, e.g. [ 128, 256 ] + chunk_indexes: a list of indexes into chunk_sizes, one per layer. + """ + seq_len = src.shape[0] + chunk_indexes = [] + chunk_sizes = [] + for i, chunk_size in enumerate(self.chunk_sizes): + chunk_sizes.append(chunk_size if seq_len % chunk_size == 0 + else seq_len) + + num_chunk_sizes = len(self.chunk_sizes) + for i in range(self.num_layers): + chunk_indexes.append(i % num_chunk_sizes) + + return chunk_sizes, chunk_indexes + + def _to_chunk_size(self, src: Tensor, chunk_size: int) -> Tensor: + """ + Reshape embeddings 'src' to have a different chunk size (in frames) by + changing the batch size. + """ + (seq_len, batch_size, num_channels) = src.shape + if chunk_size == seq_len: + return src + src = src.transpose(0, 1).contiguous().reshape(-1, chunk_size, num_channels) + return src.transpose(0, 1).contiguous() + + + def _attn_offset_to_chunk_size(self, attn_offset: Tensor, batch_size: int, chunk_size: int) -> Tensor: + """ + Break up attention offset into a given chunk size + """ + (_batch_size, seq_len, seq_len) = attn_offset.shape + if seq_len == chunk_size: + return attn_offset + if _batch_size != batch_size: + assert _batch_size == 1 + attn_offset = attn_offset.expand(batch_size, seq_len, seq_len) + + assert seq_len % chunk_size == 0 + + num_chunks = seq_len // chunk_size + + batch_stride, tgt_stride, src_stride = attn_offset.stride() + + # have the 'chunk' dimension first so it has larger stride than the original batch; this + # is to match what happens to the embeddings in 'src' where the time-stride is first. + attn_offset = attn_offset.as_strided((num_chunks, batch_size, chunk_size, chunk_size), + ((tgt_stride + src_stride) * chunk_size, batch_stride, + tgt_stride, src_stride)) + + return attn_offset.contiguous().reshape(num_chunks * batch_size, chunk_size, chunk_size) + + + def _pos_emb_to_chunk_size(self, pos_emb: Tensor, chunk_size: int) -> Tensor: + """ + Break up positional embedding tensor into a given chunk size + """ + (batch_size, seq_len, seq_len, pos_dim) = pos_emb.shape + if seq_len == chunk_size: + return pos_emb + assert seq_len % chunk_size == 0 + + num_chunks = seq_len // chunk_size + + batch_stride, tgt_stride, src_stride, channel_stride = pos_emb.stride() + + pos_emb = pos_emb.as_strided((num_chunks, batch_size, chunk_size, chunk_size, pos_dim), + ((tgt_stride + src_stride) * chunk_size, batch_stride, + tgt_stride, src_stride, channel_stride)) + + return pos_emb.contiguous().reshape(num_chunks * batch_size, + chunk_size, chunk_size, + pos_dim) + + + +class BypassModule(nn.Module): + """ + An nn.Module that implements a learnable bypass scale, and also randomized per-sequence + layer-skipping. The bypass is limited during early stages of training to be close to + "straight-through", i.e. to not do the bypass operation much initially, in order to + force all the modules to learn something. + """ + def __init__( + self, + embed_dim: int, + skip_rate: FloatLike = 0.0, + straight_through_rate: FloatLike = 0.0, + scale_min: FloatLike = ScheduledFloat((0.0, 0.9), (20000.0, 0.2), default=0), + scale_max: FloatLike = 1.0): + super().__init__() + self.bypass_scale = nn.Parameter(torch.full((embed_dim,), 0.5)) + self.skip_rate = copy.deepcopy(skip_rate) + self.straight_through_rate = copy.deepcopy(straight_through_rate) + self.scale_min = copy.deepcopy(scale_min) + self.scale_max = copy.deepcopy(scale_max) + + def embed_dim(self): + return self.bypass_scale.numel() + + def _get_bypass_scale(self, batch_size: int): + # returns bypass-scale of shape (num_channels,), + # or (batch_size, num_channels,). This is actually the + # scale on the non-residual term, so 0 correponds to bypassing + # this module. + if torch.jit.is_scripting() or not self.training: + return self.bypass_scale + else: + ans = limit_param_value(self.bypass_scale, + min=float(self.scale_min), + max=float(self.scale_max)) + skip_rate = float(self.skip_rate) + if skip_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) > skip_rate + ans = ans * mask + # now ans is of shape (batch_size, num_channels), and is zero for sequences + # on which we have randomly chosen to do layer-skipping. + straight_through_rate = float(self.straight_through_rate) + if straight_through_rate != 0.0: + mask = torch.rand((batch_size, 1), device=ans.device) < straight_through_rate + ans = torch.maximum(ans, mask.to(ans.dtype)) + + return ans + + def forward(self, + src_orig: Tensor, + src: Tensor): + """ + Args: src_orig and src are both of shape (seq_len, batch_size, num_channels) + Returns: something with the same shape as src and src_orig + """ + bypass_scale = self._get_bypass_scale(src.shape[1]) + return src_orig + (src - src_orig) * bypass_scale + + +class LearnedDownsamplingModule(nn.Module): + """ + Module that allows you to choose which frames to keep for transformer-type + modules. Effectively downsampling, but not necessarily "evenly"- you just + keep some proportion of frames determined by the embedding. + + Args: + embed_dim: embedding dimension + downsampling_factor: factor to downsample by, e.g. 2 or 4. There is no + fundamental reason why this has to be an integer, but we make it so + anyway. + """ + def __init__(self, + embed_dim: int, + downsampling_factor: int): + assert downsampling_factor > 1 + + super().__init__() + + self.to_scores = nn.Linear(embed_dim, 1, bias=False) + self.to_scores.lr_scale = 0.5 + # score_balancer is just to keep the magnitudes of the scores in + # a fixed range and keep them balanced around zero, to stop + # these drifting around. + # largish range used to keep grads relatively small and avoid overflow in grads. + self.score_balancer = Balancer(1, channel_dim=-1, + min_positive=1/(2*downsampling_factor), + max_positive=0.6, + min_abs=1.0, + max_abs=4.0, + prob=ScheduledFloat((0.0, 1.0), (8000.0, 0.25), default=0.0)) + + + # below are for diagnostics. + self.copy_weights1 = nn.Identity() + self.copy_weights2 = nn.Identity() + + self.downsampling_factor = downsampling_factor + + + def forward(self, + x: Tensor) -> Tuple[Tensor, Tensor, Tensor]: + """ + Args: + x: a Tensor of shape (seq_len, batch_size, embed_dim) + + Returns: (frame_indexes, weights, kept) + + frame_indexes: a Tensor of integer type, of shape (batch_size, reduced_seq_len) + where reduced_seq_len = (seq_len + d - 1) // d. It contains elements + 0 <= frame_indees < seq_len, in sorted (increasing) order + + weights: a Tensor of shape (batch_size, reduced_seq_len), + corresponding to the kept frames; these will be between 0 and 1, but + mostly exactly 1. + """ + (seq_len, batch_size, _) = x.shape + scores = self.to_scores(x) # (seq_len, batch_size, 1) + scores = self.score_balancer(scores) + + scores = scores.squeeze(-1).t() # (batch_size, seq_len) + + # sscores, indexes: (batch_size, seq_len) + sscores, indexes = scores.sort(dim=-1, descending=True) + + + weights = sscores.clamp(min=0.0, max=1.0) + weights = self.copy_weights1(weights) + + if self.training: + d = self.downsampling_factor + seq_len_reduced = (seq_len + d - 1) // d + + weights_discarded = weights[:, seq_len_reduced:2*seq_len_reduced] + missing = seq_len_reduced - weights_discarded.shape[1] + if missing != 0: + weights_discarded = torch.cat((weights_discarded, + torch.zeros(batch_size, missing, + device=weights.device, + dtype=weights.dtype)), + dim=1) + + if random.random() < 0.01 or __name__ == '__main__': + logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, discarded-weights={weights_discarded.mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}") + + + if random.random() < 0.5: + # flipping it half the time increases the randomness, so gives an extra incentive + # to avoid nonzero weights in the discarded half + weights_discarded = weights_discarded.flip(dims=(1,)) + + weights = weights[:, :seq_len_reduced] - weights_discarded + else: + # test mode. because the sequence might be short, we keep all nonzero scores; + # and there is no need for any penalty. + + # need to work out seq_len_reduced. + seq_len_reduced = max(1, + (weights > 0.0).to(torch.int32).sum(dim=-1).max().item()) + if random.random() < 0.02: + logging.info(f"seq_len={seq_len}, seq_len_reduced={seq_len_reduced}") + weights = weights[:, :seq_len_reduced] + + indexes = indexes[:, :seq_len_reduced] + + + weights = self.copy_weights2(weights) + + # re-sort the indexes we kept, on index value, so that + # masking for causal models will be in the correct order. + # (actually this may not really matter, TODO: see whether we + # can remove this??) + indexes, reorder = indexes.sort(dim=-1) + weights = torch.gather(weights, dim=-1, index=reorder) + + x_downsampled = self.downsample(x, indexes) + return indexes, weights, x_downsampled + + + def downsample(self, x: Tensor, indexes: Tensor) -> Tensor: + """ + Downsamples x via indexing with the indexes obtained from the + forward() function. + + Args: + x: tensor of shape (seq_len, batch_size, num_channels) + indexes: integer indexes of shape (batch_size, seq_len_reduced), with elements + 0 <= indexes < seq_len. + Returns: + x_downsampled, of shape (seq_len_reduced, batch_size, num_channels) + """ + indexes_expanded = indexes.t().unsqueeze(-1).expand(-1, -1, x.shape[-1]) + # indexe_expanded: (seq_len_reduced, batch_size, num_channels) + ans = torch.gather(x, dim=0, index=indexes_expanded) + + if __name__ == '__main__': + # temp, for testing + x_reconstructed = self.upsample(x, ans, indexes) + assert torch.allclose(x, x_reconstructed) + + return ans + + + def downsample_pos_emb(self, pos_emb: Tensor, indexes: Tensor) -> Tensor: + """ + Downsample positional embedding tensor with the provided indexes. + Args: + pos_emb: (batch_size, seq_len, seq_len, pos_dim) + interpreted as (batch_size, tgt_seq_len, src_seq_len, pos_dim). + indexes: (batch_size, seq_len_reduced), containing integer elements + 0 <= indexes < seq_len. + Returns: + downsampled_pos_len: (batch_size, seq_len_reduced, seq_len_reduced, pos_dim) + """ + + (batch_size, seq_len_reduced) = indexes.shape + (_, _, seq_len, pos_dim) = pos_emb.shape + + tgt_indexes = indexes.reshape(batch_size, seq_len_reduced, 1, 1).expand( + batch_size, seq_len_reduced, seq_len, pos_dim) + + pos_emb = torch.gather(pos_emb, dim=1, index=tgt_indexes) + # now pos_emb: (batch_size, seq_len_reduced, seq_len, pos_dim) + + src_indexes = indexes.reshape(batch_size, 1, seq_len_reduced, 1).expand( + batch_size, seq_len_reduced, seq_len_reduced, pos_dim) + + pos_emb = torch.gather(pos_emb, dim=2, index=src_indexes) + # now pos_emb: (batch_size, seq_len_reduced, seq_len_reduced, pos_dim) + return pos_emb + + + def downsample_attn_offset(self, + attn_offset: Tensor, + indexes: Tensor, + weights: Tensor, + eps: float = 1.0e-03) -> Tensor: + """ + Downsamples attn_offset and also modifies it to account for the weights in `weights`. + Args: + attn_offset: a Tensor of shape (1 or batch_size, seq_len, seq_len), interpreted as + (1 or batch_size, tgt_seq_len, src_seq_len) + indexes: a Tensor of shape (batch_size, reduced_seq_len) containing elements + 0 <= indexes < seq_len. + weights: a Tensor of shape (batch_size, reduced_seq_len) containing weights + between 0 and 1; most will be 1. + Returns: + attn_offset_downsampled, a Tensor of shape (batch_size, reduced_seq_len, reduced_seq_len) + """ + (batch_size, seq_len_reduced) = indexes.shape + seq_len = attn_offset.shape[-1] + assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len) + attn_offset = attn_offset.expand(batch_size, seq_len, seq_len) + + if torch.is_autocast_enabled(): + # it's possible to get large gradients at this point; clip these at + # this point to reduce the extent to which it has to reduce the + # grad_scale. + weights = clip_grad(weights, 5000.0) + + attn_offset = attn_offset.gather(dim=1, index=indexes.unsqueeze(-1).expand( + batch_size, seq_len_reduced, seq_len)) + attn_offset = attn_offset.gather(dim=2, index=indexes.unsqueeze(1).expand( + batch_size, seq_len_reduced, seq_len_reduced)) + # unsqueeze at position 1 so the extra cost relates to the source position. + attn_offset = attn_offset + (weights + eps).log().unsqueeze(1) + + return attn_offset + + + @staticmethod + def upsample(x_orig: Tensor, x: Tensor, indexes: Tensor, + weights: Optional[Tensor] = None) -> Tensor: + """ + Upsamples, reversing the downsample() operation and filling in + any not-chosen frames with their original value before downsampling + (or with whatever x_orig contains). + + Args: + x_orig: (seq_len, batch_size, num_channels) + x: (seq_len_reduced, batch_size, num_channels) + indexes: (batch_size, seq_len_reduced), contains original frame indexes + weights: optional tensor of shape (batch_size, seq_len_reduced) + + Downsamples x via indexing with the indexes obtained from the + forward() function. + + Args: + x: tensor of shape (seq_len, batch_size, indexes) + weights: a tensor of shape (batch_size, seq_len_reduced) containing weights between + 0 and 1, where 1 means fully use this x value and 0 means use x_orig + indexes: integer indexes of shape (batch_size, seq_len_reduced), with elements + 0 <= indexes < seq_len. + """ + (seq_len, batch_size, num_channels) = x_orig.shape + + x_weight = 1.0 if weights is None else weights.t().unsqueeze(-1) + # x_weight: (seq_len_reduced, batch_size, 1) if a tensor + + orig_x_weight = torch.ones(batch_size, seq_len, + device=x.device, dtype=x.dtype) + if weights is None: + orig_x_weight.scatter_(dim=1, index=indexes, value=0.) + else: + orig_x_weight.scatter_(dim=1, index=indexes, + src=(1. - weights).to(x.dtype)) + + indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels) + # indexes now: (seq_len_reduced, batch_size, num_channels) + + ans = torch.zeros_like(x_orig) + + ans.scatter_(dim=0, index=indexes, src=(x * x_weight)) + + # add in x_orig in the frames that were not originally kept. + return ans + x_orig * orig_x_weight.t().unsqueeze(-1) + + @staticmethod + def apply_weights(x_orig: Tensor, x: Tensor, indexes: Tensor, + weights: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: + """ + Downsamples x_orig to have the same shape as x and applies the weights, + returning interpolated x and downsampled x_orig. This is similar to + `upsample`, but is for the case where you don't want to keep the frames + that were not sampled. + + Args: + x_orig: (seq_len, batch_size, num_channels) + x: (seq_len_reduced, batch_size, num_channels) + indexes: (batch_size, seq_len_reduced), contains original frame indexes + weights: optional tensor of shape (batch_size, seq_len_reduced) + + Returns (x_orig, x) after the downsampling and interpolation, of shapes + both (seq_len_reduced, batch_size, num_channels). + """ + (seq_len, batch_size, num_channels) = x_orig.shape + weights = 1.0 if weights is None else weights.t().unsqueeze(-1) + + indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels) + # indexes now: (seq_len_reduced, batch_size, num_channels) + x_orig = torch.gather(x_orig, dim=0, index=indexes) + + x = x * weights + x_orig * (1.0 - weights) + + return x_orig, x + + + +class CompactRelPositionalEncoding(torch.nn.Module): + """ + Relative positional encoding module. This version is "compact" meaning it is able to encode + the important information about the relative position in a relatively small number of dimensions. + The goal is to make it so that small differences between large relative offsets (e.g. 1000 vs. 1001) + make very little difference to the embedding. Such differences were potentially important + when encoding absolute position, but not important when encoding relative position because there + is now no need to compare two large offsets with each other. + + Our embedding works done by projecting the interval [-infinity,infinity] to a finite interval + using the atan() function, before doing the fourier transform of that fixed interval. The + atan() function would compress the "long tails" too small, + making it hard to distinguish between different magnitudes of large offsets, so we use a logarithmic + function to compress large offsets to a smaller range before applying atan(). + Scalings are chosen in such a way that the embedding can clearly distinguish invidual offsets as long + as they are quite close to the origin, e.g. abs(offset) <= about sqrt(embedding_dim) + + + Args: + embed_dim: Temporary embedding dimension used inside this module + pos_dim: Smaller positional-encoding dim used after a projecction. + dropout_rate: Dropout rate. + max_len: Maximum input length: just a heuristic for initialization. + length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives + less weight to small differences of offset near the origin. + pos_dim: dimension at the output of this module. + """ + def __init__( + self, + embed_dim: int, + pos_dim: int, + dropout_rate: FloatLike, + max_len: int = 1000, + length_factor: float = 1.0, + ) -> None: + """Construct a CompactRelPositionalEncoding object.""" + super(CompactRelPositionalEncoding, self).__init__() + self.embed_dim = embed_dim + assert embed_dim % 2 == 0 + self.dropout = Dropout2(dropout_rate) + self.pe = None + assert length_factor >= 1.0 + self.length_factor = length_factor + self.extend_pe(torch.tensor(0.0).expand(max_len)) + + # linear transformation for positional encoding. + self.linear_pos = ScaledLinear(embed_dim, + pos_dim, + bias=False, + initial_scale=0.05) + + + def extend_pe(self, x: Tensor) -> None: + """Reset the positional encodings.""" + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(0) >= x.size(0) * 2 - 1: + # Note: TorchScript doesn't implement operator== for torch.Device + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + + T = x.size(0) + # if T == 4, x would contain [ -3, -2, 1, 0, 1, 2, 3 ] + x = torch.arange(-(T-1), T, + device=x.device).to(torch.float32).unsqueeze(1) + + freqs = 1 + torch.arange(self.embed_dim // 2, device=x.device) + + # `compression_length` this is arbitrary/heuristic, if it is larger we have more resolution + # for small time offsets but less resolution for large time offsets. + compression_length = (self.embed_dim ** 0.5) + # x_compressed, like X, goes from -infinity to infinity as T goes from -infinity to infinity; + # but it does so more slowly than T for large absolute values of T. + # The formula is chosen so that d(x_compressed )/dx is 1 around x == 0, which + # is important. + x_compressed = compression_length * x.sign() * ((x.abs() + compression_length).log() - math.log(compression_length)) + + # if self.length_factor == 1.0, then length_scale is chosen so that the + # FFT can exactly separate points close to the origin (T == 0). So this + # part of the formulation is not really heuristic. + # But empirically, for ASR at least, length_factor > 1.0 seems to work better. + length_scale = self.length_factor * self.embed_dim / (2.0 * math.pi) + + # note for machine implementations: if atan is not available, we can use: + # x.sign() * ((1 / (x.abs() + 1)) - 1) * (-math.pi/2) + # check on wolframalpha.com: plot(sign(x) * (1 / ( abs(x) + 1) - 1 ) * -pi/2 , atan(x)) + x_atan = (x_compressed / length_scale).atan() # results between -pi and pi + + cosines = (x_atan * freqs).cos() + sines = (x_atan * freqs).sin() + + pe = torch.zeros(x.shape[0], self.embed_dim, device=x.device) + pe[:, 0::2] = cosines + pe[:, 1::2] = sines + pe[:, -1] = 1.0 # for bias. + + self.pe = pe.to(dtype=x.dtype) + + + def forward(self, x: torch.Tensor) -> Tensor: + """Create positional encoding. + + Args: + x (torch.Tensor): Input tensor (seq_len, batch_size, num_channels_in) + + Returns: + positional embedding, of shape (batch_size, seq_len, seq_len, pos_dim). + """ + self.extend_pe(x) + seq_len = x.size(0) + pos_emb = self.pe[ + self.pe.size(0) // 2 - seq_len + 1 : self.pe.size(0) // 2 + seq_len, + : + ] + pos_emb = pos_emb.unsqueeze(0) + pos_emb = self.dropout(pos_emb) + pos_emb = self.linear_pos(pos_emb) + + # currenly pos_emb: (1, 2*seq_len-1, pos_dim) + pos_dim = pos_emb.shape[-1] + batch_size = x.size(1) + # it doesn't really matter which one we make positive and which negative here, it + # would just flip the meaning of the embedding. + + + # expand the '1' dimension to seq_len; this introduces a dimension that + # 'does nothing', just creates copies, as a workaround for lack of torch support + # for negative strides. + pos_emb = pos_emb.expand(seq_len, 2*seq_len-1, pos_dim).contiguous() + + (useless_stride, seq_stride, channel_stride) = pos_emb.stride() + + pos_emb = pos_emb.as_strided((batch_size, seq_len, seq_len, pos_dim), + (0, useless_stride-seq_stride, seq_stride, channel_stride), + storage_offset=seq_stride * (seq_len - 1)) + + return pos_emb # (batch_size, seq_len, seq_len, pos_dim) + + + + +class RelPositionMultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head attention weights with relative position encoding; + in this version, the positions for each frame are passed in (in order to support + + + Various other modules consume the resulting attention weights: see, for example, the + SimpleAttention module which allows you to compute conventional attention. + + This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", + we have to write up the differences. + + + Args: + embed_dim: number of channels at the input to this module, e.g. 256 + num_heads: number of heads to compute weights for, e.g. 8 + query_head_dim: dimension of the query (and key), per head. e.g. 24. + pos_dim: dimension of the projected positional encoding, e.g. 4. + dropout: dropout probability for attn_output_weights. Default: 0.0. + pos_emb_skip_rate: probability for skipping the pos_emb part of the scores on + any given call to forward(), in training time. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + query_head_dim: int, + pos_dim: int, + dropout: float = 0.0, + pos_emb_skip_rate: FloatLike = ScheduledFloat((0.0, 0.5), + (4000.0, 0.0)) + ) -> None: + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.query_head_dim = query_head_dim + self.pos_dim = pos_dim + self.dropout = dropout + self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) + self.score_penalty = AbsValuePenalizer( + limit=25.0, penalty=1.0e-04, prob=0.1) + self.name = None # for diagnostics, will be set in train.py + + key_head_dim = query_head_dim + in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, + initial_scale=query_head_dim**-0.25) + + self.whiten_keys = Whiten(num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025) + + # add a balancer for the keys that runs with very small probability, and + # tries to enforce that all dimensions have mean around zero. The + # weights produced by this module are invariant to adding a constant to + # the keys, so the derivative of the bias is mathematically zero; but + # due to how Adam/ScaledAdam work, it can learn a fairly large nonzero + # bias because the small numerical roundoff tends to have a non-random + # sign. This module is intended to prevent that. Use a very small + # probability; that should be suffixient to fix the problem. + self.balance_keys = Balancer(key_head_dim * num_heads, + channel_dim=-1, + min_positive=0.4, + max_positive=0.6, + min_abs=0.0, + max_abs=100.0, + prob=0.025) + + + # the following are for diagnosics only, see --print-diagnostics option + self.copy_pos_query = Identity() + self.copy_query = Identity() + + + def forward( + self, + x: Tensor, + pos_emb: Tensor, + attn_offset: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + x: input of shape (seq_len, batch_size, embed_dim) + pos_emb: Positional embedding tensor, of shape (1, 2*seq_len - 2, pos_dim) + + attn_offset: a Tensor of shape broadcasting with (batch_size, seq_len, seq_len), + interpreted as (batch_size, tgt_seq_len, src_seq_len), if provided this + contains values (probably <= 0) to be added to the logprobs of the attention; + this may combine the log of 'weights' of ChooseDownsamplingModule with + any attn_mask that enforces causality. + pos_emb: a Tensor of shape broadcasting with (batch_size, seq_len, seq_len, pos_dim) + (e.g. pos_dim=4), encoding relative positions. + + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, seq_len, seq_len) + interpreted as (hum_heads, batch_size, tgt_seq_len, src_seq_len). + """ + x = self.in_proj(x) + query_head_dim = self.query_head_dim + pos_dim = self.pos_dim + num_heads = self.num_heads + + seq_len, batch_size, _ = x.shape + + query_dim = query_head_dim * num_heads + + q = x[...,0:query_dim] + k = x[...,query_dim:2*query_dim] + # p is the position-encoding query + p = x[...,2*query_dim:] + assert p.shape[-1] == num_heads * pos_dim + + + q = self.copy_query(q) # for diagnostics only, does nothing. + k = self.whiten_keys(self.balance_keys(k)) # does nothing in the forward pass. + p = self.copy_pos_query(p) # for diagnostics only, does nothing. + + + q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_dim) + k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) + + q = q.permute(2, 1, 0, 3) # (head, batch, tgt_seq_len, query_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, src_seq_len) + + # attn_scores: (num_heads, batch_size, tgt_seq_len, src_esq_len) + attn_scores = torch.matmul(q, k) + + if not self.training or random.random() >= float(self.pos_emb_skip_rate): + # pos_emb: (batch_size, tgt_seq_len, src_seq_len, pos_dim) + p = p.permute(1, 0, 3, 2) # (batch_size, tgt_seq_len, pos_dim, num_heads) + + pos_scores = torch.matmul(pos_emb, p) + # pos_scores: (batch_size, tgt_seq_len, src_seq_len, num_heads) + pos_scores = pos_scores.permute(3, 0, 1, 2) + # pos_scores: (num_heads, batch_size, tgt_seq_len, src_seq_len) + attn_scores = attn_scores + pos_scores + + attn_scores = self.score_penalty(attn_scores) + + # attn_offset includes key-padding mask and attention-mask, plus any weights + # from the subsampling. + attn_scores = attn_scores + attn_offset + + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if random.random() < 0.001: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + + def _print_attn_entropy( + self, + attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( + dim=-1).mean(dim=(1,2)) + logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}") + + +class Attention(nn.Module): + """ + The simplest possible attention module. This one works with already-computed attention + weights, e.g. as computed by RelPositionMultiheadAttentionWeights. + + Args: + embed_dim_in: the input embedding dimension + embed_dim_out: the output embedding dimension (normally the same as input) + num_heads: the number of attention heads + value_head_dim: the value dimension per head + """ + def __init__( + self, + embed_dim_in: int, + embed_dim_out: int, + num_heads: int, + value_head_dim: int, + ) -> None: + super().__init__() + self.in_proj = nn.Linear(embed_dim_in, + num_heads * value_head_dim, + bias=False) + + self.out_proj = ScaledLinear(num_heads * value_head_dim, + embed_dim_out, bias=False, + initial_scale=0.05) + + self.whiten = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(7.5, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01) + + + def forward( + self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """ + Args: + x: input tensor, of shape (seq_len, batch_size, embed_dim) + attn_weights: a tensor of shape (num_heads, batch_size, query_len, key_len), + Expect attn_weights.sum(dim=-1) == 1. + Returns: + a tensor with the same shape as x. + """ + (num_heads, batch_size, query_len, key_len) = attn_weights.shape + + x = self.in_proj(x) # (key_len, batch_size, num_heads * value_head_dim) + x = x.reshape(key_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, key_len, value_head_dim) + value_head_dim = x.shape[-1] + + # todo: see whether there is benefit in overriding matmul + x = torch.matmul(attn_weights, x) + # v: (num_heads, batch_size, query_len, value_head_dim) + + x = x.permute(2, 1, 0, 3).contiguous().view( + query_len, batch_size, num_heads * value_head_dim) + + # returned value is of shape (query_len, batch_size, embed_dim), like the input. + x = self.out_proj(x) + x = self.whiten(x) + + return x + + +class MultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head cross-attention weights. Allows src and target + to have different dims. + + Args: + key_embed_dim: number of channels of the thing that we'll project to + make the query (corresponds to source). e.g. 256 + query_embed_dim: number of channels of the thing that we'll project to + make the query (corresponds to target). e.g. 256 + num_heads: number of heads to compute weights for, e.g. 8 + head_dim: dimension of the query and key, per head. e.g. 24. + dropout: dropout probability for attn_output_weights. Default: 0.0. + """ + + def __init__( + self, + key_embed_dim: int, + query_embed_dim: int, + num_heads: int, + head_dim: int, + dropout: float = 0.0, + + ) -> None: + super().__init__() + self.key_embed_dim = key_embed_dim + self.query_embed_dim = query_embed_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.dropout = dropout + self.score_penalty = AbsValuePenalizer( + limit=25.0, penalty=1.0e-04, prob=0.1) + self.name = None # for diagnostics, will be set in train.py + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.query_in_proj = ScaledLinear(query_embed_dim, + head_dim * num_heads, + bias=True, + initial_scale=head_dim ** -0.25) + + # weights produced by this module are invariant to adding a constant to + # the keys, so we don't need a bias for the keys. + self.key_in_proj = ScaledLinear(key_embed_dim, + head_dim * num_heads, + bias=False, + initial_scale=head_dim ** -0.25) + + self.whiten_keys = Whiten(num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025) + + + + def forward( + self, + key: Tensor, + query: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + key: input of shape (key_len, batch_size, key_embed_dim) + query: input of shape (query_len, batch_size, query_embed_dim) + key_padding_mask: an optional bool tensor of shape (batch_size, key_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, query_len, key_len) + """ + q = self.query_in_proj(query) + k = self.key_in_proj(key) + + head_dim = self.head_dim + num_heads = self.num_heads + + query_len, batch_size, _ = q.shape + key_len, _batch_size, _ = k.shape + assert _batch_size == batch_size + + k = self.whiten_keys(k) # does nothing in the forward pass. + + q = q.reshape(query_len, batch_size, num_heads, head_dim) + k = k.reshape(key_len, batch_size, num_heads, head_dim) + + # tgt_seq_len refers to target, src_seq_len refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, tgt_seq_len, query_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, src_seq_len) + + attn_scores = torch.matmul(q, k) + + attn_scores = self.score_penalty(attn_scores) + + assert attn_scores.shape == (num_heads, batch_size, query_len, key_len) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, key_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if random.random() < 0.001: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + + def _print_attn_entropy( + self, + attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( + dim=-1).mean(dim=(1,2)) + logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}") + + + +class FeedforwardModule(nn.Module): + """Feedforward module in Subformer model. + """ + def __init__(self, + embed_dim: int, + feedforward_dim: int, + dropout: FloatLike): + super(FeedforwardModule, self).__init__() + self.in_proj = nn.Linear(embed_dim, feedforward_dim) + + self.hidden_balancer = Balancer(feedforward_dim, + channel_dim=-1, + min_positive=0.3, + max_positive=1.0, + min_abs=0.75, + max_abs=5.0) + + # shared_dim=0 means we share the dropout mask along the time axis + self.out_proj = ActivationDropoutAndLinear(feedforward_dim, embed_dim, + activation='SwooshL', + dropout_p=dropout, + dropout_shared_dim=0, bias=True, + initial_scale=0.1) + + self.out_whiten = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(7.5), + prob=(0.025, 0.25), + grad_scale=0.01) + + def forward(self, + x: Tensor): + x = self.in_proj(x) + x = self.hidden_balancer(x) + # out_proj contains SwooshL activation, then dropout, then linear. + x = self.out_proj(x) + x = self.out_whiten(x) + return x + + +class NonlinAttention(nn.Module): + """This is like the ConvolutionModule, but refactored so that we use multiplication by attention weights (borrowed + from the attention module) in place of actual convolution. We also took out the second nonlinearity, the + one after the attention mechanism. + + Args: + channels (int): The number of channels of conv layers. + """ + + def __init__( + self, + channels: int, + hidden_channels: int, + ) -> None: + super().__init__() + + self.hidden_channels = hidden_channels + + self.in_proj = nn.Linear(channels, hidden_channels * 3, bias=True) + + # balancer that goes before the sigmoid. Have quite a large min_abs value, at 2.0, + # because we noticed that well-trained instances of this module have abs-value before the sigmoid + # starting from about 3, and poorly-trained instances of the module have smaller abs values + # before the sigmoid. + self.balancer = Balancer( + hidden_channels, channel_dim=-1, + min_positive=ScheduledFloat((0.0, 0.25), (20000.0, 0.05)), + max_positive=ScheduledFloat((0.0, 0.75), (20000.0, 0.95)), + min_abs=0.5, + max_abs=5.0, + ) + self.tanh = nn.Tanh() + + self.identity1 = Identity() # for diagnostics. + self.identity2 = Identity() # for diagnostics. + + + # ensure the activations after multiplication don't get too large. + self.hidden_penalty = AbsValuePenalizer( + limit=40.0, penalty=1.0e-04, prob=0.1) + + self.out_proj = ScaledLinear(hidden_channels, channels, + bias=True, + initial_scale=0.05) + + + + self.whiten1 = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(5.0), + prob=(0.025, 0.25), + grad_scale=0.01) + + self.whiten2 = Whiten(num_groups=1, + whitening_limit=_whitening_schedule(5.0, ratio=3.0), + prob=(0.025, 0.25), + grad_scale=0.01) + + + def forward(self, + x: Tensor, + attn_weights: Tensor, + ) -> Tensor: + """. + Args: + x: a Tensor of shape (seq_len, batch_size, num_channels) +attn_weights: a Tensor of shape (num_heads, batch_size, seq_len, seq_len) + Returns: + a Tensor with the same shape as x + """ + num_channels = x.shape[-1] + x = self.in_proj(x) + + (seq_len, batch_size, _) = x.shape + hidden_channels = self.hidden_channels + + s, x, y = x.chunk(3, dim=-1) + + # s will go through tanh. + + s = self.balancer(s) + s = self.tanh(s) + + s = s.unsqueeze(-1).reshape(seq_len, batch_size, hidden_channels) + x = self.whiten1(x) + x = x * s + x = self.identity1(x) # diagnostics only, it's the identity. + + (seq_len, batch_size, embed_dim) = x.shape + num_heads = attn_weights.shape[0] + assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + + x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = torch.matmul(attn_weights, x) + # now x: (num_heads, batch_size, seq_len, head_dim) + x = x.permute(2, 1, 0, 3).reshape(seq_len, batch_size, -1) + + + y = self.identity2(y) + x = x * y + x = self.hidden_penalty(x) + + x = self.out_proj(x) + x = self.whiten2(x) + return x + + + +class ScalarMultiply(nn.Module): + def __init__(self, scale: float): + super().__init__() + self.scale = scale + + def forward(self, x): + return x * self.scale + + +def _test_zipformer_main(causal: bool = False): + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + memory_dim = 100 + + c = Subformer( + structure = "S(S)S" if causal else "S(S(S", + encoder_dim=(64, 96, 64), + num_heads=(4, 4, 8), + causal=causal, + memory_dim=memory_dim, + ) + batch_size = 5 + seq_len = 128 + # Just make sure the forward pass runs. + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + memory=torch.randn(101, batch_size, memory_dim), + ) + f[0].sum().backward() + c.eval() + f = c( + torch.randn(seq_len, batch_size, 64), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + _test_zipformer_main(False) + _test_zipformer_main(True) diff --git a/egs/librispeech/ASR/subformer/subsampling.py b/egs/librispeech/ASR/subformer/subsampling.py new file mode 120000 index 000000000..d178adc2e --- /dev/null +++ b/egs/librispeech/ASR/subformer/subsampling.py @@ -0,0 +1 @@ +../zipformer/subsampling.py \ No newline at end of file diff --git a/egs/librispeech/ASR/subformer/train.py b/egs/librispeech/ASR/subformer/train.py new file mode 100755 index 000000000..a81b39cf5 --- /dev/null +++ b/egs/librispeech/ASR/subformer/train.py @@ -0,0 +1,1395 @@ +#!/usr/bin/env python3 +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo, +# Zengwei Yao, +# Daniel Povey) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./zipformer/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir zipformer/exp \ + --full-libri 1 \ + --max-duration 550 + +""" + + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from zipformer import Zipformer2 +from subformer import Subformer +from mixformer import Mixformer +from scaling import ScheduledFloat +from decoder import Decoder +from joiner import Joiner +from subsampling import Conv2dSubsampling +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from model import Transducer +from optim import Eden, ScaledAdam +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.hooks import register_inf_check_hooks +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + MetricsTracker, + setup_logger, + str2bool, + get_parameter_groups_with_lrs +) + +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] + + +def get_adjusted_batch_count( + params: AttributeDict) -> float: + # returns the number of batches we would have used so far if we had used the reference + # duration. This is for purposes of set_batch_count(). + return (params.batch_idx_train * (params.max_duration * params.world_size) / + params.ref_duration) + + +def set_batch_count( + model: Union[nn.Module, DDP], batch_count: float +) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + for name, module in model.named_modules(): + if hasattr(module, 'batch_count'): + module.batch_count = batch_count + if hasattr(module, 'name'): + module.name = name + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--encoder-structure", + type=str, + default="ZS(S(S(S)S)S)SZ", + help="Structure of encoder, determines order of encoder stacks and (downsampling/upsampling) " + "operations." + ) + + parser.add_argument( + "--num-encoder-layers", + type=str, + default="2,2,4,4,8,4,4,2,2", + help="Number of zipformer encoder layers per stack, comma separated.", + ) + + parser.add_argument( + "--downsampling-factor", + type=str, + default="1,2,2,2,1", + help="Downsampling factor for each stack of encoder layers.", + ) + + parser.add_argument( + "--feedforward-dim", + type=str, + default="512,1024,1536,2048,3072,2048,1536,1024,512", + help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", + ) + + parser.add_argument( + "--num-heads", + type=str, + default="4,4,4,8,16,8,4,4,4", + help="Number of attention heads in the zipformer encoder layers: a single int or comma-separated list.", + ) + + parser.add_argument( + "--encoder-dim", + type=str, + default="192,256,384,512,768,512,384,256,192", + help="Embedding dimension in encoder stacks: a single int or comma-separated list." + ) + + parser.add_argument( + "--encoder-chunk-sizes", + type=str, + default="128,1024", + help="Base chunk size for attention in encoder stacks; alternate layers will use this value or " + "double this value." + ) + + + parser.add_argument( + "--query-head-dim", + type=str, + default="32", + help="Query/key dimension per head in encoder stacks: a single int or comma-separated list." + ) + + parser.add_argument( + "--value-head-dim", + type=str, + default="16", + help="Value dimension per head in encoder stacks: a single int or comma-separated list." + ) + + parser.add_argument( + "--pos-head-dim", + type=str, + default="4", + help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list." + ) + + parser.add_argument( + "--pos-dim", + type=str, + default="48", + help="Positional-encoding embedding dimension" + ) + + parser.add_argument( + "--encoder-unmasked-dim", + type=str, + default="192,192", + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "A single int or comma-separated list. Must be <= each corresponding encoder_dim." + ) + + parser.add_argument( + "--cnn-module-kernel", + type=str, + default="31,31", + help="Sizes of convolutional kernels in convolution modules in each encoder stack: " + "a single int or comma-separated list.", + ) + + parser.add_argument( + "--decoder-dim", + type=int, + default=512, + help="Embedding dimension in the decoder model.", + ) + + parser.add_argument( + "--joiner-dim", + type=int, + default=512, + help="""Dimension used in the joiner model. + Outputs from the encoder and decoder model are projected + to this dimension before adding. + """, + ) + + parser.add_argument( + "--causal", + type=str2bool, + default=False, + help="If True, use causal version of model.", + ) + + parser.add_argument( + "--chunk-size", + type=str, + default="16,32,64,-1", + help="Chunk sizes (at 50Hz frame rate) will be chosen randomly from this list during training. " + " Must be just -1 if --causal=False" + ) + + parser.add_argument( + "--left-context-frames", + type=str, + default="64,128,256,-1", + help="Maximum left-contexts for causal training, measured in frames which will " + "be converted to a number of chunks. If splitting into chunks, " + "chunk left-context frames will be chosen randomly from this list; else not relevant." + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=30, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="zipformer/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--base-lr", + type=float, + default=0.045, + help="The base learning rate." + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=7500, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=3.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--warmup-start", + type=float, + default=0.5, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--warmup-batches", + type=int, + default=500, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--ref-duration", + type=float, + default=600, + help="Reference batch duration for purposes of adjusting batch counts for setting various " + "schedules inside the model" + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=30, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=200, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - encoder_dim: Hidden dim for multi-head attention model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warmup period that dictates the decay of the + scale on "simple" (un-pruned) loss. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for zipformer + "feature_dim": 80, + "subsampling_factor": 4, # not passed in, this is fixed. + "warm_step": 2000, + "env_info": get_env_info(), + } + ) + + return params + + +def _to_int_tuple(s: str): + return tuple(map(int, s.split(','))) + + +def get_encoder_embed(params: AttributeDict) -> nn.Module: + # encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, (T - 7) // 2, encoder_dims). + # That is, it does two things simultaneously: + # (1) subsampling: T -> (T - 7) // 2 + # (2) embedding: num_features -> encoder_dims + # In the normal configuration, we will downsample once more at the end + # by a factor of 2, and most of the encoder stacks will run at a lower + # sampling rate. + encoder_embed = Conv2dSubsampling( + in_channels=params.feature_dim, + out_channels=_to_int_tuple(params.encoder_dim)[0], + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)) + ) + return encoder_embed + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = Mixformer( + structure=params.encoder_structure, + downsampling_factor=_to_int_tuple(params.downsampling_factor), + encoder_dim=_to_int_tuple(params.encoder_dim), + num_encoder_layers=_to_int_tuple(params.num_encoder_layers), + encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim), + encoder_chunk_sizes=(_to_int_tuple(params.encoder_chunk_sizes),), + query_head_dim=_to_int_tuple(params.query_head_dim), + value_head_dim=_to_int_tuple(params.value_head_dim), + num_heads=_to_int_tuple(params.num_heads), + pos_head_dim=_to_int_tuple(params.pos_head_dim), + pos_dim=_to_int_tuple(params.pos_dim), + feedforward_dim=_to_int_tuple(params.feedforward_dim), + cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=params.causal, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=max(_to_int_tuple(params.encoder_dim)), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder_embed = get_encoder_embed(params) + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder_embed=encoder_embed, + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=int(max(params.encoder_dim.split(','))), + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + if "cur_batch_idx" in saved_params: + params["cur_batch_idx"] = saved_params["cur_batch_idx"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute CTC loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + batch_idx_train = params.batch_idx_train + warm_step = params.warm_step + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + ) + + s = params.simple_loss_scale + # take down the scale on the simple loss from 1.0 at the start + # to params.simple_loss scale by warm_step. + simple_loss_scale = ( + s if batch_idx_train >= warm_step + else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) + ) + pruned_loss_scale = ( + 1.0 if batch_idx_train >= warm_step + else 0.1 + 0.9 * (batch_idx_train / warm_step) + ) + + loss = ( + simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + cur_batch_idx = params.get("cur_batch_idx", 0) + + saved_bad_model = False + + def save_bad_model(suffix: str = ""): + save_checkpoint_impl(filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=0) + + for batch_idx, batch in enumerate(train_dl): + if batch_idx % 10 == 0: + set_batch_count(model, get_adjusted_batch_count(params)) + if batch_idx < cur_batch_idx: + continue + cur_batch_idx = batch_idx + + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + save_bad_model() + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 5: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + params.cur_batch_idx = batch_idx + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + del params.cur_batch_idx + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.use_fp16: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + + if cur_grad_scale < 8.0 or (cur_grad_scale < 32.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + if cur_grad_scale < 0.01: + if not saved_bad_model: + save_bad_model(suffix="-first-warning") + saved_bad_model = True + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + save_bad_model() + raise RuntimeError(f"grad_scale is too small, exiting: {cur_grad_scale}") + + if batch_idx % params.log_interval == 0: + cur_lr = max(scheduler.get_last_lr()) + cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}, " + + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + if params.use_fp16: + tb_writer.add_scalar( + "train/grad_scale", cur_grad_scale, params.batch_idx_train + ) + + if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], + find_unused_parameters=True) + + optimizer = ScaledAdam( + get_parameter_groups_with_lrs( + model, lr=params.base_lr, include_names=True), + lr=params.base_lr, # should have no effect + clipping_scale=2.0, + ) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs, warmup_batches=params.warmup_batches, warmup_start=params.warmup_start) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.print_diagnostics: + opts = diagnostics.TensorDiagnosticOptions( + 2 ** 22 + ) # allow 4 megabytes per sub-module + diagnostic = diagnostics.attach_diagnostics(model, opts) + + if params.inf_check: + register_inf_check_hooks(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./zipformer.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 7) // 2 + 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + pass + #scan_pessimistic_batches_for_oom( + # model=model, + # train_dl=train_dl, + # optimizer=optimizer, + # sp=sp, + # params=params, + #) + + scaler = GradScaler(enabled=params.use_fp16, + init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, + sp: spm.SentencePieceProcessor, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + sp: + The BPE model. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + supervisions = batch["supervisions"] + features = batch["inputs"] + + logging.info(f"features shape: {features.shape}") + + y = sp.encode(supervisions["text"], out_type=int) + num_tokens = sum(len(i) for i in y) + logging.info(f"num tokens: {num_tokens}") + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params, sp=sp) + raise + logging.info(f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB") + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/subformer/zipformer.py b/egs/librispeech/ASR/subformer/zipformer.py new file mode 120000 index 000000000..a064749a4 --- /dev/null +++ b/egs/librispeech/ASR/subformer/zipformer.py @@ -0,0 +1 @@ +../zipformer/zipformer.py \ No newline at end of file diff --git a/egs/librispeech/ASR/zipformer/scaling.py b/egs/librispeech/ASR/zipformer/scaling.py index 908b60938..c5aa8c470 100644 --- a/egs/librispeech/ASR/zipformer/scaling.py +++ b/egs/librispeech/ASR/zipformer/scaling.py @@ -274,6 +274,24 @@ def softmax(x: Tensor, dim: int): return SoftmaxFunction.apply(x, dim) +class ClipGradFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: Tensor, + limit: float): + ctx.limit = limit + return x + + @staticmethod + def backward(ctx, x_grad, *args): + return x_grad.clamp(-ctx.limit, ctx.limit), None + + +def clip_grad(x: Tensor, limit: float): + return ClipGradFunction.apply(x, limit) + + class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( @@ -875,6 +893,40 @@ def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float, return x +class AbsValuePenalizer(nn.Module): + """ + This module adds a penalty to the loss function when ever the absolute value of + any element of the input tensor exceeds a certain limit. + """ + def __init__(self, + limit: float, + prob: float = 0.1, + penalty: float = 1.0e-04): + super().__init__() + self.limit = limit + self.penalty = penalty + + self.prob = prob + self.name = None # will be set in training loop + + # 20% of the time we will return and do nothing because memory usage is + # too high. + self.mem_cutoff = CutoffEstimator(0.2) + + def forward(self, x: Tensor) -> Tensor: + if (torch.jit.is_scripting() or not x.requires_grad + or not self.training + or random.random() > self.prob): + # or (x.is_cuda and self.mem_cutoff(torch.cuda.memory_allocated())) + return _no_op(x) # the _no_op op is to make our diagnostics code work. + + x = penalize_abs_values_gt(x, + limit=self.limit, + penalty=self.penalty, + name=self.name) + return x + + def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. if x.ndim == 2: return x.diag()