From 42daafee4e2634c841fa3a1778e053b1cd3f09fb Mon Sep 17 00:00:00 2001 From: Desh Raj Date: Sun, 11 Jun 2023 16:32:29 -0400 Subject: [PATCH] clean commit for SURT recipe --- .../asr_datamodule.py | 372 --- .../beam_search.py | 885 ----- .../decode.py | 770 ----- .../decode_libricss.py | 791 ----- .../decode_stream.py | 151 - .../decoder.py | 102 - .../dprnn.py | 304 -- .../encoder_interface.py | 43 - .../export.py | 320 -- .../joiner.py | 65 - .../model.py | 304 -- .../optim.py | 1061 ------ .../pretrained.py | 355 -- .../scaling.py | 1533 --------- .../scaling_converter.py | 114 - .../streaming_beam_search.py | 282 -- .../streaming_decode.py | 615 ---- .../test_model.py | 150 - .../train.py | 1346 -------- .../zipformer.py | 2881 ----------------- .../SURT/local/compute_fbank_libricss.py | 6 + .../SURT/local/compute_fbank_librimix.py | 115 - .../SURT/local/compute_fbank_librispeech.py | 12 +- .../SURT/local/compute_fbank_lsmix.py | 188 ++ egs/libricss/SURT/prepare.sh | 148 +- 25 files changed, 267 insertions(+), 12646 deletions(-) delete mode 100644 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/asr_datamodule.py delete mode 100644 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/beam_search.py delete mode 100755 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode.py delete mode 100755 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode_libricss.py delete mode 100644 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode_stream.py delete mode 100644 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decoder.py delete mode 100644 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/dprnn.py delete mode 100644 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/encoder_interface.py delete mode 100755 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/export.py delete mode 100644 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/joiner.py delete mode 100644 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/model.py delete mode 100644 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/optim.py delete mode 100755 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/pretrained.py delete mode 100644 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/scaling.py delete mode 100644 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/scaling_converter.py delete mode 100644 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/streaming_beam_search.py delete mode 100755 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/streaming_decode.py delete mode 100755 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/test_model.py delete mode 100755 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/train.py delete mode 100644 egs/libricss/SURT/dprnn_pruned_transducer_stateless7/zipformer.py delete mode 100755 egs/libricss/SURT/local/compute_fbank_librimix.py create mode 100755 egs/libricss/SURT/local/compute_fbank_lsmix.py diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/asr_datamodule.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/asr_datamodule.py deleted file mode 100644 index f6f56cc6f..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/asr_datamodule.py +++ /dev/null @@ -1,372 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Callable, Dict, List, Optional - -import torch -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures - CutMix, - DynamicBucketingSampler, - K2SurtDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class LibrimixAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/manifests"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--max-cuts", - type=int, - default=100, - help="Maximum number of cuts in a single batch. You can " - "reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SurtDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - quadratic_duration=30.0, - max_cuts=self.args.max_cuts, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SingleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - max_cuts=self.args.max_cuts, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - - logging.info("About to create dev dataset") - validate = K2SurtDataset( - input_strategy=OnTheFlyFeatures( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - ) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - max_cuts=self.args.max_cuts, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.debug("About to create test dataset") - test = K2SurtDataset( - input_strategy=OnTheFlyFeatures( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - ) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - max_cuts=self.args.max_cuts, - shuffle=False, - ) - logging.debug("About to create test dataloader") - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self, reverberated: bool = False) -> CutSet: - logging.info("About to get train cuts") - rvb_affix = "_rvb" if reverberated else "_norvb" - cs = load_manifest_lazy( - self.args.manifest_dir / f"cuts_train{rvb_affix}_v1.jsonl.gz" - ) - # Trim to supervision groups - cs = cs.trim_to_supervision_groups(max_pause=1.0) - cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0) - return cs - - @lru_cache() - def dev_cuts(self, reverberated: bool = False) -> CutSet: - logging.info("About to get dev cuts") - rvb_affix = "_rvb" if reverberated else "_norvb" - cs = load_manifest_lazy( - self.args.manifest_dir / f"cuts_dev{rvb_affix}_v1.jsonl.gz" - ) - cs = cs.filter(lambda c: c.duration >= 0.1) - return cs - - @lru_cache() - def train_cuts_2spk(self, reverberated: bool = False) -> CutSet: - logging.info("About to get 2-spk train cuts") - rvb_affix = "_rvb" if reverberated else "_norvb" - cs = load_manifest_lazy( - self.args.manifest_dir / f"cuts_train_2spk{rvb_affix}.jsonl.gz" - ) - cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0) - return cs - - @lru_cache() - def libricss_cuts(self, split="dev", type="sdm") -> CutSet: - logging.info(f"About to get LibriCSS {split} {type} cuts") - cs = load_manifest_lazy( - self.args.manifest_dir / f"cuts_{split}_libricss-{type}.jsonl.gz" - ) - return cs diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/beam_search.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/beam_search.py deleted file mode 100644 index 5db5db996..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/beam_search.py +++ /dev/null @@ -1,885 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple, Union - -import k2 -import sentencepiece as spm -import torch -from model import SURT - -from icefall import NgramLm, NgramLmStateCost -from icefall.decode import Nbest, one_best_decoding -from icefall.lm_wrapper import LmScorer -from icefall.utils import ( - DecodingResults, - add_eos, - add_sos, - get_texts, - get_texts_with_timestamp, -) - - -def fast_beam_search_one_best( - model: SURT, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using fast beam search, and then - the shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `SURT`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - best_path = one_best_decoding(lattice) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search( - model: SURT, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - temperature: float = 1.0, -) -> k2.Fsa: - """It limits the maximum number of symbols per frame to 1. - - Args: - model: - An instance of `SURT`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - temperature: - Softmax temperature. - Returns: - Return an FsaVec with axes [utt][state][arc] containing the decoded - lattice. Note: When the input graph is a TrivialGraph, the returned - lattice is actually an acceptor. - """ - assert encoder_out.ndim == 3 - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - B, T, C = encoder_out.shape - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(k2.RnntDecodingStream(decoding_graph)) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = (logits / temperature).log_softmax(dim=-1) - decoding_streams.advance(log_probs) - decoding_streams.terminate_and_flush_to_streams() - lattice = decoding_streams.format_output(encoder_out_lens.tolist()) - - return lattice - - -def greedy_search( - model: SURT, - encoder_out: torch.Tensor, - max_sym_per_frame: int, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """Greedy search for a single utterance. - Args: - model: - An instance of `SURT`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - max_sym_per_frame: - Maximum number of symbols per frame. If it is set to 0, the WER - would be 100%. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 4 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - unk_id = getattr(model, "unk_id", blank_id) - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - hyp = [blank_id] * context_size - - # timestamp[i] is the frame index after subsampling - # on which hyp[i] is decoded - timestamp = [] - - # Maximum symbols per utterance. - max_sym_per_utt = 1000 - - # symbols per frame - sym_per_frame = 0 - - # symbols per utterance decoded so far - sym_per_utt = 0 - - while t < T and sym_per_utt < max_sym_per_utt: - if sym_per_frame >= max_sym_per_frame: - sym_per_frame = 0 - t += 1 - continue - - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits is (1, 1, 1, vocab_size) - - y = logits.argmax().item() - if y not in (blank_id, unk_id): - hyp.append(y) - timestamp.append(t) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sym_per_utt += 1 - sym_per_frame += 1 - else: - sym_per_frame = 0 - t += 1 - hyp = hyp[context_size:] # remove blanks - - if not return_timestamps: - return hyp - else: - return DecodingResults( - hyps=[hyp], - timestamps=[timestamp], - ) - - -def greedy_search_batch( - model: SURT, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The SURT model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = next(model.parameters()).device - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)] - - # timestamp[n][i] is the frame index after subsampling - # on which hyp[n][i] is decoded - timestamps = [[] for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out: (N, 1, decoder_out_dim) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits'shape (batch_size, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v not in (blank_id, unk_id): - hyps[i].append(v) - timestamps[i].append(t) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - ) - - -def modified_beam_search( - model: SURT, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The SURT model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[blank_id] * context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for (t, batch_size) in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - ) - - -def beam_search( - model: SURT, - encoder_out: torch.Tensor, - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_SURT.py#L247 is used as a reference. - - Args: - model: - An instance of `SURT`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [blank_id] * context_size, - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - - B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[])) - - max_sym_per_utt = 20000 - - sym_per_utt = 0 - - decoder_cache: Dict[str, torch.Tensor] = {} - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - A = B - B = HypothesisList() - - joint_cache: Dict[str, torch.Tensor] = {} - - # TODO(fangjun): Implement prefix search to update the `log_prob` - # of hypotheses in A - - while True: - y_star = A.get_most_probable() - A.remove(y_star) - - cached_key = y_star.key - - if cached_key not in decoder_cache: - decoder_input = torch.tensor( - [y_star.ys[-context_size:]], - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - decoder_cache[cached_key] = decoder_out - else: - decoder_out = decoder_cache[cached_key] - - cached_key += f"-t-{t}" - if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, - decoder_out.unsqueeze(1), - project_input=False, - ) - - # TODO(fangjun): Scale the blank posterior - log_prob = (logits / temperature).log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - joint_cache[cached_key] = log_prob - else: - log_prob = joint_cache[cached_key] - - # First, process the blank symbol - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob - - # ys[:] returns a copy of ys - B.add( - Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - timestamp=y_star.timestamp[:], - ) - ) - - # Second, process other non-blank labels - values, indices = log_prob.topk(beam + 1) - for i, v in zip(indices.tolist(), values.tolist()): - if i in (blank_id, unk_id): - continue - new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - new_timestamp = y_star.timestamp + [t] - A.add( - Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - timestamp=new_timestamp, - ) - ) - - # Check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = A.get_most_probable() - - kept_B = B.filter(A_most_probable.log_prob) - - if len(kept_B) >= beam: - B = kept_B.topk(beam) - break - - t += 1 - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - if not return_timestamps: - return ys - else: - return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) - - -@dataclass -class Hypothesis: - # The predicted tokens so far. - # Newly predicted tokens are appended to `ys`. - ys: List[int] - - # The log prob of ys. - # It contains only one entry. - log_prob: torch.Tensor - - # timestamp[i] is the frame index after subsampling - # on which ys[i] is decoded - timestamp: List[int] = field(default_factory=list) - - # the lm score for next token given the current ys - lm_score: Optional[torch.Tensor] = None - - # the RNNLM states (h and c in LSTM) - state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - - # N-gram LM state - state_cost: Optional[NgramLmStateCost] = None - - @property - def key(self) -> str: - """Return a string representation of self.ys""" - return "_".join(map(str, self.ys)) - - -class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: - """ - Args: - data: - A dict of Hypotheses. Its key is its `value.key`. - """ - if data is None: - self._data = {} - else: - self._data = data - - @property - def data(self) -> Dict[str, Hypothesis]: - return self._data - - def add(self, hyp: Hypothesis) -> None: - """Add a Hypothesis to `self`. - - If `hyp` already exists in `self`, its probability is updated using - `log-sum-exp` with the existed one. - - Args: - hyp: - The hypothesis to be added. - """ - key = hyp.key - if key in self: - old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) - else: - self._data[key] = hyp - - def get_most_probable(self, length_norm: bool = False) -> Hypothesis: - """Get the most probable hypothesis, i.e., the one with - the largest `log_prob`. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - Returns: - Return the hypothesis that has the largest `log_prob`. - """ - if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) - else: - return max(self._data.values(), key=lambda hyp: hyp.log_prob) - - def remove(self, hyp: Hypothesis) -> None: - """Remove a given hypothesis. - - Caution: - `self` is modified **in-place**. - - Args: - hyp: - The hypothesis to be removed from `self`. - Note: It must be contained in `self`. Otherwise, - an exception is raised. - """ - key = hyp.key - assert key in self, f"{key} does not exist" - del self._data[key] - - def filter(self, threshold: torch.Tensor) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. - - Caution: - `self` is not modified. Instead, a new HypothesisList is returned. - - Returns: - Return a new HypothesisList containing all hypotheses from `self` - with `log_prob` being greater than the given `threshold`. - """ - ans = HypothesisList() - for _, hyp in self._data.items(): - if hyp.log_prob > threshold: - ans.add(hyp) # shallow copy - return ans - - def topk(self, k: int) -> "HypothesisList": - """Return the top-k hypothesis.""" - hyps = list(self._data.items()) - - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] - - ans = HypothesisList(dict(hyps)) - return ans - - def __contains__(self, key: str): - return key in self._data - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self) -> int: - return len(self._data) - - def __str__(self) -> str: - s = [] - for key in self: - s.append(key) - return ", ".join(s) - - -def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: - """Return a ragged shape with axes [utt][num_hyps]. - - Args: - hyps: - len(hyps) == batch_size. It contains the current hypothesis for - each utterance in the batch. - Returns: - Return a ragged shape with 2 axes [utt][num_hyps]. Note that - the shape is on CPU. - """ - num_hyps = [len(h) for h in hyps] - - # torch.cumsum() is inclusive sum, so we put a 0 at the beginning - # to get exclusive sum later. - num_hyps.insert(0, 0) - - num_hyps = torch.tensor(num_hyps) - row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) - ans = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=row_splits[-1].item() - ) - return ans diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode.py deleted file mode 100755 index a2974b9a8..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode.py +++ /dev/null @@ -1,770 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./conv_lstm_transducer_stateless_ctc/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./conv_lstm_transducer_stateless_ctc/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./conv_lstm_transducer_stateless_ctc/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./conv_lstm_transducer_stateless_ctc/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./conv_lstm_transducer_stateless_ctc/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./conv_lstm_transducer_stateless_ctc/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./conv_lstm_transducer_stateless_ctc/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./conv_lstm_transducer_stateless_ctc/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -from collections import defaultdict -from itertools import chain, groupby, repeat -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibrimixAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.utils import EPSILON -from train import add_model_arguments, get_params, get_surt_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_surt_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conv_lstm_transducer_stateless_ctc/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--save-masks", - type=str2bool, - default=False, - help="""If true, save masks generated by unmixing module.""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - feature_lens = batch["input_lens"].to(device) - - # Apply the mask encoder - B, T, F = feature.shape - processed = model.mask_encoder(feature) # B,T,F*num_channels - masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1) - x_masked = [feature * m for m in masks] - - # To save the masks, we split them by batch and trim each mask to the length of - # the corresponding feature. We save them in a dict, where the key is the - # cut ID and the value is the mask. - masks_dict = {} - for i in range(B): - mask = torch.cat( - [x_masked[j][i, : feature_lens[i]] for j in range(params.num_channels)], - dim=-1, - ) - mask = mask.cpu().numpy() - masks_dict[batch["cuts"][i].id] = mask - - # Recognition - # Stack the inputs along the batch axis - h = torch.cat(x_masked, dim=0) - h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0) - encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens) - - hyps = [] - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp).split()) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps}, masks_dict - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps}, masks_dict - else: - return {f"beam_size_{params.beam_size}": hyps}, masks_dict - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - masks = {} - for batch_idx, batch in enumerate(dl): - # The dataloader returns text as a list of cuts, each of which is a list of channel - # text. We flatten this to a list where all channels are together, i.e., it looks like - # [utt1_ch1, utt2_ch1, ..., uttN_ch1, utt1_ch2, ...., uttN,ch2]. - texts = [val for tup in zip(*batch["text"]) for val in tup] - cut_ids = [cut.id for cut in batch["cuts"]] - - # Repeat cut_ids list N times, where N is the number of channels. - cut_ids = list(chain.from_iterable(repeat(cut_ids, params.num_channels))) - - hyps_dict, masks_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - masks.update(masks_dict) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts), f"{len(hyps)} vs {len(texts)}" - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results, masks - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - # Combine results by cut_id. This means that we combine different channels for - # ref and hyp of the same cut into list. Example: - # (cut1, ref1, hyp1), (cut1, ref2, hyp2), (cut2, ref3, hyp3) -> - # (cut1, [ref1, ref2], [hyp1, hyp2]), (cut2, [ref3], [hyp3]) - # Also, each ref and hyp is currently a list of words. We join them into a string. - results_grouped = [] - for cut_id, items in groupby(results, lambda x: x[0]): - items = list(items) - refs = [" ".join(item[1]) for item in items] - hyps = [" ".join(item[2]) for item in items] - results_grouped.append((cut_id, refs, hyps)) - - store_transcripts(filename=recog_path, texts=results_grouped) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_surt_error_stats( - f, f"{test_set_name}-{key}", results_grouped, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -def save_masks( - params: AttributeDict, - test_set_name: str, - masks: List[torch.Tensor], -): - masks_path = params.res_dir / f"masks-{test_set_name}.txt" - torch.save(masks, masks_path) - logging.info(f"The masks are stored in {masks_path}") - - -@torch.no_grad() -def main(): - parser = get_parser() - LibrimixAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_surt_model(params) - assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( - model.encoder.decode_chunk_size, - params.decode_chunk_len, - ) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - librimix = LibrimixAsrDataModule(args) - - dev_cuts = librimix.dev_cuts(reverberated=False) - dev_dl = librimix.test_dataloaders(dev_cuts) - - test_sets = ["dev"] - test_dl = [dev_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict, masks = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - if params.save_masks: - save_masks( - params=params, - test_set_name=test_set, - masks=masks, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode_libricss.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode_libricss.py deleted file mode 100755 index 17a7ac8a6..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode_libricss.py +++ /dev/null @@ -1,791 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./conv_lstm_transducer_stateless_ctc/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./conv_lstm_transducer_stateless_ctc/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./conv_lstm_transducer_stateless_ctc/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./conv_lstm_transducer_stateless_ctc/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./conv_lstm_transducer_stateless_ctc/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./conv_lstm_transducer_stateless_ctc/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search (one best) -./conv_lstm_transducer_stateless_ctc/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./conv_lstm_transducer_stateless_ctc/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -from collections import defaultdict -from itertools import chain, groupby, repeat -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibrimixAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from lhotse.utils import EPSILON -from train import add_model_arguments, get_params, get_surt_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_surt_error_stats, -) - -OVERLAP_RATIOS = ["0L", "0S", "OV10", "OV20", "OV30", "OV40"] - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conv_lstm_transducer_stateless_ctc/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_bpe_500", - help="The lang dir containing word table and LG graph", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--save-masks", - type=str2bool, - default=False, - help="""If true, save masks generated by unmixing module.""", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - feature_lens = batch["input_lens"].to(device) - - # Apply the mask encoder - B, T, F = feature.shape - processed = model.mask_encoder(feature) # B,T,F*num_channels - masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1) - x_masked = [feature * m for m in masks] - - # Recognition - # Stack the inputs along the batch axis - h = torch.cat(x_masked, dim=0) - h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0) - encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens) - - def _group_channels(hyps: List[str]) -> List[List[str]]: - """ - Currently we have a batch of size M*B, where M is the number of - channels and B is the batch size. We need to group the hypotheses - into B groups, each of which contains M hypotheses. - - Example: - hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2'] - _group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']] - """ - assert len(hyps) == B * params.num_channels - out_hyps = [] - for i in range(B): - out_hyps.append(hyps[i::B]) - return out_hyps - - hyps = [] - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.decode(hyp)) - - if params.decoding_method == "greedy_search": - return {"greedy_search": _group_channels(hyps)} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: _group_channels(hyps)} - else: - return {f"beam_size_{params.beam_size}": _group_channels(hyps)} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - cut_ids = [cut.id for cut in batch["cuts"]] - cuts_batch = batch["cuts"] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - for cut_id, hyp_words in zip(cut_ids, hyps): - # Reference is a list of supervision texts sorted by start time. - ref_words = [ - s.text.strip() - for s in sorted( - cuts_batch[cut_id].supervisions, key=lambda s: s.start - ) - ] - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(cut_ids) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_surt_error_stats( - f, - f"{test_set_name}-{key}", - results, - enable_log=True, - num_channels=params.num_channels, - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -def save_masks( - params: AttributeDict, - test_set_name: str, - masks: List[torch.Tensor], -): - masks_path = params.res_dir / f"masks-{test_set_name}.txt" - torch.save(masks, masks_path) - logging.info(f"The masks are stored in {masks_path}") - - -@torch.no_grad() -def main(): - parser = get_parser() - LibrimixAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and are defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_surt_model(params) - assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( - model.encoder.decode_chunk_size, - params.decode_chunk_len, - ) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - word_table = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - librimix = LibrimixAsrDataModule(args) - - dev_cuts = librimix.libricss_cuts(split="dev", type="ihm-mix").to_eager() - dev_cuts_grouped = [dev_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS] - test_cuts = librimix.libricss_cuts(split="test", type="ihm-mix").to_eager() - test_cuts_grouped = [ - test_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS - ] - - for dev_set, ol in zip(dev_cuts_grouped, OVERLAP_RATIOS): - dev_dl = librimix.test_dataloaders(dev_set) - results_dict = decode_dataset( - dl=dev_dl, - params=params, - model=model, - sp=sp, - word_table=word_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=f"dev_{ol}", - results_dict=results_dict, - ) - - # if params.save_masks: - # save_masks( - # params=params, - # test_set_name=f"dev_{ol}", - # masks=masks, - # ) - - # for test_set, ol in zip(test_cuts_grouped, OVERLAP_RATIOS): - # test_dl = librimix.test_dataloaders(test_set) - # results_dict = decode_dataset( - # dl=test_dl, - # params=params, - # model=model, - # sp=sp, - # word_table=word_table, - # decoding_graph=decoding_graph, - # ) - - # save_results( - # params=params, - # test_set_name=f"test_{ol}", - # results_dict=results_dict, - # ) - - # if params.save_masks: - # save_masks( - # params=params, - # test_set_name=f"test_{ol}", - # masks=masks, - # ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode_stream.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode_stream.py deleted file mode 100644 index 0d7e86fcf..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decode_stream.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Wei Kang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math -from typing import List, Optional, Tuple - -import k2 -import torch -from beam_search import Hypothesis, HypothesisList - -from icefall.utils import AttributeDict - - -class DecodeStream(object): - def __init__( - self, - params: AttributeDict, - cut_id: str, - initial_states: List[torch.Tensor], - decoding_graph: Optional[k2.Fsa] = None, - device: torch.device = torch.device("cpu"), - ) -> None: - """ - Args: - initial_states: - Initial decode states of the model, e.g. the return value of - `get_init_state` in conformer.py - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a HLG. - Used only when decoding_method is fast_beam_search. - device: - The device to run this stream. - """ - if params.decoding_method == "fast_beam_search": - assert decoding_graph is not None - assert device == decoding_graph.device - - self.params = params - self.cut_id = cut_id - self.LOG_EPS = math.log(1e-10) - - self.states = initial_states - - # It contains a 2-D tensors representing the feature frames. - self.features: torch.Tensor = None - - self.num_frames: int = 0 - # how many frames have been processed. (before subsampling). - # we only modify this value in `func:get_feature_frames`. - self.num_processed_frames: int = 0 - - self._done: bool = False - - # The transcript of current utterance. - self.ground_truth: str = "" - - # The decoding result (partial or final) of current utterance. - self.hyp: List = [] - - # how many frames have been processed, after subsampling (i.e. a - # cumulative sum of the second return value of - # encoder.streaming_forward - self.done_frames: int = 0 - - # It has two steps of feature subsampling in zipformer: out_lens=((x_lens-7)//2+1)//2 - # 1) feature embedding: out_lens=(x_lens-7)//2 - # 2) output subsampling: out_lens=(out_lens+1)//2 - self.pad_length = 7 - - if params.decoding_method == "greedy_search": - self.hyp = [params.blank_id] * params.context_size - elif params.decoding_method == "modified_beam_search": - self.hyps = HypothesisList() - self.hyps.add( - Hypothesis( - ys=[params.blank_id] * params.context_size, - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - ) - ) - elif params.decoding_method == "fast_beam_search": - # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( - decoding_graph - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - @property - def done(self) -> bool: - """Return True if all the features are processed.""" - return self._done - - @property - def id(self) -> str: - return self.cut_id - - def set_features( - self, - features: torch.Tensor, - tail_pad_len: int = 0, - ) -> None: - """Set features tensor of current utterance.""" - assert features.dim() == 2, features.dim() - self.features = torch.nn.functional.pad( - features, - (0, 0, 0, self.pad_length + tail_pad_len), - mode="constant", - value=self.LOG_EPS, - ) - self.num_frames = self.features.size(0) - - def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: - """Consume chunk_size frames of features""" - chunk_length = chunk_size + self.pad_length - - ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) - - ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames + ret_length # noqa - ] - - self.num_processed_frames += chunk_size - if self.num_processed_frames >= self.num_frames: - self._done = True - - return ret_features, ret_length - - def decoding_result(self) -> List[int]: - """Obtain current decoding result.""" - if self.params.decoding_method == "greedy_search": - return self.hyp[self.params.context_size :] # noqa - elif self.params.decoding_method == "modified_beam_search": - best_hyp = self.hyps.get_most_probable(length_norm=True) - return best_hyp.ys[self.params.context_size :] # noqa - else: - assert self.params.decoding_method == "fast_beam_search" - return self.hyp diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decoder.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decoder.py deleted file mode 100644 index 5f90e6375..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/decoder.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class Decoder(nn.Module): - """This class modifies the stateless decoder from the following paper: - - RNN-transducer with stateless prediction network - https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - - It removes the recurrent connection from the decoder, i.e., the prediction - network. Different from the above paper, it adds an extra Conv1d - right after the embedding layer. - - TODO: Implement https://arxiv.org/pdf/2109.07513.pdf - """ - - def __init__( - self, - vocab_size: int, - decoder_dim: int, - blank_id: int, - context_size: int, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - decoder_dim: - Dimension of the input embedding, and of the decoder output. - blank_id: - The ID of the blank symbol. - context_size: - Number of previous words to use to predict the next word. - 1 means bigram; 2 means trigram. n means (n+1)-gram. - """ - super().__init__() - - self.embedding = nn.Embedding( - num_embeddings=vocab_size, - embedding_dim=decoder_dim, - padding_idx=blank_id, - ) - self.blank_id = blank_id - - assert context_size >= 1, context_size - self.context_size = context_size - self.vocab_size = vocab_size - if context_size > 1: - self.conv = nn.Conv1d( - in_channels=decoder_dim, - out_channels=decoder_dim, - kernel_size=context_size, - padding=0, - groups=decoder_dim // 4, # group size == 4 - bias=False, - ) - - def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U). - need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. - Returns: - Return a tensor of shape (N, U, decoder_dim). - """ - y = y.to(torch.int64) - # this stuff about clamp() is a temporary fix for a mismatch - # at utterance start, we use negative ids in beam_search.py - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) - if self.context_size > 1: - embedding_out = embedding_out.permute(0, 2, 1) - if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) - else: - # During inference time, there is no need to do extra padding - # as we only need one output - assert embedding_out.size(-1) == self.context_size - embedding_out = self.conv(embedding_out) - embedding_out = embedding_out.permute(0, 2, 1) - embedding_out = F.relu(embedding_out) - return embedding_out diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/dprnn.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/dprnn.py deleted file mode 100644 index 361b1b385..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/dprnn.py +++ /dev/null @@ -1,304 +0,0 @@ -import random -from typing import Optional, Tuple - -import torch -import torch.nn as nn -from einops import rearrange -from scaling import ActivationBalancer, BasicNorm, DoubleSwish, ScaledLinear, ScaledLSTM -from torch.autograd import Variable - -EPS = torch.finfo(torch.get_default_dtype()).eps - - -def _pad_segment(input, segment_size): - # Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L342 - # input is the features: (B, N, T) - batch_size, dim, seq_len = input.shape - segment_stride = segment_size // 2 - - rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size - if rest > 0: - pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type()) - input = torch.cat([input, pad], 2) - - pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(input.type()) - input = torch.cat([pad_aux, input, pad_aux], 2) - - return input, rest - - -def split_feature(input, segment_size): - # Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L358 - # split the feature into chunks of segment size - # input is the features: (B, N, T) - - input, rest = _pad_segment(input, segment_size) - batch_size, dim, seq_len = input.shape - segment_stride = segment_size // 2 - - segments1 = ( - input[:, :, :-segment_stride] - .contiguous() - .view(batch_size, dim, -1, segment_size) - ) - segments2 = ( - input[:, :, segment_stride:] - .contiguous() - .view(batch_size, dim, -1, segment_size) - ) - segments = ( - torch.cat([segments1, segments2], 3) - .view(batch_size, dim, -1, segment_size) - .transpose(2, 3) - ) - - return segments.contiguous(), rest - - -def merge_feature(input, rest): - # Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L385 - # merge the splitted features into full utterance - # input is the features: (B, N, L, K) - - batch_size, dim, segment_size, _ = input.shape - segment_stride = segment_size // 2 - input = ( - input.transpose(2, 3).contiguous().view(batch_size, dim, -1, segment_size * 2) - ) # B, N, K, L - - input1 = ( - input[:, :, :, :segment_size] - .contiguous() - .view(batch_size, dim, -1)[:, :, segment_stride:] - ) - input2 = ( - input[:, :, :, segment_size:] - .contiguous() - .view(batch_size, dim, -1)[:, :, :-segment_stride] - ) - - output = input1 + input2 - if rest > 0: - output = output[:, :, :-rest] - - return output.contiguous() # B, N, T - - -class RNNEncoderLayer(nn.Module): - """ - RNNEncoderLayer is made up of lstm and feedforward networks. - Args: - input_size: - The number of expected features in the input (required). - hidden_size: - The hidden dimension of rnn layer. - dropout: - The dropout value (default=0.1). - layer_dropout: - The dropout value for model-level warmup (default=0.075). - """ - - def __init__( - self, - input_size: int, - hidden_size: int, - dropout: float = 0.1, - bidirectional: bool = False, - ) -> None: - super(RNNEncoderLayer, self).__init__() - self.input_size = input_size - self.hidden_size = hidden_size - - assert hidden_size >= input_size, (hidden_size, input_size) - self.lstm = ScaledLSTM( - input_size=input_size, - hidden_size=hidden_size // 2 if bidirectional else hidden_size, - proj_size=0, - num_layers=1, - dropout=0.0, - batch_first=True, - bidirectional=bidirectional, - ) - self.norm_final = BasicNorm(input_size) - - # try to ensure the output is close to zero-mean (or at least, zero-median). # noqa - self.balancer = ActivationBalancer( - num_channels=input_size, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - max_abs=6.0, - ) - self.dropout = nn.Dropout(dropout) - - def forward( - self, - src: torch.Tensor, - states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - warmup: float = 1.0, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Pass the input through the encoder layer. - Args: - src: - The sequence to the encoder layer (required). - Its shape is (S, N, E), where S is the sequence length, - N is the batch size, and E is the feature number. - states: - A tuple of 2 tensors (optional). It is for streaming inference. - states[0] is the hidden states of all layers, - with shape of (1, N, input_size); - states[1] is the cell states of all layers, - with shape of (1, N, hidden_size). - """ - src_orig = src - - # alpha = 1.0 means fully use this encoder layer, 0.0 would mean - # completely bypass it. - alpha = warmup if self.training else 1.0 - - # lstm module - src_lstm, new_states = self.lstm(src, states) - src = self.dropout(src_lstm) + src - src = self.norm_final(self.balancer(src)) - - if alpha != 1.0: - src = alpha * src + (1 - alpha) * src_orig - - return src - - -# dual-path RNN -class DPRNN(nn.Module): - """Deep dual-path RNN. - Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py - - args: - input_size: int, dimension of the input feature. The input should have shape - (batch, seq_len, input_size). - hidden_size: int, dimension of the hidden state. - output_size: int, dimension of the output size. - dropout: float, dropout ratio. Default is 0. - num_blocks: int, number of stacked RNN layers. Default is 1. - """ - - def __init__( - self, - feature_dim, - input_size, - hidden_size, - output_size, - dropout=0.1, - num_blocks=1, - segment_size=50, - chunk_width_randomization=False, - ): - super().__init__() - - self.input_size = input_size - self.output_size = output_size - self.hidden_size = hidden_size - - self.segment_size = segment_size - self.chunk_width_randomization = chunk_width_randomization - - self.input_embed = nn.Sequential( - ScaledLinear(feature_dim, input_size), - BasicNorm(input_size), - ActivationBalancer( - num_channels=input_size, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - ), - ) - - # dual-path RNN - self.row_rnn = nn.ModuleList([]) - self.col_rnn = nn.ModuleList([]) - for _ in range(num_blocks): - # intra-RNN is non-causal - self.row_rnn.append( - RNNEncoderLayer( - input_size, hidden_size, dropout=dropout, bidirectional=True - ) - ) - self.col_rnn.append( - RNNEncoderLayer( - input_size, hidden_size, dropout=dropout, bidirectional=False - ) - ) - - # output layer - self.out_embed = nn.Sequential( - ScaledLinear(input_size, output_size), - BasicNorm(output_size), - ActivationBalancer( - num_channels=output_size, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - ), - ) - - def forward(self, input): - # input shape: B, T, F - input = self.input_embed(input) - B, T, D = input.shape - - if self.chunk_width_randomization and self.training: - segment_size = random.randint(self.segment_size // 2, self.segment_size) - else: - segment_size = self.segment_size - input, rest = split_feature(input.transpose(1, 2), segment_size) - # input shape: batch, N, dim1, dim2 - # apply RNN on dim1 first and then dim2 - # output shape: B, output_size, dim1, dim2 - # input = input.to(device) - batch_size, _, dim1, dim2 = input.shape - output = input - for i in range(len(self.row_rnn)): - row_input = ( - output.permute(0, 3, 2, 1) - .contiguous() - .view(batch_size * dim2, dim1, -1) - ) # B*dim2, dim1, N - output = self.row_rnn[i](row_input) # B*dim2, dim1, H - output = ( - output.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous() - ) # B, N, dim1, dim2 - - col_input = ( - output.permute(0, 2, 3, 1) - .contiguous() - .view(batch_size * dim1, dim2, -1) - ) # B*dim1, dim2, N - output = self.col_rnn[i](col_input) # B*dim1, dim2, H - output = ( - output.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous() - ) # B, N, dim1, dim2 - - output = merge_feature(output, rest) - output = output.transpose(1, 2) - output = self.out_embed(output) - - # Apply ReLU to the output - output = torch.relu(output) - - return output - - -if __name__ == "__main__": - - model = DPRNN( - 80, - 256, - 256, - 160, - dropout=0.1, - num_blocks=3, - segment_size=20, - chunk_width_randomization=True, - ) - input = torch.randn(2, 1002, 80) - print(model(input).shape) diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/encoder_interface.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/encoder_interface.py deleted file mode 100644 index 257facce4..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/encoder_interface.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple - -import torch -import torch.nn as nn - - -class EncoderInterface(nn.Module): - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A tensor of shape (batch_size, input_seq_len, num_features) - containing the input features. - x_lens: - A tensor of shape (batch_size,) containing the number of frames - in `x` before padding. - Returns: - Return a tuple containing two tensors: - - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) - containing unnormalized probabilities, i.e., the output of a - linear layer. - - encoder_out_lens, a tensor of shape (batch_size,) containing - the number of frames in `encoder_out` before padding. - """ - raise NotImplementedError("Please implement it in a subclass") diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/export.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/export.py deleted file mode 100755 index 5c06cc052..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/export.py +++ /dev/null @@ -1,320 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" - -Usage: - -(1) Export to torchscript model using torch.jit.script() - -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 30 \ - --avg 9 \ - --jit 1 - -It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("cpu_jit.pt")`. - -Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python -are on CPU. You can use `to("cuda")` to move them to a CUDA device. - -Check -https://github.com/k2-fsa/sherpa -for how to use the exported models outside of icefall. - -(2) Export `model.state_dict()` - -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `pruned_transducer_stateless7_streaming/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/librispeech/ASR - ./pruned_transducer_stateless7_streaming/decode.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 600 \ - --decoding-method greedy_search \ - --bpe-model data/lang_bpe_500/bpe.model - -Check ./pretrained.py for its usage. - -Note: If you don't want to train a model from scratch, we have -provided one for you. You can get it at - -https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 - -with the following commands: - - sudo apt-get install git-lfs - git lfs install - git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 - # You will find the pre-trained model in icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp -""" - -import argparse -import logging -from pathlib import Path - -import sentencepiece as spm -import torch -import torch.nn as nn -from scaling_converter import convert_scaled_to_non_scaled -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named cpu_jit.pt - - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/joiner.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/joiner.py deleted file mode 100644 index 3ddac2cf2..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/joiner.py +++ /dev/null @@ -1,65 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn - - -class Joiner(nn.Module): - def __init__( - self, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - super().__init__() - - self.encoder_proj = nn.Linear(encoder_dim, joiner_dim) - self.decoder_proj = nn.Linear(decoder_dim, joiner_dim) - self.output_linear = nn.Linear(joiner_dim, vocab_size) - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - project_input: bool = True, - ) -> torch.Tensor: - """ - Args: - encoder_out: - Output from the encoder. Its shape is (N, T, s_range, C). - decoder_out: - Output from the decoder. Its shape is (N, T, s_range, C). - project_input: - If true, apply input projections encoder_proj and decoder_proj. - If this is false, it is the user's responsibility to do this - manually. - Returns: - Return a tensor of shape (N, T, s_range, C). - """ - assert encoder_out.ndim == decoder_out.ndim - assert encoder_out.ndim in (2, 4) - assert encoder_out.shape[:-1] == decoder_out.shape[:-1] - - if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) - else: - logit = encoder_out + decoder_out - - logit = self.output_linear(torch.tanh(logit)) - - return logit diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/model.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/model.py deleted file mode 100644 index b5766a809..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/model.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Tuple - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface - -from icefall.utils import add_sos - - -class SURT(nn.Module): - """It implements Streaming Unmixing and Recognition Transducer (SURT). - https://arxiv.org/abs/2011.13148 - """ - - def __init__( - self, - mask_encoder: nn.Module, - encoder: EncoderInterface, - decoder: nn.Module, - joiner: nn.Module, - num_channels: int, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - """ - Args: - mask_encoder: - It is the masking network. It generates a mask for each channel of the - encoder. These masks are applied to the input features, and then passed - to the transcription network. - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dm) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output contains - unnormalized probs, i.e., not processed by log-softmax. - num_channels: - It is the number of channels that the input features will be split into. - In general, it should be equal to the maximum number of simultaneously - active speakers. For most real scenarios, using 2 channels is sufficient. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface), type(encoder) - assert hasattr(decoder, "blank_id") - - self.mask_encoder = mask_encoder - self.encoder = encoder - self.decoder = decoder - self.joiner = joiner - self.num_channels = num_channels - - self.simple_am_proj = nn.Linear( - encoder_dim, - vocab_size, - ) - self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) - - self.ctc_output = nn.Sequential( - nn.Dropout(p=0.1), - nn.Linear(encoder_dim, vocab_size), - nn.LogSoftmax(dim=-1), - ) - - def forward_helper( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - reduction: str = "sum", - beam_size: int = 10, - use_double_scores: bool = False, - subsampling_factor: int = 1, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Compute transducer loss for one branch of the SURT model. - """ - encoder_out, x_lens = self.encoder(x, x_lens) - assert torch.all(x_lens > 0) - - # compute ctc log-probs - ctc_output = self.ctc_output(encoder_out) - - # For the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) - boundary[:, 2] = y_lens - boundary[:, 3] = x_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction=reduction, - return_grad=True, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - reduction=reduction, - ) - - # Compute ctc loss - supervision_segments = torch.stack( - ( - torch.arange(len(x_lens), device="cpu"), - torch.zeros_like(x_lens, device="cpu"), - torch.clone(x_lens).detach().cpu(), - ), - dim=1, - ).to(torch.int32) - # We need to sort supervision_segments in decreasing order of num_frames - indices = torch.argsort(supervision_segments[:, 2], descending=True) - supervision_segments = supervision_segments[indices] - - # Works with a BPE model - decoding_graph = k2.ctc_graph(y, modified=False, device=x.device) - dense_fsa_vec = k2.DenseFsaVec( - ctc_output, - supervision_segments, - allow_truncate=subsampling_factor - 1, - ) - ctc_loss = k2.ctc_loss( - decoding_graph=decoding_graph, - dense_fsa_vec=dense_fsa_vec, - output_beam=beam_size, - reduction="none", - use_double_scores=use_double_scores, - ) - - return (simple_loss, pruned_loss, ctc_loss) - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - reduction: str = "sum", - beam_size: int = 10, - use_double_scores: bool = False, - subsampling_factor: int = 1, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor of shape (N*num_channels, S). It contains the labels - of the N utterances. The labels are in the range [0, vocab_size). All - the channels are concatenated together one after another. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - reduction: - "sum" to sum the losses over all utterances in the batch. - "none" to return the loss in a 1-D tensor for each utterance - in the batch. - beam_size: - The beam size used in CTC decoding. - use_double_scores: - If True, use double precision for CTC decoding. - subsampling_factor: - The subsampling factor of the model. It is used to compute the - supervision segments for CTC loss. - Returns: - Return the transducer loss. - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0), (x.size(), x_lens.size()) - - # Apply the mask encoder - B, T, F = x.shape - processed = self.mask_encoder(x) # B,T,F*num_channels - masks = processed.view(B, T, F, self.num_channels).unbind(dim=-1) - x_masked = [x * m for m in masks] - - # Recognition - # Stack the inputs along the batch axis - h = torch.cat(x_masked, dim=0) - h_lens = torch.cat([x_lens for _ in range(self.num_channels)], dim=0) - - simple_loss, pruned_loss, ctc_loss = self.forward_helper( - h, - h_lens, - y, - prune_range, - am_scale, - lm_scale, - reduction=reduction, - beam_size=beam_size, - use_double_scores=use_double_scores, - subsampling_factor=subsampling_factor, - ) - - # Chunks the outputs into 2 parts along batch axis and then stack them along a new axis. - simple_loss = torch.stack( - torch.chunk(simple_loss, self.num_channels, dim=0), dim=0 - ) - pruned_loss = torch.stack( - torch.chunk(pruned_loss, self.num_channels, dim=0), dim=0 - ) - ctc_loss = torch.stack(torch.chunk(ctc_loss, self.num_channels, dim=0), dim=0) - - return (simple_loss, pruned_loss, ctc_loss) diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/optim.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/optim.py deleted file mode 100644 index 374b78cb3..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/optim.py +++ /dev/null @@ -1,1061 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import contextlib -import logging -import random -from collections import defaultdict -from typing import List, Optional, Tuple, Union - -import torch -from lhotse.utils import fix_random_seed -from scaling import ActivationBalancer -from torch import Tensor -from torch.optim import Optimizer - - -class BatchedOptimizer(Optimizer): - """ - This class adds to class Optimizer the capability to optimize parameters in batches: - it will stack the parameters and their grads for you so the optimizer can work - on tensors with an extra leading dimension. This is intended for speed with GPUs, - as it reduces the number of kernels launched in the optimizer. - - Args: - params: - """ - - def __init__(self, params, defaults): - super(BatchedOptimizer, self).__init__(params, defaults) - - @contextlib.contextmanager - def batched_params(self, param_group, group_params_names): - """ - This function returns (technically, yields) a list of - of tuples (p, state), where - p is a `fake` parameter that is stacked (over axis 0) from real parameters - that share the same shape, and its gradient is also stacked; - `state` is the state corresponding to this batch of parameters - (it will be physically located in the "state" for one of the real - parameters, the last one that has any particular shape and dtype). - - This function is decorated as a context manager so that it can - write parameters back to their "real" locations. - - The idea is, instead of doing: - - for p in group["params"]: - state = self.state[p] - ... - - you can do: - - with self.batched_params(group["params"]) as batches: - for p, state, p_names in batches: - ... - - - Args: - group: a parameter group, which is a list of parameters; should be - one of self.param_groups. - group_params_names: name for each parameter in group, - which is List[str]. - """ - batches = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter - batches_names = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str - - assert len(param_group) == len(group_params_names) - for p, named_p in zip(param_group, group_params_names): - key = (str(p.dtype), *p.shape) - batches[key].append(p) - batches_names[key].append(named_p) - - batches_names_keys = list(batches_names.keys()) - sorted_idx = sorted( - range(len(batches_names)), key=lambda i: batches_names_keys[i] - ) - batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx] - batches = [batches[batches_names_keys[idx]] for idx in sorted_idx] - - stacked_params_dict = dict() - - # turn batches into a list, in deterministic order. - # tuples will contain tuples of (stacked_param, state, stacked_params_names), - # one for each batch in `batches`. - tuples = [] - - for batch, batch_names in zip(batches, batches_names): - p = batch[0] - # we arbitrarily store the state in the - # state corresponding to the 1st parameter in the - # group. class Optimizer will take care of saving/loading state. - state = self.state[p] - p_stacked = torch.stack(batch) - grad = torch.stack( - [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] - ) - p_stacked.grad = grad - stacked_params_dict[key] = p_stacked - tuples.append((p_stacked, state, batch_names)) - - yield tuples # <-- calling code will do the actual optimization here! - - for ((stacked_params, _state, _names), batch) in zip(tuples, batches): - for i, p in enumerate(batch): # batch is list of Parameter - p.copy_(stacked_params[i]) - - -class ScaledAdam(BatchedOptimizer): - """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) - - - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period - """ - - def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, - parameters_names=None, - show_dominant_parameters=True, - ): - - assert parameters_names is not None, ( - "Please prepare parameters_names," - "which is a List[List[str]]. Each List[str] is for a group" - "and each str is for a parameter" - ) - defaults = dict( - lr=lr, - clipping_scale=clipping_scale, - betas=betas, - scalar_lr_scale=scalar_lr_scale, - eps=eps, - param_min_rms=param_min_rms, - param_max_rms=param_max_rms, - scalar_max=scalar_max, - size_update_period=size_update_period, - clipping_update_period=clipping_update_period, - ) - - super(ScaledAdam, self).__init__(params, defaults) - assert len(self.param_groups) == len(parameters_names) - self.parameters_names = parameters_names - self.show_dominant_parameters = show_dominant_parameters - - def __setstate__(self, state): - super(ScaledAdam, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - batch = True - - for group, group_params_names in zip(self.param_groups, self.parameters_names): - - with self.batched_params(group["params"], group_params_names) as batches: - - # batches is list of pairs (stacked_param, state). stacked_param is like - # a regular parameter, and will have a .grad, but the 1st dim corresponds to - # a stacking dim, it is not a real dim. - - if ( - len(batches[0][1]) == 0 - ): # if len(first state) == 0: not yet initialized - clipping_scale = 1 - else: - clipping_scale = self._get_clipping_scale(group, batches) - - for p, state, _ in batches: - # Perform optimization step. - # grad is not going to be None, we handled that when creating the batches. - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "ScaledAdam optimizer does not support sparse gradients" - ) - # State initialization - if len(state) == 0: - self._init_state(group, p, state) - - self._step_one_batch(group, p, state, clipping_scale) - - return loss - - def _init_state(self, group: dict, p: Tensor, state: dict): - """ - Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p - is actually the batch dimension, corresponding to batched-together - parameters of a given shape. - - - Args: - group: Dict to look up configuration values. - p: The parameter that we are initializing the state for - state: Dict from string to whatever state we are initializing - """ - size_update_period = group["size_update_period"] - - state["step"] = 0 - - kwargs = {"device": p.device, "dtype": p.dtype} - - # 'delta' implements conventional momentum. There are - # several different kinds of update going on, so rather than - # compute "exp_avg" like in Adam, we store and decay a - # parameter-change "delta", which combines all forms of - # update. this is equivalent to how it's done in Adam, - # except for the first few steps. - state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) - - batch_size = p.shape[0] - numel = p.numel() // batch_size - numel = p.numel() - - if numel > 1: - # "param_rms" just periodically records the scalar root-mean-square value of - # the parameter tensor. - # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - state["param_rms"] = param_rms - - state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros( - size_update_period, *param_rms.shape, **kwargs - ) - - # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) - - def _get_clipping_scale( - self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]] - ) -> float: - """ - Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients - by this amount before applying the rest of the update. - - Args: - group: the parameter group, an item in self.param_groups - tuples: a list of tuples of (param, state, param_names) - where param is a batched set of parameters, - with a .grad (1st dim is batch dim) - and state is the state-dict where optimization parameters are kept. - param_names is a List[str] while each str is name for a parameter - in batched set of parameters "param". - """ - assert len(tuples) >= 1 - clipping_scale = group["clipping_scale"] - (first_p, first_state, _) = tuples[0] - step = first_state["step"] - if clipping_scale is None or step == 0: - # no clipping. return early on step == 0 because the other - # parameters' state won't have been initialized yet. - return 1.0 - clipping_update_period = group["clipping_update_period"] - - tot_sumsq = torch.tensor(0.0, device=first_p.device) - for (p, state, param_names) in tuples: - grad = p.grad - if grad.is_sparse: - raise RuntimeError( - "ScaledAdam optimizer does not support sparse gradients" - ) - if p.numel() == p.shape[0]: # a batch of scalars - tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] - else: - tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() - - tot_norm = tot_sumsq.sqrt() - if "model_norms" not in first_state: - first_state["model_norms"] = torch.zeros( - clipping_update_period, device=p.device - ) - first_state["model_norms"][step % clipping_update_period] = tot_norm - - if step % clipping_update_period == 0: - # Print some stats. - # We don't reach here if step == 0 because we would have returned - # above. - sorted_norms = first_state["model_norms"].sort()[0].to("cpu") - quartiles = [] - for n in range(0, 5): - index = min( - clipping_update_period - 1, (clipping_update_period // 4) * n - ) - quartiles.append(sorted_norms[index].item()) - - median = quartiles[2] - threshold = clipping_scale * median - first_state["model_norm_threshold"] = threshold - percent_clipped = ( - first_state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in first_state - else 0.0 - ) - first_state["num_clipped"] = 0 - quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.info( - f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" - ) - - if step < clipping_update_period: - return 1.0 # We have not yet estimated a norm to clip to. - else: - try: - model_norm_threshold = first_state["model_norm_threshold"] - except KeyError: - logging.info( - "Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?" - ) - return 1.0 - ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) - if ans < 1.0: - first_state["num_clipped"] += 1 - if ans < 0.1: - logging.warn( - f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}" - ) - if self.show_dominant_parameters: - assert p.shape[0] == len(param_names) - self._show_gradient_dominating_parameter(tuples, tot_sumsq) - return ans - - def _show_gradient_dominating_parameter( - self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor - ): - """ - Show information of parameter wihch dominanting tot_sumsq. - - Args: - tuples: a list of tuples of (param, state, param_names) - where param is a batched set of parameters, - with a .grad (1st dim is batch dim) - and state is the state-dict where optimization parameters are kept. - param_names is a List[str] while each str is name for a parameter - in batched set of parameters "param". - tot_sumsq: sumsq of all parameters. Though it's could be calculated - from tuples, we still pass it to save some time. - """ - all_sumsq_orig = {} - for (p, state, batch_param_names) in tuples: - # p is a stacked batch parameters. - batch_grad = p.grad - if p.numel() == p.shape[0]: # a batch of scalars - batch_sumsq_orig = batch_grad**2 - # Dummpy values used by following `zip` statement. - batch_rms_orig = torch.ones(p.shape[0]) - else: - batch_rms_orig = state["param_rms"] - batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum( - dim=list(range(1, batch_grad.ndim)) - ) - - for name, sumsq_orig, rms, grad in zip( - batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad - ): - - proportion_orig = sumsq_orig / tot_sumsq - all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad) - - assert torch.isclose( - sum([value[0] for value in all_sumsq_orig.values()]).cpu(), - torch.tensor(1.0), - ) - sorted_by_proportion = { - k: v - for k, v in sorted( - all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True - ) - } - dominant_param_name = next(iter(sorted_by_proportion)) - ( - dominant_proportion, - dominant_sumsq, - dominant_rms, - dominant_grad, - ) = sorted_by_proportion[dominant_param_name] - logging.info( - f"Parameter Dominanting tot_sumsq {dominant_param_name}" - f" with proportion {dominant_proportion:.2f}," - f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)" - f"={dominant_sumsq:.3e}," - f" grad_sumsq = {(dominant_grad**2).sum():.3e}," - f" orig_rms_sq={(dominant_rms**2).item():.3e}" - ) - - def _step_one_batch( - self, group: dict, p: Tensor, state: dict, clipping_scale: float - ): - """ - Do the step for one parameter, which is actually going to be a batch of - `real` parameters, with dim 0 as the batch dim. - Args: - group: dict to look up configuration values - p: parameter to update (actually multiple parameters stacked together - as a batch) - state: state-dict for p, to look up the optimizer state - """ - lr = group["lr"] - size_update_period = group["size_update_period"] - beta1 = group["betas"][0] - - grad = p.grad - if clipping_scale != 1.0: - grad = grad * clipping_scale - step = state["step"] - delta = state["delta"] - - delta.mul_(beta1) - batch_size = p.shape[0] - numel = p.numel() // batch_size - if numel > 1: - # Update the size/scale of p, and set param_rms - scale_grads = state["scale_grads"] - scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True - ) - if step % size_update_period == size_update_period - 1: - param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_( - (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - ) - if step > 0: - # self._size_update() learns the overall scale on the - # parameter, by shrinking or expanding it. - self._size_update(group, scale_grads, p, state) - - if numel == 1: - # For parameters with 1 element we just use regular Adam. - # Updates delta. - self._step_scalar(group, p, state) - else: - self._step(group, p, state) - - state["step"] = step + 1 - - def _size_update( - self, group: dict, scale_grads: Tensor, p: Tensor, state: dict - ) -> None: - """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. - - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p - """ - - param_rms = state["param_rms"] - beta1, beta2 = group["betas"] - size_lr = group["lr"] * group["scalar_lr_scale"] - param_min_rms = group["param_min_rms"] - param_max_rms = group["param_max_rms"] - eps = group["eps"] - step = state["step"] - batch_size = p.shape[0] - - size_update_period = scale_grads.shape[0] - # correct beta2 for the size update period: we will have - # faster decay at this level. - beta2_corr = beta2**size_update_period - - scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) - scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` - alpha=1 - beta2_corr, - ) # shape is (batch_size, 1, 1, ...) - - # The 1st time we reach here is when size_step == 1. - size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr**size_step - # we don't bother with bias_correction1; this will help prevent divergence - # at the start of training. - - denom = scale_exp_avg_sq.sqrt() + eps - - scale_step = ( - -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom - ) - - is_too_small = param_rms < param_min_rms - is_too_large = param_rms > param_max_rms - - # when the param gets too small, just don't shrink it any further. - scale_step.masked_fill_(is_too_small, 0.0) - # when it gets too large, stop it from getting any larger. - scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) - delta = state["delta"] - # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1 - beta1)) - - def _step(self, group: dict, p: Tensor, state: dict): - """ - This function does the core update of self.step(), in the case where the members of - the batch have more than 1 element. - - Args: - group: A dict which will be used to look up configuration values - p: The parameter to be updated - grad: The grad of p - state: The state-dict corresponding to parameter p - - This function modifies p. - """ - grad = p.grad - lr = group["lr"] - beta1, beta2 = group["betas"] - eps = group["eps"] - param_min_rms = group["param_min_rms"] - step = state["step"] - - exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) - - this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) - bias_correction2 = 1 - beta2 ** (this_step + 1) - if bias_correction2 < 0.99: - # note: not in-place. - exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2) - - denom = exp_avg_sq.sqrt() - denom += eps - grad = grad / denom - - alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) - - delta = state["delta"] - delta.add_(grad * alpha) - p.add_(delta) - - def _step_scalar(self, group: dict, p: Tensor, state: dict): - """ - A simplified form of the core update for scalar tensors, where we cannot get a good - estimate of the parameter rms. - """ - beta1, beta2 = group["betas"] - scalar_max = group["scalar_max"] - eps = group["eps"] - lr = group["lr"] * group["scalar_lr_scale"] - grad = p.grad - - exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - - # bias_correction2 is like in Adam. Don't bother with bias_correction1; - # slower update at the start will help stability anyway. - bias_correction2 = 1 - beta2 ** (state["step"] + 1) - denom = (exp_avg_sq / bias_correction2).sqrt() + eps - - delta = state["delta"] - delta.add_(grad / denom, alpha=-lr * (1 - beta1)) - p.clamp_(min=-scalar_max, max=scalar_max) - p.add_(delta) - - -class LRScheduler(object): - """ - Base-class for learning rate schedulers where the learning-rate depends on both the - batch and the epoch. - """ - - def __init__(self, optimizer: Optimizer, verbose: bool = False): - # Attach optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) - self.optimizer = optimizer - self.verbose = verbose - - for group in optimizer.param_groups: - group.setdefault("base_lr", group["lr"]) - - self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] - - self.epoch = 0 - self.batch = 0 - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - """ - return { - "base_lrs": self.base_lrs, - "epoch": self.epoch, - "batch": self.batch, - } - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - self.__dict__.update(state_dict) - - def get_last_lr(self) -> List[float]: - """Return last computed learning rate by current scheduler. Will be a list of float.""" - return self._last_lr - - def get_lr(self): - # Compute list of learning rates from self.epoch and self.batch and - # self.base_lrs; this must be overloaded by the user. - # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] - raise NotImplementedError - - def step_batch(self, batch: Optional[int] = None) -> None: - # Step the batch index, or just set it. If `batch` is specified, it - # must be the batch index from the start of training, i.e. summed over - # all epochs. - # You can call this in any order; if you don't provide 'batch', it should - # of course be called once per batch. - if batch is not None: - self.batch = batch - else: - self.batch = self.batch + 1 - self._set_lrs() - - def step_epoch(self, epoch: Optional[int] = None): - # Step the epoch index, or just set it. If you provide the 'epoch' arg, - # you should call this at the start of the epoch; if you don't provide the 'epoch' - # arg, you should call it at the end of the epoch. - if epoch is not None: - self.epoch = epoch - else: - self.epoch = self.epoch + 1 - self._set_lrs() - - def _set_lrs(self): - values = self.get_lr() - assert len(values) == len(self.optimizer.param_groups) - - for i, data in enumerate(zip(self.optimizer.param_groups, values)): - param_group, lr = data - param_group["lr"] = lr - self.print_lr(self.verbose, i, lr) - self._last_lr = [group["lr"] for group in self.optimizer.param_groups] - - def print_lr(self, is_verbose, group, lr): - """Display the current learning rate.""" - if is_verbose: - logging.info( - f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" - f" of group {group} to {lr:.4e}." - ) - - -class Eden(LRScheduler): - """ - Eden scheduler. - The basic formula (before warmup) is: - lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * - (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup - where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches - and then stays constant at 1. - - - E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam - - Args: - optimizer: the optimizer to change the learning rates on - lr_batches: the number of batches after which we start significantly - decreasing the learning rate, suggest 5000. - lr_epochs: the number of epochs after which we start significantly - decreasing the learning rate, suggest 6 if you plan to do e.g. - 20 to 40 epochs, but may need smaller number if dataset is huge - and you will do few epochs. - """ - - def __init__( - self, - optimizer: Optimizer, - lr_batches: Union[int, float], - lr_epochs: Union[int, float], - warmup_batches: Union[int, float] = 500.0, - verbose: bool = False, - ): - super(Eden, self).__init__(optimizer, verbose) - self.lr_batches = lr_batches - self.lr_epochs = lr_epochs - self.warmup_batches = warmup_batches - - def get_lr(self): - factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 - ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 - ) - warmup_factor = ( - 1.0 - if self.batch >= self.warmup_batches - else 0.5 + 0.5 * (self.batch / self.warmup_batches) - ) - - return [x * factor * warmup_factor for x in self.base_lrs] - - -def _test_eden(): - m = torch.nn.Linear(100, 100) - optim = ScaledAdam(m.parameters(), lr=0.03) - - scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True) - - for epoch in range(10): - scheduler.step_epoch(epoch) # sets epoch to `epoch` - - for step in range(20): - x = torch.randn(200, 100).detach() - x.requires_grad = True - y = m(x) - dy = torch.randn(200, 100).detach() - f = (y * dy).sum() - f.backward() - - optim.step() - scheduler.step_batch() - optim.zero_grad() - - logging.info(f"last lr = {scheduler.get_last_lr()}") - logging.info(f"state dict = {scheduler.state_dict()}") - - -# This is included mostly as a baseline for ScaledAdam. -class Eve(Optimizer): - """ - Implements Eve algorithm. This is a modified version of AdamW with a special - way of setting the weight-decay / shrinkage-factor, which is designed to make the - rms of the parameters approach a particular target_rms (default: 0.1). This is - for use with networks with 'scaled' versions of modules (see scaling.py), which - will be close to invariant to the absolute scale on the parameter matrix. - - The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. - The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. - Eve is unpublished so far. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 3e-4; - this value means that the weight would decay significantly after - about 3k minibatches. Is not multiplied by learning rate, but - is conditional on RMS-value of parameter being > target_rms. - target_rms (float, optional): target root-mean-square value of - parameters, if they fall below this we will stop applying weight decay. - - - .. _Adam: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.98), - eps=1e-8, - weight_decay=1e-3, - target_rms=0.1, - ): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - if not 0 < target_rms <= 10.0: - raise ValueError("Invalid target_rms value: {}".format(target_rms)) - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - target_rms=target_rms, - ) - super(Eve, self).__init__(params, defaults) - - def __setstate__(self, state): - super(Eve, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - - # Perform optimization step - grad = p.grad - if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") - - state = self.state[p] - - # State initialization - if len(state) == 0: - state["step"] = 0 - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - - beta1, beta2 = group["betas"] - - state["step"] += 1 - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( - group["eps"] - ) - - step_size = group["lr"] / bias_correction1 - target_rms = group["target_rms"] - weight_decay = group["weight_decay"] - - if p.numel() > 1: - # avoid applying this weight-decay on "scaling factors" - # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) - p.mul_(1 - (weight_decay * is_above_target_rms)) - - p.addcdiv_(exp_avg, denom, value=-step_size) - - if random.random() < 0.0005: - step = (exp_avg / denom) * step_size - logging.info( - f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" - ) - - return loss - - -def _test_scaled_adam(hidden_dim: int): - import timeit - - from scaling import ScaledLinear - - E = 100 - B = 4 - T = 2 - logging.info("in test_eve_cain") - # device = torch.device('cuda') - device = torch.device("cpu") - dtype = torch.float32 - - fix_random_seed(42) - # these input_magnitudes and output_magnitudes are to test that - # Abel is working as we expect and is able to adjust scales of - # different dims differently. - input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp() - - for iter in [1, 0]: - fix_random_seed(42) - Linear = torch.nn.Linear if iter == 0 else ScaledLinear - - m = torch.nn.Sequential( - Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) - - train_pairs = [ - ( - 100.0 - * torch.randn(B, T, E, device=device, dtype=dtype) - * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, - ) - for _ in range(20) - ] - - if iter == 0: - optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: - optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) - scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) - - start = timeit.default_timer() - avg_loss = 0.0 - for epoch in range(180): - scheduler.step_epoch() - # if epoch == 100 and iter in [2,3]: - # optim.reset_speedup() # check it doesn't crash. - - # if epoch == 130: - # opts = diagnostics.TensorDiagnosticOptions( - # 2 ** 22 - # ) # allow 4 megabytes per sub-module - # diagnostic = diagnostics.attach_diagnostics(m, opts) - - for n, (x, y) in enumerate(train_pairs): - y_out = m(x) - loss = ((y_out - y) ** 2).mean() * 100.0 - if epoch == 0 and n == 0: - avg_loss = loss.item() - else: - avg_loss = 0.98 * avg_loss + 0.02 * loss.item() - if n == 0 and epoch % 5 == 0: - # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) - lr = scheduler.get_last_lr()[0] - logging.info( - f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}" - ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} - loss.log().backward() - optim.step() - optim.zero_grad() - scheduler.step_batch() - - # diagnostic.print_diagnostics() - - stop = timeit.default_timer() - logging.info(f"Iter={iter}, Time taken: {stop - start}") - - logging.info(f"last lr = {scheduler.get_last_lr()}") - # logging.info("state dict = ", scheduler.state_dict()) - # logging.info("optim state_dict = ", optim.state_dict()) - logging.info(f"input_magnitudes = {input_magnitudes}") - logging.info(f"output_magnitudes = {output_magnitudes}") - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - logging.getLogger().setLevel(logging.INFO) - import subprocess - - s = subprocess.check_output( - "git status -uno .; git log -1; git diff HEAD .", shell=True - ) - logging.info(s) - import sys - - if len(sys.argv) > 1: - hidden_dim = int(sys.argv[1]) - else: - hidden_dim = 200 - - _test_scaled_adam(hidden_dim) - _test_eden() diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/pretrained.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/pretrained.py deleted file mode 100755 index fb77fdd42..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/pretrained.py +++ /dev/null @@ -1,355 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --bpe-model data/lang_bpe_500/bpe.model \ - --epoch 20 \ - --avg 10 - -Usage of this script: - -(1) greedy search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --bpe-model ./data/lang_bpe_500/bpe.model \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by -./pruned_transducer_stateless7_streaming/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import sentencepiece as spm -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--bpe-model", - type=str, - help="""Path to bpe.model.""", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(sp.decode(hyp).split()) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/scaling.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/scaling.py deleted file mode 100644 index 835bf72ca..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/scaling.py +++ /dev/null @@ -1,1533 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import logging -import random -from typing import Optional, Tuple, Union - -import torch -import torch.backends.cudnn.rnn as rnn -import torch.nn as nn -import torch.nn.functional as F -from torch import _VF, Tensor - -from icefall.utils import is_jit_tracing - - -class ActivationBalancerFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - scale_factor: Tensor, - sign_factor: Optional[Tensor], - channel_dim: int, - ) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - ctx.channel_dim = channel_dim - xgt0 = x > 0 - if sign_factor is None: - ctx.save_for_backward(xgt0, scale_factor) - else: - ctx.save_for_backward(xgt0, scale_factor, sign_factor) - return x - - @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: - if len(ctx.saved_tensors) == 3: - xgt0, scale_factor, sign_factor = ctx.saved_tensors - for _ in range(ctx.channel_dim, x_grad.ndim - 1): - scale_factor = scale_factor.unsqueeze(-1) - sign_factor = sign_factor.unsqueeze(-1) - factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) - else: - xgt0, scale_factor = ctx.saved_tensors - for _ in range(ctx.channel_dim, x_grad.ndim - 1): - scale_factor = scale_factor.unsqueeze(-1) - factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) - neg_delta_grad = x_grad.abs() * factor - return ( - x_grad - neg_delta_grad, - None, - None, - None, - ) - - -def _compute_scale_factor( - x: Tensor, - channel_dim: int, - min_abs: float, - max_abs: float, - gain_factor: float, - max_factor: float, -) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] - x_abs_mean = torch.mean(x.abs(), dim=sum_dims).to(torch.float32) - - if min_abs == 0.0: - below_threshold = 0.0 - else: - # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if - # x_abs)_mean , min_abs. - below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( - min=0, max=max_factor - ) - - above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( - min=0, max=max_factor - ) - - return below_threshold - above_threshold - - -def _compute_sign_factor( - x: Tensor, - channel_dim: int, - min_positive: float, - max_positive: float, - gain_factor: float, - max_factor: float, -) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) - if min_positive == 0.0: - factor1 = 0.0 - else: - # 0 if proportion_positive >= min_positive, else can be - # as large as max_factor. - factor1 = ( - (min_positive - proportion_positive) * (gain_factor / min_positive) - ).clamp_(min=0, max=max_factor) - - if max_positive == 1.0: - factor2 = 0.0 - else: - # 0 if self.proportion_positive <= max_positive, else can be - # as large as -max_factor. - factor2 = ( - (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) - ).clamp_(min=0, max=max_factor) - sign_factor = factor1 - factor2 - # require min_positive != 0 or max_positive != 1: - assert not isinstance(sign_factor, float) - return sign_factor - - -class ActivationScaleBalancerFunction(torch.autograd.Function): - """ - This object is used in class ActivationBalancer when the user specified - min_positive=0, max_positive=1, so there are no constraints on the signs - of the activations and only the absolute value has a constraint. - """ - - @staticmethod - def forward( - ctx, - x: Tensor, - sign_factor: Tensor, - scale_factor: Tensor, - channel_dim: int, - ) -> Tensor: - if channel_dim < 0: - channel_dim += x.ndim - ctx.channel_dim = channel_dim - xgt0 = x > 0 - ctx.save_for_backward(xgt0, sign_factor, scale_factor) - return x - - @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: - xgt0, sign_factor, scale_factor = ctx.saved_tensors - for _ in range(ctx.channel_dim, x_grad.ndim - 1): - sign_factor = sign_factor.unsqueeze(-1) - scale_factor = scale_factor.unsqueeze(-1) - - factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) - neg_delta_grad = x_grad.abs() * factor - return ( - x_grad - neg_delta_grad, - None, - None, - None, - ) - - -class RandomClampFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - min: Optional[float], - max: Optional[float], - prob: float, - reflect: float, - ) -> Tensor: - x_clamped = torch.clamp(x, min=min, max=max) - mask = torch.rand_like(x) < prob - ans = torch.where(mask, x_clamped, x) - if x.requires_grad: - ctx.save_for_backward(ans == x) - ctx.reflect = reflect - if reflect != 0.0: - ans = ans * (1.0 + reflect) - (x * reflect) - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - (is_same,) = ctx.saved_tensors - x_grad = ans_grad * is_same.to(ans_grad.dtype) - reflect = ctx.reflect - if reflect != 0.0: - x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) - return x_grad, None, None, None, None - - -def random_clamp( - x: Tensor, - min: Optional[float] = None, - max: Optional[float] = None, - prob: float = 0.5, - reflect: float = 0.0, -): - return RandomClampFunction.apply(x, min, max, prob, reflect) - - -def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: - """ - A randomized way of casting a floating point value to half precision. - """ - if x.dtype == torch.float16: - return x - x_abs = x.abs() - is_too_small = x_abs < min_abs - # for elements where is_too_small is true, random_val will contain +-min_abs with - # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, - # for those elements]. - random_val = min_abs * x.sign() * (torch.rand_like(x) * min_abs < x_abs) - return torch.where(is_too_small, random_val, x).to(torch.float16) - - -class RandomGradFunction(torch.autograd.Function): - """ - Does nothing in forward pass; in backward pass, gets rid of very small grads using - randomized approach that preserves expectations (intended to reduce roundoff). - """ - - @staticmethod - def forward(ctx, x: Tensor, min_abs: float) -> Tensor: - ctx.min_abs = min_abs - return x - - @staticmethod - def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: - if ans_grad.dtype == torch.float16: - return ( - random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), - None, - ) - else: - return ans_grad, None - - -class RandomGrad(torch.nn.Module): - """ - Gets rid of very small gradients using an expectation-preserving method, intended to increase - accuracy of training when using amp (automatic mixed precision) - """ - - def __init__(self, min_abs: float = 5.0e-06): - super(RandomGrad, self).__init__() - self.min_abs = min_abs - - def forward(self, x: Tensor): - if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing(): - return x - else: - return RandomGradFunction.apply(x, self.min_abs) - - -class SoftmaxFunction(torch.autograd.Function): - """ - Tries to handle half-precision derivatives in a randomized way that should - be more accurate for training than the default behavior. - """ - - @staticmethod - def forward(ctx, x: Tensor, dim: int): - ans = x.softmax(dim=dim) - # if x dtype is float16, x.softmax() returns a float32 because - # (presumably) that op does not support float16, and autocast - # is enabled. - if torch.is_autocast_enabled(): - ans = ans.to(torch.float16) - ctx.save_for_backward(ans) - ctx.x_dtype = x.dtype - ctx.dim = dim - return ans - - @staticmethod - def backward(ctx, ans_grad: Tensor): - (ans,) = ctx.saved_tensors - with torch.cuda.amp.autocast(enabled=False): - ans_grad = ans_grad.to(torch.float32) - ans = ans.to(torch.float32) - x_grad = ans_grad * ans - x_grad = x_grad - ans * x_grad.sum(dim=ctx.dim, keepdim=True) - return x_grad, None - - -def softmax(x: Tensor, dim: int): - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x.softmax(dim) - - return SoftmaxFunction.apply(x, dim) - - -class MaxEigLimiterFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float, - ) -> Tensor: - ctx.channel_dim = channel_dim - ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) - return x - - @staticmethod - def backward(ctx, x_grad, *args): - with torch.enable_grad(): - (x_orig, coeffs, new_direction) = ctx.saved_tensors - x_orig.requires_grad = True - num_channels = x_orig.shape[ctx.channel_dim] - x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) - new_direction.requires_grad = False - x = x - x.mean(dim=0) - x_var = (x**2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. This is to be minimized. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) - variance_proportion.backward() - x_orig_grad = x_orig.grad - x_extra_grad = ( - x_orig.grad - * ctx.grad_scale - * x_grad.norm() - / (x_orig_grad.norm() + 1.0e-20) - ) - return x_grad + x_extra_grad.detach(), None, None, None, None - - -class GradientFilterFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - batch_dim: int, # e.g., 1 - threshold: float, # e.g., 10.0 - *params: Tensor, # module parameters - ) -> Tuple[Tensor, ...]: - if x.requires_grad: - if batch_dim < 0: - batch_dim += x.ndim - ctx.batch_dim = batch_dim - ctx.threshold = threshold - return (x,) + params - - @staticmethod - def backward( - ctx, - x_grad: Tensor, - *param_grads: Tensor, - ) -> Tuple[Tensor, ...]: - eps = 1.0e-20 - dim = ctx.batch_dim - norm_dims = [d for d in range(x_grad.ndim) if d != dim] - norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt() - median_norm = norm_of_batch.median() - - cutoff = median_norm * ctx.threshold - inv_mask = (cutoff + norm_of_batch) / (cutoff + eps) - mask = 1.0 / (inv_mask + eps) - x_grad = x_grad * mask - - avg_mask = 1.0 / (inv_mask.mean() + eps) - param_grads = [avg_mask * g for g in param_grads] - - return (x_grad, None, None) + tuple(param_grads) - - -class GradientFilter(torch.nn.Module): - """This is used to filter out elements that have extremely large gradients - in batch and the module parameters with soft masks. - Args: - batch_dim (int): - The batch dimension. - threshold (float): - For each element in batch, its gradient will be - filtered out if the gradient norm is larger than - `grad_norm_threshold * median`, where `median` is the median - value of gradient norms of all elememts in batch. - """ - - def __init__(self, batch_dim: int = 1, threshold: float = 10.0): - super(GradientFilter, self).__init__() - self.batch_dim = batch_dim - self.threshold = threshold - - def forward(self, x: Tensor, *params: Tensor) -> Tuple[Tensor, ...]: - if torch.jit.is_scripting() or is_jit_tracing(): - return (x,) + params - else: - return GradientFilterFunction.apply( - x, - self.batch_dim, - self.threshold, - *params, - ) - - -class BasicNorm(torch.nn.Module): - """ - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm. The observation this is based on, is that Transformer-type - networks, especially with pre-norm, sometimes seem to set one of the - feature dimensions to a large constant value (e.g. 50), which "defeats" - the LayerNorm because the output magnitude is then not strongly dependent - on the other (useful) features. Presumably the weight and bias of the - LayerNorm are required to allow it to do this. - - So the idea is to introduce this large constant value as an explicit - parameter, that takes the role of the "eps" in LayerNorm, so the network - doesn't have to do this trick. We make the "eps" learnable. - - Args: - num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interprted as an offset from the input's ndim if negative. - shis is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. - eps: the initial "epsilon" that we add as ballast in: - scale = ((input_vec**2).mean() + epsilon)**-0.5 - Note: our epsilon is actually large, but we keep the name - to indicate the connection with conventional LayerNorm. - learn_eps: if true, we learn epsilon; if false, we keep it - at the initial value. - eps_min: float - eps_max: float - """ - - def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True, - eps_min: float = -3.0, - eps_max: float = 3.0, - ) -> None: - super(BasicNorm, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - if learn_eps: - self.eps = nn.Parameter(torch.tensor(eps).log().detach()) - else: - self.register_buffer("eps", torch.tensor(eps).log().detach()) - self.eps_min = eps_min - self.eps_max = eps_max - - def forward(self, x: Tensor) -> Tensor: - assert x.shape[self.channel_dim] == self.num_channels - eps = self.eps - if self.training and random.random() < 0.25: - # with probability 0.25, in training mode, clamp eps between the min - # and max; this will encourage it to learn parameters within the - # allowed range by making parameters that are outside the allowed - # range noisy. - - # gradients to allow the parameter to get back into the allowed - # region if it happens to exit it. - eps = eps.clamp(min=self.eps_min, max=self.eps_max) - scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp() - ) ** -0.5 - return x * scales - - -def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: - """ - Behaves like a constructor of a modified version of nn.Linear - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Linear(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - - -def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: - """ - Behaves like a constructor of a modified version of nn.Conv1d - that gives an easy way to set the default initial parameter scale. - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - """ - ans = nn.Conv1d(*args, **kwargs) - with torch.no_grad(): - ans.weight[:] *= initial_scale - if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) - return ans - - -class ScaledConv2d(nn.Conv2d): - # See docs for ScaledLinear - def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs, - ): - super(ScaledConv2d, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter("bias_scale", None) - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in base class - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3**0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += torch.tensor(scale / std).log() - - def get_weight(self): - return self.weight * self.weight_scale.exp() - - def get_bias(self): - # see https://github.com/pytorch/pytorch/issues/24135 - bias = self.bias - bias_scale = self.bias_scale - if bias is None or bias_scale is None: - return None - else: - return bias * bias_scale.exp() - - def _conv_forward(self, input, weight): - if self.padding_mode != "zeros": - return F.conv2d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - weight, - self.get_bias(), - self.stride, - (0, 0), - self.dilation, - self.groups, - ) - return F.conv2d( - input, - weight, - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) - - def forward(self, input: Tensor) -> Tensor: - return self._conv_forward(input, self.get_weight()) - - -class ScaledLSTM(nn.LSTM): - # See docs for ScaledLinear. - # This class implements LSTM with scaling mechanism, using `torch._VF.lstm` - # Please refer to https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py - def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - grad_norm_threshold: float = 10.0, - **kwargs, - ): - super(ScaledLSTM, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - self._scales_names = [] - self._scales = [] - self.batch_dim = int(not self.batch_first) - for name in self._flat_weights_names: - scale_name = name + "_scale" - self._scales_names.append(scale_name) - param = nn.Parameter(initial_scale.clone().detach()) - setattr(self, scale_name, param) - self._scales.append(param) - - self.grad_filter = GradientFilter( - batch_dim=self.batch_dim, threshold=grad_norm_threshold - ) - - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in base class - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3**0.5) * std - scale = self.hidden_size**-0.5 - v = scale / std - for idx, name in enumerate(self._flat_weights_names): - if "weight" in name: - nn.init.uniform_(self._flat_weights[idx], -a, a) - with torch.no_grad(): - self._scales[idx] += torch.tensor(v).log() - elif "bias" in name: - nn.init.constant_(self._flat_weights[idx], 0.0) - - def _flatten_parameters(self, flat_weights) -> None: - """Resets parameter data pointer so that they can use faster code paths. - - Right now, this works only if the module is on the GPU and cuDNN is enabled. - Otherwise, it's a no-op. - - This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa - """ - # Short-circuits if _flat_weights is only partially instantiated - if len(flat_weights) != len(self._flat_weights_names): - return - - for w in flat_weights: - if not isinstance(w, Tensor): - return - # Short-circuits if any tensor in flat_weights is not acceptable to cuDNN - # or the tensors in flat_weights are of different dtypes - - first_fw = flat_weights[0] - dtype = first_fw.dtype - for fw in flat_weights: - if ( - not isinstance(fw.data, Tensor) - or not (fw.data.dtype == dtype) - or not fw.data.is_cuda - or not torch.backends.cudnn.is_acceptable(fw.data) - ): - return - - # If any parameters alias, we fall back to the slower, copying code path. This is - # a sufficient check, because overlapping parameter buffers that don't completely - # alias would break the assumptions of the uniqueness check in - # Module.named_parameters(). - unique_data_ptrs = set(p.data_ptr() for p in flat_weights) - if len(unique_data_ptrs) != len(flat_weights): - return - - with torch.cuda.device_of(first_fw): - - # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is - # an inplace operation on self._flat_weights - with torch.no_grad(): - if torch._use_cudnn_rnn_flatten_weight(): - num_weights = 4 if self.bias else 2 - if self.proj_size > 0: - num_weights += 1 - torch._cudnn_rnn_flatten_weight( - flat_weights, - num_weights, - self.input_size, - rnn.get_cudnn_mode(self.mode), - self.hidden_size, - self.proj_size, - self.num_layers, - self.batch_first, - bool(self.bidirectional), - ) - - def _get_flat_weights(self): - """Get scaled weights, and resets their data pointer.""" - flat_weights = [] - for idx in range(len(self._flat_weights_names)): - flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp()) - self._flatten_parameters(flat_weights) - return flat_weights - - def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None): - # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa - # The change for calling `_VF.lstm()` is: - # self._flat_weights -> self._get_flat_weights() - if hx is None: - num_directions = 2 if self.bidirectional else 1 - h_zeros = torch.zeros( - self.num_layers * num_directions, - input.size(self.batch_dim), - self.proj_size if self.proj_size > 0 else self.hidden_size, - dtype=input.dtype, - device=input.device, - ) - c_zeros = torch.zeros( - self.num_layers * num_directions, - input.size(self.batch_dim), - self.hidden_size, - dtype=input.dtype, - device=input.device, - ) - hx = (h_zeros, c_zeros) - - self.check_forward_args(input, hx, None) - - flat_weights = self._get_flat_weights() - input, *flat_weights = self.grad_filter(input, *flat_weights) - - result = _VF.lstm( - input, - hx, - flat_weights, - self.bias, - self.num_layers, - self.dropout, - self.training, - self.bidirectional, - self.batch_first, - ) - - output = result[0] - hidden = result[1:] - return output, hidden - - -class ActivationBalancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - Args: - num_channels: the number of channels - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), above which we start to modify the derivatives. - max_factor: the maximum factor by which we modify the derivatives for - either the sign constraint or the magnitude constraint; - e.g. with max_factor=0.02, the the derivatives would be multiplied by - values in the range [0.98..1.02]. - sign_gain_factor: determines the 'gain' with which we increase the - change in gradient once the constraints on min_positive and max_positive - are violated. - scale_gain_factor: determines the 'gain' with which we increase the - change in gradient once the constraints on min_abs and max_abs - are violated. - min_abs: the minimum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - max_abs: the maximum average-absolute-value difference from the mean - value per channel, which we allow, before we start to modify - the derivatives to prevent this. - min_prob: determines the minimum probability with which we modify the - gradients for the {min,max}_positive and {min,max}_abs constraints, - on each forward(). This is done randomly to prevent all layers - from doing it at the same time. Early in training we may use - higher probabilities than this; it will decay to this value. - """ - - def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.04, - sign_gain_factor: float = 0.01, - scale_gain_factor: float = 0.02, - min_abs: float = 0.2, - max_abs: float = 100.0, - min_prob: float = 0.1, - ): - super(ActivationBalancer, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.max_factor = max_factor - self.min_abs = min_abs - self.max_abs = max_abs - self.min_prob = min_prob - self.sign_gain_factor = sign_gain_factor - self.scale_gain_factor = scale_gain_factor - - # count measures how many times the forward() function has been called. - # We occasionally sync this to a tensor called `count`, that exists to - # make sure it is synced to disk when we load and save the model. - self.cpu_count = 0 - self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) - - def forward(self, x: Tensor) -> Tensor: - if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing(): - return _no_op(x) - - count = self.cpu_count - self.cpu_count += 1 - - if random.random() < 0.01: - # Occasionally sync self.cpu_count with self.count. - # count affects the decay of 'prob'. don't do this on every iter, - # because syncing with the GPU is slow. - self.cpu_count = max(self.cpu_count, self.count.item()) - self.count.fill_(self.cpu_count) - - # the prob of doing some work exponentially decreases from 0.5 till it hits - # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) - - if random.random() < prob: - sign_gain_factor = 0.5 - if self.min_positive != 0.0 or self.max_positive != 1.0: - sign_factor = _compute_sign_factor( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor, - ) - else: - sign_factor = None - - scale_factor = _compute_scale_factor( - x.detach(), - self.channel_dim, - min_abs=self.min_abs, - max_abs=self.max_abs, - gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor, - ) - return ActivationBalancerFunction.apply( - x, - scale_factor, - sign_factor, - self.channel_dim, - ) - else: - return _no_op(x) - - -def penalize_abs_values_gt(x: Tensor, limit: float, penalty: float) -> Tensor: - """ - Returns x unmodified, but in backprop will put a penalty for the excess of - the absolute values of elements of x over the limit "limit". E.g. if - limit == 10.0, then if x has any values over 10 it will get a penalty. - - Caution: the value of this penalty will be affected by grad scaling used - in automatic mixed precision training. For this reasons we use this, - it shouldn't really matter, or may even be helpful; we just use this - to disallow really implausible values of scores to be given to softmax. - """ - x_sign = x.sign() - over_limit = (x.abs() - limit) > 0 - # The following is a memory efficient way to penalize the absolute values of - # x that's over the limit. (The memory efficiency comes when you think - # about which items torch needs to cache for the autograd, and which ones it - # can throw away). The numerical value of aux_loss as computed here will - # actually be larger than it should be, by limit * over_limit.sum(), but it - # has the same derivative as the real aux_loss which is penalty * (x.abs() - - # limit).relu(). - aux_loss = penalty * ((x_sign * over_limit).to(torch.int8) * x) - # note: we don't do sum() here on aux)_loss, but it's as if we had done - # sum() due to how with_loss() works. - x = with_loss(x, aux_loss) - # you must use x for something, or this will be ineffective. - return x - - -def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. - if x.ndim == 2: - return x.diag() - else: - (batch, dim, dim) = x.shape - x = x.reshape(batch, dim * dim) - x = x[:, :: dim + 1] - assert x.shape == (batch, dim) - return x - - -def _whitening_metric(x: Tensor, num_groups: int): - """ - Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of - of the centered feature covariance are the same within each group's covariance matrix - and also between groups. - Args: - x: a Tensor of shape (*, num_channels) - num_groups: the number of groups of channels, a number >=1 that divides num_channels - Returns: - Returns a scalar Tensor that will be 1.0 if the data is "perfectly white" and - greater than 1.0 otherwise. - """ - assert x.dtype != torch.float16 - x = x.reshape(-1, x.shape[-1]) - (num_frames, num_channels) = x.shape - assert num_channels % num_groups == 0 - channels_per_group = num_channels // num_groups - x = x.reshape(num_frames, num_groups, channels_per_group).transpose(0, 1) - # x now has shape (num_groups, num_frames, channels_per_group) - # subtract the mean so we use the centered, not uncentered, covariance. - # My experience has been that when we "mess with the gradients" like this, - # it's better not do anything that tries to move the mean around, because - # that can easily cause instability. - x = x - x.mean(dim=1, keepdim=True) - # x_covar: (num_groups, channels_per_group, channels_per_group) - x_covar = torch.matmul(x.transpose(1, 2), x) - x_covar_mean_diag = _diag(x_covar).mean() - # the following expression is what we'd get if we took the matrix product - # of each covariance and measured the mean of its trace, i.e. - # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) - # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) - return metric - - -class WhiteningPenaltyFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, x: Tensor, num_groups: int, whitening_limit: float, grad_scale: float - ) -> Tensor: - ctx.save_for_backward(x) - ctx.num_groups = num_groups - ctx.whitening_limit = whitening_limit - ctx.grad_scale = grad_scale - return x - - @staticmethod - def backward(ctx, x_grad: Tensor): - (x_orig,) = ctx.saved_tensors - with torch.enable_grad(): - with torch.cuda.amp.autocast(enabled=False): - x_detached = x_orig.to(torch.float32).detach() - x_detached.requires_grad = True - - metric = _whitening_metric(x_detached, ctx.num_groups) - - if random.random() < 0.005 or __name__ == "__main__": - logging.info( - f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " - f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" - ) - - (metric - ctx.whitening_limit).relu().backward() - penalty_grad = x_detached.grad - scale = ctx.grad_scale * ( - x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) - ) - penalty_grad = penalty_grad * scale - return x_grad + penalty_grad.to(x_grad.dtype), None, None, None - - -class Whiten(nn.Module): - def __init__( - self, - num_groups: int, - whitening_limit: float, - prob: Union[float, Tuple[float, float]], - grad_scale: float, - ): - """ - Args: - num_groups: the number of groups to divide the channel dim into before - whitening. We will attempt to make the feature covariance - within each group, after mean subtraction, as "white" as possible, - while having the same trace across all groups. - whitening_limit: a value greater than 1.0, that dictates how much - freedom we have to violate the constraints. 1.0 would mean perfectly - white, with exactly the same trace across groups; larger values - give more freedom. E.g. 2.0. - prob: the probability with which we apply the gradient modification - (also affects the grad scale). May be supplied as a float, - or as a pair (min_prob, max_prob) - - grad_scale: determines the scale on the gradient term from this object, - relative to the rest of the gradient on the attention weights. - E.g. 0.02 (you may want to use smaller values than this if prob is large) - """ - super(Whiten, self).__init__() - assert num_groups >= 1 - assert whitening_limit >= 1 - assert grad_scale >= 0 - self.num_groups = num_groups - self.whitening_limit = whitening_limit - if isinstance(prob, float): - assert 0 < prob <= 1 - self.prob = prob - else: - (self.min_prob, self.max_prob) = prob - assert 0 < self.min_prob < self.max_prob <= 1 - self.prob = self.max_prob - - self.grad_scale = grad_scale - - def forward(self, x: Tensor) -> Tensor: - """ - In the forward pass, this function just returns the input unmodified. - In the backward pass, it will modify the gradients to ensure that the - distribution in each group has close to (lambda times I) as the covariance - after mean subtraction, with the same lambda across groups. - For whitening_limit > 1, there will be more freedom to violate this - constraint. - - Args: - x: the input of shape (*, num_channels) - - Returns: - x, unmodified. You should make sure - you use the returned value, or the graph will be freed - and nothing will happen in backprop. - """ - if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: - return _no_op(x) - else: - if hasattr(self, "min_prob") and random.random() < 0.25: - # occasionally switch between min_prob and max_prob, based on whether - # we are above or below the threshold. - if ( - _whitening_metric(x.to(torch.float32), self.num_groups) - > self.whitening_limit - ): - # there would be a change to the grad. - self.prob = self.max_prob - else: - self.prob = self.min_prob - - return WhiteningPenaltyFunction.apply( - x, self.num_groups, self.whitening_limit, self.grad_scale - ) - - -class WithLoss(torch.autograd.Function): - @staticmethod - def forward(ctx, x: Tensor, y: Tensor): - ctx.y_shape = y.shape - return x - - @staticmethod - def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones( - ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device - ) - - -def with_loss(x, y): - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x - # returns x but adds y.sum() to the loss function. - return WithLoss.apply(x, y) - - -def _no_op(x: Tensor) -> Tensor: - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x - else: - # a no-op function that will have a node in the autograd graph, - # to avoid certain bugs relating to backward hooks - return x.chunk(1, dim=-1)[0] - - -class Identity(torch.nn.Module): - def __init__(self): - super(Identity, self).__init__() - - def forward(self, x): - return _no_op(x) - - -class MaxEig(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to discourage - that any given direction in activation space accounts for more than - a specified proportion of the covariance (e.g. 0.2). - - - Args: - num_channels: the number of channels - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - max_var_per_eig: the maximum proportion of the variance of the - features/channels, after mean subtraction, that can come from - any given eigenvalue. - min_prob: the minimum probability with which we apply this during any invocation - of forward(), assuming last time we applied the constraint it was - not active; supplied for speed. - scale: determines the scale with which we modify the gradients, relative - to the existing / unmodified gradients - """ - - def __init__( - self, - num_channels: int, - channel_dim: int, - max_var_per_eig: float = 0.2, - min_prob: float = 0.01, - scale: float = 0.01, - ): - super(MaxEig, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.scale = scale - assert max_var_per_eig == 0.0 or max_var_per_eig > 1.0 / num_channels - self.max_var_per_eig = max_var_per_eig - - # we figure out the dominant direction using the power method: starting with - # a random vector, keep multiplying by the covariance and renormalizing. - with torch.no_grad(): - # arbitrary.. would use randn() but want to leave the rest of the model's - # random parameters unchanged for comparison - direction = torch.arange(num_channels).to(torch.float) - direction = direction / direction.norm() - self.register_buffer("max_eig_direction", direction) - - self.min_prob = min_prob - # cur_prob is the current probability we'll use to apply the ActivationBalancer. - # We'll regress this towards prob, each tiem we try to apply it and it is not - # active. - self.cur_prob = 1.0 - - def forward(self, x: Tensor) -> Tensor: - if ( - torch.jit.is_scripting() - or self.max_var_per_eig <= 0 - or random.random() > self.cur_prob - or torch.jit.is_tracing() - ): - return _no_op(x) - - with torch.cuda.amp.autocast(enabled=False): - eps = 1.0e-20 - orig_x = x - x = x.to(torch.float32) - with torch.no_grad(): - x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) - x = x - x.mean(dim=0) - new_direction, coeffs = self._find_direction_coeffs( - x, self.max_eig_direction - ) - x_var = (x**2).mean() - x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() - - # `variance_proportion` is the proportion of the variance accounted for - # by the top eigen-direction. - variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) - - # ensure new direction is nonzero even if x == 0, by including `direction`. - self._set_direction(0.1 * self.max_eig_direction + new_direction) - - if random.random() < 0.01 or __name__ == "__main__": - logging.info( - f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" - ) - - if variance_proportion >= self.max_var_per_eig: - # The constraint is active. Note, we should quite rarely - # reach here, only near the beginning of training if we are - # starting to diverge, should this constraint be active. - cur_prob = self.cur_prob - self.cur_prob = 1.0 # next time, do the update with probability 1.0. - return MaxEigLimiterFunction.apply( - orig_x, coeffs, new_direction, self.channel_dim, self.scale - ) - else: - # let self.cur_prob exponentially approach self.min_prob, as - # long as the constraint is inactive. - self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob - return orig_x - - def _set_direction(self, direction: Tensor): - """ - Sets self.max_eig_direction to a normalized version of `direction` - """ - direction = direction.detach() - direction = direction / direction.norm() - direction_sum = direction.sum().item() - if direction_sum - direction_sum == 0: # no inf/nan - self.max_eig_direction[:] = direction - else: - logging.info( - f"Warning: sum of direction in MaxEig is {direction_sum}, " - "num_channels={self.num_channels}, channel_dim={self.channel_dim}" - ) - - def _find_direction_coeffs( - self, x: Tensor, prev_direction: Tensor - ) -> Tuple[Tensor, Tensor, Tensor]: - """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Does not have to be normalized, but should be nonzero. - - Returns: (cur_direction, coeffs), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. - coeffs: a Tensor of shape (num_frames, 1) that minimizes, or - approximately minimizes, (x - coeffs * cur_direction).norm() - """ - (num_frames, num_channels) = x.shape - assert num_channels > 1 and num_frames > 1 - assert prev_direction.shape == (num_channels,) - # `coeffs` are the coefficients of `prev_direction` in x. - # actually represent the coeffs up to a constant positive factor. - coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20) - return cur_direction, coeffs - - -class DoubleSwishFunction(torch.autograd.Function): - """ - double_swish(x) = x * torch.sigmoid(x-1) - This is a definition, originally motivated by its close numerical - similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). - - Memory-efficient derivative computation: - double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) - double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). - Now, s'(x) = s(x) * (1-s(x)). - double_swish'(x) = x * s'(x) + s(x). - = x * s(x) * (1-s(x)) + s(x). - = double_swish(x) * (1-s(x)) + s(x) - ... so we just need to remember s(x) but not x itself. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - requires_grad = x.requires_grad - x_dtype = x.dtype - if x.dtype == torch.float16: - x = x.to(torch.float32) - - s = torch.sigmoid(x - 1.0) - y = x * s - - if requires_grad: - deriv = y * (1 - s) + s - # notes on derivative of x * sigmoid(x - 1): - # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 - # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund - # max \simeq 1.1990. Take ceil to be 1.2 so it's an upper bound. - # the combination of "+ torch.rand_like(deriv)" and casting to torch.uint8 (which - # floors), should be expectation-preserving. - floor = -0.043637 - ceil = 1.2 - d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - deriv - ) - if __name__ == "__main__": - # for self-testing only. - assert d_scaled.min() >= 0.0 - assert d_scaled.max() < 256.0 - d_int = d_scaled.to(torch.uint8) - ctx.save_for_backward(d_int) - if x.dtype == torch.float16 or torch.is_autocast_enabled(): - y = y.to(torch.float16) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors - # the same constants as used in forward pass. - floor = -0.043637 - ceil = 1.2 - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d - - -class DoubleSwish(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). - """ - if torch.jit.is_scripting() or torch.jit.is_tracing(): - return x * torch.sigmoid(x - 1.0) - return DoubleSwishFunction.apply(x) - - -def _test_max_eig(): - for proportion in [0.1, 0.5, 10.0]: - logging.info(f"proportion = {proportion}") - x = torch.randn(100, 128) - direction = torch.randn(128) - coeffs = torch.randn(100, 1) - x += proportion * direction * coeffs - - x.requires_grad = True - - num_channels = 128 - m = MaxEig( - num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig - ) # grad_scale - - for _ in range(4): - y = m(x) - - y_grad = torch.randn_like(x) - y.backward(gradient=y_grad) - - if proportion < 0.2: - assert torch.allclose(x.grad, y_grad, atol=1.0e-02) - elif proportion > 1.0: - assert not torch.allclose(x.grad, y_grad) - - -def _test_whiten(): - for proportion in [0.1, 0.5, 10.0]: - logging.info(f"_test_whiten(): proportion = {proportion}") - x = torch.randn(100, 128) - direction = torch.randn(128) - coeffs = torch.randn(100, 1) - x += proportion * direction * coeffs - - x.requires_grad = True - - num_channels = 128 - m = Whiten( - 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, - ) # grad_scale - - for _ in range(4): - y = m(x) - - y_grad = torch.randn_like(x) - y.backward(gradient=y_grad) - - if proportion < 0.2: - assert torch.allclose(x.grad, y_grad) - elif proportion > 1.0: - assert not torch.allclose(x.grad, y_grad) - - -def _test_activation_balancer_sign(): - probs = torch.arange(0, 1, 0.01) - N = 1000 - x = 1.0 * ((2.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1))) - 1.0) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer( - probs.numel(), - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - max_factor=0.2, - min_abs=0.0, - ) - - y_grad = torch.sign(torch.randn(probs.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_activation_balancer_sign: x = ", x) - print("_test_activation_balancer_sign: y grad = ", y_grad) - print("_test_activation_balancer_sign: x grad = ", x.grad) - - -def _test_activation_balancer_magnitude(): - magnitudes = torch.arange(0, 1, 0.01) - N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer( - magnitudes.numel(), - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - max_factor=0.2, - min_abs=0.2, - max_abs=0.8, - min_prob=1.0, - ) - - y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_activation_balancer_magnitude: x = ", x) - print("_test_activation_balancer_magnitude: y grad = ", y_grad) - print("_test_activation_balancer_magnitude: x grad = ", x.grad) - - -def _test_basic_norm(): - num_channels = 128 - m = BasicNorm(num_channels=num_channels, channel_dim=1) - - x = torch.randn(500, num_channels) - - y = m(x) - - assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() - print("x rms = ", x_rms) - print("y rms = ", y_rms) - assert y_rms < x_rms - assert y_rms > 0.5 * x_rms - - -def _test_double_swish_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 3.0 - x.requires_grad = True - m = DoubleSwish() - - tol = (1.2 - (-0.043637)) / 255.0 - torch.autograd.gradcheck(m, x, atol=tol) - - # for self-test. - x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 - x.requires_grad = True - y = m(x) - - -def _test_softmax(): - a = torch.randn(2, 10, dtype=torch.float64) - b = a.clone() - a.requires_grad = True - b.requires_grad = True - a.softmax(dim=1)[:, 0].sum().backward() - print("a grad = ", a.grad) - softmax(b, dim=1)[:, 0].sum().backward() - print("b grad = ", b.grad) - assert torch.allclose(a.grad, b.grad) - - -def _test_scaled_lstm(): - N, L = 2, 30 - dim_in, dim_hidden = 10, 20 - m = ScaledLSTM(input_size=dim_in, hidden_size=dim_hidden, bias=True) - x = torch.randn(L, N, dim_in) - h0 = torch.randn(1, N, dim_hidden) - c0 = torch.randn(1, N, dim_hidden) - y, (h, c) = m(x, (h0, c0)) - assert y.shape == (L, N, dim_hidden) - assert h.shape == (1, N, dim_hidden) - assert c.shape == (1, N, dim_hidden) - - -def _test_grad_filter(): - threshold = 50.0 - time, batch, channel = 200, 5, 128 - grad_filter = GradientFilter(batch_dim=1, threshold=threshold) - - for i in range(2): - x = torch.randn(time, batch, channel, requires_grad=True) - w = nn.Parameter(torch.ones(5)) - b = nn.Parameter(torch.zeros(5)) - - x_out, w_out, b_out = grad_filter(x, w, b) - - w_out_grad = torch.randn_like(w) - b_out_grad = torch.randn_like(b) - x_out_grad = torch.rand_like(x) - if i % 2 == 1: - # The gradient norm of the first element must be larger than - # `threshold * median`, where `median` is the median value - # of gradient norms of all elements in batch. - x_out_grad[:, 0, :] = torch.full((time, channel), threshold) - - torch.autograd.backward( - [x_out, w_out, b_out], [x_out_grad, w_out_grad, b_out_grad] - ) - - print( - "_test_grad_filter: for gradient norms, the first element > median * threshold ", # noqa - i % 2 == 1, - ) - - print( - "_test_grad_filter: x_out_grad norm = ", - (x_out_grad**2).mean(dim=(0, 2)).sqrt(), - ) - print( - "_test_grad_filter: x.grad norm = ", - (x.grad**2).mean(dim=(0, 2)).sqrt(), - ) - print("_test_grad_filter: w_out_grad = ", w_out_grad) - print("_test_grad_filter: w.grad = ", w.grad) - print("_test_grad_filter: b_out_grad = ", b_out_grad) - print("_test_grad_filter: b.grad = ", b.grad) - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_softmax() - _test_whiten() - _test_max_eig() - _test_activation_balancer_sign() - _test_activation_balancer_magnitude() - _test_basic_norm() - _test_double_swish_deriv() - _test_scaled_lstm() - _test_grad_filter() diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/scaling_converter.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/scaling_converter.py deleted file mode 100644 index 56165d1f9..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/scaling_converter.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file replaces various modules in a model. -Specifically, ActivationBalancer is replaced with an identity operator; -Whiten is also replaced with an identity operator; -BasicNorm is replaced by a module with `exp` removed. -""" - -import copy -from typing import List - -import torch -import torch.nn as nn -from scaling import ActivationBalancer, BasicNorm, Whiten - - -class NonScaledNorm(nn.Module): - """See BasicNorm for doc""" - - def __init__( - self, - num_channels: int, - eps_exp: float, - channel_dim: int = -1, # CAUTION: see documentation. - ): - super().__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.eps_exp = eps_exp - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if not torch.jit.is_tracing(): - assert x.shape[self.channel_dim] == self.num_channels - scales = ( - torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp - ).pow(-0.5) - return x * scales - - -def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm: - assert isinstance(basic_norm, BasicNorm), type(BasicNorm) - norm = NonScaledNorm( - num_channels=basic_norm.num_channels, - eps_exp=basic_norm.eps.data.exp().item(), - channel_dim=basic_norm.channel_dim, - ) - return norm - - -# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa -# get_submodule was added to nn.Module at v1.9.0 -def get_submodule(model, target): - if target == "": - return model - atoms: List[str] = target.split(".") - mod: torch.nn.Module = model - for item in atoms: - if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no " "attribute `" + item + "`" - ) - mod = getattr(mod, item) - if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") - return mod - - -def convert_scaled_to_non_scaled( - model: nn.Module, - inplace: bool = False, -): - """ - Args: - model: - The model to be converted. - inplace: - If True, the input model is modified inplace. - If False, the input model is copied and we modify the copied version. - Return: - Return a model without scaled layers. - """ - if not inplace: - model = copy.deepcopy(model) - - d = {} - for name, m in model.named_modules(): - if isinstance(m, BasicNorm): - d[name] = convert_basic_norm(m) - elif isinstance(m, (ActivationBalancer, Whiten)): - d[name] = nn.Identity() - - for k, v in d.items(): - if "." in k: - parent, child = k.rsplit(".", maxsplit=1) - setattr(get_submodule(model, parent), child, v) - else: - setattr(model, k, v) - - return model diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/streaming_beam_search.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/streaming_beam_search.py deleted file mode 100644 index e6e0fb1c8..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/streaming_beam_search.py +++ /dev/null @@ -1,282 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from typing import List - -import k2 -import torch -import torch.nn as nn -from beam_search import Hypothesis, HypothesisList, get_hyps_shape -from decode_stream import DecodeStream - -from icefall.decode import one_best_decoding -from icefall.utils import get_texts - - -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], -) -> None: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - streams: - A list of Stream objects. - """ - assert len(streams) == encoder_out.size(0) - assert encoder_out.ndim == 3 - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = model.device - T = encoder_out.size(1) - - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - # decoder_out is of shape (N, 1, decoder_out_dim) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - for t in range(T): - # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) - current_encoder_out = encoder_out[:, t : t + 1, :] # noqa - - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - # logits'shape (batch_size, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - streams[i].hyp.append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = torch.tensor( - [stream.hyp[-context_size:] for stream in streams], - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder( - decoder_input, - need_pad=False, - ) - decoder_out = model.joiner.decoder_proj(decoder_out) - - -def modified_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - streams: List[DecodeStream], - num_active_paths: int = 4, -) -> None: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The RNN-T model. - encoder_out: - A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of - the encoder model. - streams: - A list of stream objects. - num_active_paths: - Number of active paths during the beam search. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert len(streams) == encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - device = next(model.parameters()).device - batch_size = len(streams) - T = encoder_out.size(1) - - B = [stream.hyps for stream in streams] - - for t in range(T): - current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.stack( - [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, encoder_out_dim) - - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) - # logits is of shape (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token != blank_id: - new_ys.append(new_token) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) - B[i].add(new_hyp) - - for i in range(batch_size): - streams[i].hyps = B[i] - - -def fast_beam_search_one_best( - model: nn.Module, - encoder_out: torch.Tensor, - processed_lens: torch.Tensor, - streams: List[DecodeStream], - beam: float, - max_states: int, - max_contexts: int, -) -> None: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first generated by Fsa-based beam search, then we get the - recognition by applying shortest path on the lattice. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - processed_lens: - A tensor of shape (N,) containing the number of processed frames - in `encoder_out` before padding. - streams: - A list of stream objects. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - """ - assert encoder_out.ndim == 3 - B, T, C = encoder_out.shape - assert B == len(streams) - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(streams[i].rnnt_decoding_stream) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - log_probs = logits.log_softmax(dim=-1) - decoding_streams.advance(log_probs) - - decoding_streams.terminate_and_flush_to_streams() - - lattice = decoding_streams.format_output(processed_lens.tolist()) - best_path = one_best_decoding(lattice) - hyp_tokens = get_texts(best_path) - - for i in range(B): - streams[i].hyp = hyp_tokens[i] diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/streaming_decode.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/streaming_decode.py deleted file mode 100755 index 7a349ecb2..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/streaming_decode.py +++ /dev/null @@ -1,615 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./pruned_transducer_stateless7_streaming/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --decode-chunk-len 32 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --decoding_method greedy_search \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import sentencepiece as spm -import torch -import torch.nn as nn -from asr_datamodule import LibriSpeechAsrDataModule -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model -from zipformer import stack_states, unstack_states - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - - features = [] - feature_lens = [] - states = [] - processed_lens = [] - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling - # factor in encoders is 8. - # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. - tail_length = 23 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - processed_lens = torch.tensor(processed_lens, device=device) - - encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( - x=features, - x_lens=feature_lens, - states=states, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) - elif params.decoding_method == "fast_beam_search": - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - sp: spm.SentencePieceProcessor, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - log_interval = 50 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = model.encoder.get_init_state(device=device) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) - decode_stream.ground_truth = cut.supervisions[0].text - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - decode_streams[i].ground_truth.split(), - sp.decode(decode_streams[i].decoding_result()).split(), - ) - ) - del decode_streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - return {key: decode_results} - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = ( - params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" - ) - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = ( - params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" - ) - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - LibriSpeechAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - # for streaming - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # and is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - librispeech = LibriSpeechAsrDataModule(args) - - test_clean_cuts = librispeech.test_clean_cuts() - test_other_cuts = librispeech.test_other_cuts() - - test_sets = ["test-clean", "test-other"] - test_cuts = [test_clean_cuts, test_other_cuts] - - for test_set, test_cut in zip(test_sets, test_cuts): - results_dict = decode_dataset( - cuts=test_cut, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/test_model.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/test_model.py deleted file mode 100755 index 5400df804..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/test_model.py +++ /dev/null @@ -1,150 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -To run this file, do: - - cd icefall/egs/librispeech/ASR - python ./pruned_transducer_stateless7_streaming/test_model.py -""" - -import torch -from scaling_converter import convert_scaled_to_non_scaled -from train import get_params, get_transducer_model - - -def test_model(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.num_encoder_layers = "2,4,3,2,4" - params.feedforward_dims = "1024,1024,2048,2048,1024" - params.nhead = "8,8,8,8,8" - params.encoder_dims = "384,384,384,384,384" - params.attention_dims = "192,192,192,192,192" - params.encoder_unmasked_dims = "256,256,256,256,256" - params.zipformer_downsampling_factors = "1,2,4,8,2" - params.cnn_module_kernels = "31,31,31,31,31" - params.decoder_dim = 512 - params.joiner_dim = 512 - params.num_left_chunks = 4 - params.short_chunk_size = 50 - params.decode_chunk_len = 32 - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - # Test jit script - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - print("Using torch.jit.script") - model = torch.jit.script(model) - - -def test_model_jit_trace(): - params = get_params() - params.vocab_size = 500 - params.blank_id = 0 - params.context_size = 2 - params.num_encoder_layers = "2,4,3,2,4" - params.feedforward_dims = "1024,1024,2048,2048,1024" - params.nhead = "8,8,8,8,8" - params.encoder_dims = "384,384,384,384,384" - params.attention_dims = "192,192,192,192,192" - params.encoder_unmasked_dims = "256,256,256,256,256" - params.zipformer_downsampling_factors = "1,2,4,8,2" - params.cnn_module_kernels = "31,31,31,31,31" - params.decoder_dim = 512 - params.joiner_dim = 512 - params.num_left_chunks = 4 - params.short_chunk_size = 50 - params.decode_chunk_len = 32 - model = get_transducer_model(params) - model.eval() - - num_param = sum([p.numel() for p in model.parameters()]) - print(f"Number of model parameters: {num_param}") - - convert_scaled_to_non_scaled(model, inplace=True) - - # Test encoder - def _test_encoder(): - encoder = model.encoder - assert encoder.decode_chunk_size == params.decode_chunk_len // 2, ( - encoder.decode_chunk_size, - params.decode_chunk_len, - ) - T = params.decode_chunk_len + 7 - - x = torch.zeros(1, T, 80, dtype=torch.float32) - x_lens = torch.full((1,), T, dtype=torch.int32) - states = encoder.get_init_state(device=x.device) - encoder.__class__.forward = encoder.__class__.streaming_forward - traced_encoder = torch.jit.trace(encoder, (x, x_lens, states)) - - states1 = encoder.get_init_state(device=x.device) - states2 = traced_encoder.get_init_state(device=x.device) - for i in range(5): - x = torch.randn(1, T, 80, dtype=torch.float32) - x_lens = torch.full((1,), T, dtype=torch.int32) - y1, _, states1 = encoder.streaming_forward(x, x_lens, states1) - y2, _, states2 = traced_encoder(x, x_lens, states2) - assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean()) - - # Test decoder - def _test_decoder(): - decoder = model.decoder - y = torch.zeros(10, decoder.context_size, dtype=torch.int64) - need_pad = torch.tensor([False]) - - traced_decoder = torch.jit.trace(decoder, (y, need_pad)) - d1 = decoder(y, need_pad) - d2 = traced_decoder(y, need_pad) - assert torch.equal(d1, d2), (d1 - d2).abs().mean() - - # Test joiner - def _test_joiner(): - joiner = model.joiner - encoder_out_dim = joiner.encoder_proj.weight.shape[1] - decoder_out_dim = joiner.decoder_proj.weight.shape[1] - encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) - decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) - - traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out)) - j1 = joiner(encoder_out, decoder_out) - j2 = traced_joiner(encoder_out, decoder_out) - assert torch.equal(j1, j2), (j1 - j2).abs().mean() - - _test_encoder() - _test_decoder() - _test_joiner() - - -def main(): - test_model() - test_model_jit_trace() - - -if __name__ == "__main__": - main() diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/train.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/train.py deleted file mode 100755 index 670ade470..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/train.py +++ /dev/null @@ -1,1346 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -cd egs/librispeech/ASR/ -./prepare.sh -./prepare_giga_speech.sh - -./lstm_transducer_stateless2/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir lstm_transducer_stateless2/exp \ - --full-libri 1 \ - --max-duration 300 - -# For mix precision training: - -./lstm_transducer_stateless2/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir lstm_transducer_stateless2/exp \ - --full-libri 1 \ - --max-duration 550 -""" - -import argparse -import copy -import logging -import warnings -from itertools import chain -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import sentencepiece as spm -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import LibrimixAsrDataModule -from decoder import Decoder -from dprnn import DPRNN -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import SURT -from optim import Eden, ScaledAdam -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-mask-encoder-layers", - type=int, - default=4, - help="Number of layers in the DPRNN based mask encoder.", - ) - - parser.add_argument( - "--mask-encoder-dim", - type=int, - default=256, - help="Hidden dimension of the LSTM blocks in DPRNN.", - ) - - parser.add_argument( - "--mask-encoder-segment-size", - type=int, - default=32, - help="Segment size of the SegLSTM in DPRNN. Ideally, this should be equal to the " - "decode-chunk-length of the zipformer encoder.", - ) - - parser.add_argument( - "--chunk-width-randomization", - type=bool, - default=False, - help="Whether to randomize the chunk width in DPRNN.", - ) - - # Zipformer config is based on: - # https://github.com/k2-fsa/icefall/pull/745#issuecomment-1405282740 - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,2,2,2,2", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="768,768,768,768,768", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="256,256,256,256,256", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="192,192,192,192,192", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=50, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=50, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="conv_lstm_transducer_stateless_ctc/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--encoder-init-ckpt", - type=str, - default=None, - help="""The encoder checkpoint to initialize the encoder (recognition module). - If not specified, the encoder is randomly initialized. - """, - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--base-lr", type=float, default=0.004, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=10, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--ctc-loss-scale", - type=float, - default=0.2, - help="Scale for CTC loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=10, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=100, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 2000, - # parameters for SURT - "num_channels": 2, - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed - # parameters for Noam - "model_warm_step": 5000, # arg given to model, not for lrate - # parameters for ctc loss - "beam_size": 10, - "use_double_scores": True, - "env_info": get_env_info(), - } - ) - - return params - - -def get_mask_encoder_model(params: AttributeDict) -> nn.Module: - mask_encoder = DPRNN( - feature_dim=params.feature_dim, - input_size=params.mask_encoder_dim, - hidden_size=params.mask_encoder_dim, - output_size=params.feature_dim * params.num_channels, - segment_size=params.mask_encoder_segment_size, - num_blocks=params.num_mask_encoder_layers, - chunk_width_randomization=params.chunk_width_randomization, - ) - return mask_encoder - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - num_left_chunks=params.num_left_chunks, - short_chunk_size=params.short_chunk_size, - decode_chunk_size=params.decode_chunk_len // 2, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_surt_model( - params: AttributeDict, -) -> nn.Module: - mask_encoder = get_mask_encoder_model(params) - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = SURT( - mask_encoder=mask_encoder, - encoder=encoder, - decoder=decoder, - joiner=joiner, - num_channels=params.num_channels, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"].to(device) - feature_lens = batch["input_lens"].to(device) - - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - - # The dataloader returns text as a list of cuts, each of which is a list of channel - # text. We flatten this to a list where all channels are together, i.e., it looks like - # [utt1_ch1, utt2_ch1, ..., uttN_ch1, utt1_ch2, ...., uttN,ch2]. - text = [val for tup in zip(*batch["text"]) for val in tup] - assert len(text) == len(feature) * params.num_channels - - # Convert all channel texts to token IDs and create a ragged tensor. - y = sp.encode(text, out_type=int) - y = k2.RaggedTensor(y).to(device) - - batch_idx_train = params.batch_idx_train - warm_step = params.model_warm_step - - with torch.set_grad_enabled(is_training): - (simple_loss, pruned_loss, ctc_loss,) = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - reduction="none", - subsampling_factor=params.subsampling_factor, - ) - simple_loss_is_finite = torch.isfinite(simple_loss) - pruned_loss_is_finite = torch.isfinite(pruned_loss) - ctc_loss_is_finite = torch.isfinite(ctc_loss) - is_finite = simple_loss_is_finite & pruned_loss_is_finite & ctc_loss_is_finite - if not torch.all(is_finite): - logging.info( - "Not all losses are finite!\n" - f"simple_losses: {simple_loss}\n" - f"pruned_losses: {pruned_loss}\n" - f"ctc_losses: {ctc_loss}\n" - ) - display_and_save_batch(batch, params=params, sp=sp) - simple_loss = simple_loss[simple_loss_is_finite] - pruned_loss = pruned_loss[pruned_loss_is_finite] - ctc_loss = ctc_loss[ctc_loss_is_finite] - - # If either all simple_loss or pruned_loss is inf or nan, - # we stop the training process by raising an exception - if ( - torch.all(~simple_loss_is_finite) - or torch.all(~pruned_loss_is_finite) - or torch.all(~ctc_loss_is_finite) - ): - raise ValueError( - "There are too many utterances in this batch " - "leading to inf or nan losses." - ) - - simple_loss_sum = simple_loss.sum() - pruned_loss_sum = pruned_loss.sum() - ctc_loss_sum = ctc_loss.sum() - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - loss = ( - simple_loss_scale * simple_loss_sum - + pruned_loss_scale * pruned_loss_sum - + params.ctc_loss_scale * ctc_loss_sum - ) - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - # info["frames"] is an approximate number for two reasons: - # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 - # (2) If some utterances in the batch lead to inf/nan loss, they - # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa - info["utterances"] = feature.size(0) - # averaged input duration in frames over utterances - info["utt_duration"] = feature_lens.sum().item() - # averaged padding proportion over utterances - info["utt_pad_proportion"] = ( - ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() - ) - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss_sum.detach().cpu().item() - info["pruned_loss"] = pruned_loss_sum.detach().cpu().item() - info["ctc_loss"] = ctc_loss_sum.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: spm.SentencePieceProcessor, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: spm.SentencePieceProcessor, - train_dl: torch.utils.data.DataLoader, - train_dl_warmup: Optional[torch.utils.data.DataLoader], - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - train_dl_warmup: - Dataloader for the training dataset with 2 speakers. This is used during the - warmup stage. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - torch.cuda.empty_cache() - model.train() - - tot_loss = MetricsTracker() - - iter_train = iter(train_dl) - iter_train_warmup = iter(train_dl_warmup) if train_dl_warmup is not None else None - - batch_idx = 0 - - while True: - # We first sample a batch from the main dataset. This is because we want to - # make sure all epochs have the same number of batches. - try: - batch = next(iter_train) - except StopIteration: - break - - # If we are in warmup stage, get the batch from the 2spk dataset. - if ( - params.batch_idx_train <= params.model_warm_step - and iter_train_warmup is not None - ): - try: - batch = next(iter_train_warmup) - except StopIteration: - iter_train_warmup = iter(train_dl_warmup) - batch = next(iter_train_warmup) - - batch_idx += 1 - - params.batch_idx_train += 1 - batch_size = batch["inputs"].shape[0] - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, sp=sp) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - params.cur_batch_idx = batch_idx - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - del params.cur_batch_idx - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", cur_grad_scale, params.batch_idx_train - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = spm.SentencePieceProcessor() - sp.load(params.bpe_model) - - # is defined in local/train_bpe_model.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_surt_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - - if checkpoints is None and params.encoder_init_ckpt is not None: - logging.info("Initializing encoder with checkpoint") - init_ckpt = torch.load(params.encoder_init_ckpt, map_location=device) - model.load_state_dict(init_ckpt["model"], strict=False) - - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - diagnostic = diagnostics.attach_diagnostics(model) - - librimix = LibrimixAsrDataModule(args) - - train_cuts = librimix.train_cuts(reverberated=False) - train_cuts_2spk = librimix.train_cuts_2spk(reverberated=False) - # dev_cuts = librimix.dev_cuts(reverberated=False) - dev_cuts = librimix.libricss_cuts(split="dev", type="ihm-mix") - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = librimix.train_dataloaders( - train_cuts, - sampler_state_dict=sampler_state_dict, - ) - train_dl_2spk = librimix.train_dataloaders(train_cuts_2spk) - valid_dl = librimix.valid_dataloaders(dev_cuts) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - train_dl_warmup=train_dl_2spk, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: spm.SentencePieceProcessor, - params: AttributeDict, - warmup: float, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - raise - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: spm.SentencePieceProcessor, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = [sp.encode(text_ch) for text_ch in batch["text"]] - num_tokens = [sum(len(yi) for yi in y_ch) for y_ch in y] - logging.info(f"num tokens: {num_tokens}") - - -def main(): - parser = get_parser() - LibrimixAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/zipformer.py b/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/zipformer.py deleted file mode 100644 index e13629384..000000000 --- a/egs/libricss/SURT/dprnn_pruned_transducer_stateless7/zipformer.py +++ /dev/null @@ -1,2881 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import itertools -import logging -import math -import random -import warnings -from typing import List, Optional, Tuple, Union - -import torch -from encoder_interface import EncoderInterface -from scaling import ( - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. -) -from scaling import ( - ActivationBalancer, - BasicNorm, - DoubleSwish, - Identity, - MaxEig, - ScaledConv1d, - Whiten, - _diag, - penalize_abs_values_gt, - random_clamp, - softmax, -) -from torch import Tensor, nn - -from icefall.dist import get_rank -from icefall.utils import make_pad_mask, subsequent_chunk_mask - - -def stack_states(state_list: List[List[Tensor]]) -> List[Tensor]: - """Stack list of zipformer states that correspond to separate utterances - into a single emformer state, so that it can be used as an input for - zipformer when those utterances are formed into a batch. - - Note: - It is the inverse of :func:`unstack_states`. - - Args: - state_list: - Each element in state_list corresponding to the internal state - of the zipformer model for a single utterance. - ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. - ``states[i][0:num_encoders]`` is the cached numbers of past frames. - ``states[i][num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[i][2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[i][3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[i][4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[i][5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[i][6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - - Returns: - A new state corresponding to a batch of utterances. - See the input argument of :func:`unstack_states` for the meaning - of the returned tensor. - """ - batch_size = len(state_list) - assert len(state_list[0]) % 7 == 0, len(state_list[0]) - num_encoders = len(state_list[0]) // 7 - - cached_len = [] - cached_avg = [] - cached_key = [] - cached_val = [] - cached_val2 = [] - cached_conv1 = [] - cached_conv2 = [] - - # For cached_len - len_list = [state_list[n][0:num_encoders] for n in range(batch_size)] - for i in range(num_encoders): - # len_avg: (num_layers, batch_size) - len_avg = torch.cat([len_list[n][i] for n in range(batch_size)], dim=1) - cached_len.append(len_avg) - - # For cached_avg - avg_list = [ - state_list[n][num_encoders : 2 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # avg: (num_layers, batch_size, D) - avg = torch.cat([avg_list[n][i] for n in range(batch_size)], dim=1) - cached_avg.append(avg) - - # For cached_key - key_list = [ - state_list[n][2 * num_encoders : 3 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # key: (num_layers, left_context_size, batch_size, D) - key = torch.cat([key_list[n][i] for n in range(batch_size)], dim=2) - cached_key.append(key) - - # For cached_val - val_list = [ - state_list[n][3 * num_encoders : 4 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # val: (num_layers, left_context_size, batch_size, D) - val = torch.cat([val_list[n][i] for n in range(batch_size)], dim=2) - cached_val.append(val) - - # For cached_val2 - val2_list = [ - state_list[n][4 * num_encoders : 5 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # val2: (num_layers, left_context_size, batch_size, D) - val2 = torch.cat([val2_list[n][i] for n in range(batch_size)], dim=2) - cached_val2.append(val2) - - # For cached_conv1 - conv1_list = [ - state_list[n][5 * num_encoders : 6 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # conv1: (num_layers, batch_size, D, kernel-1) - conv1 = torch.cat([conv1_list[n][i] for n in range(batch_size)], dim=1) - cached_conv1.append(conv1) - - # For cached_conv2 - conv2_list = [ - state_list[n][6 * num_encoders : 7 * num_encoders] for n in range(batch_size) - ] - for i in range(num_encoders): - # conv2: (num_layers, batch_size, D, kernel-1) - conv2 = torch.cat([conv2_list[n][i] for n in range(batch_size)], dim=1) - cached_conv2.append(conv2) - - states = ( - cached_len - + cached_avg - + cached_key - + cached_val - + cached_val2 - + cached_conv1 - + cached_conv2 - ) - return states - - -def unstack_states(states: List[Tensor]) -> List[List[Tensor]]: - """Unstack the zipformer state corresponding to a batch of utterances - into a list of states, where the i-th entry is the state from the i-th - utterance in the batch. - - Note: - It is the inverse of :func:`stack_states`. - - Args: - states: - A list of 7 * num_encoders elements: - ``states[0:num_encoders]`` is the cached numbers of past frames. - ``states[num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - - Returns: - A list of states. - ``states[i]`` is a list of 7 * num_encoders elements of i-th utterance. - """ - assert len(states) % 7 == 0, len(states) - num_encoders = len(states) // 7 - ( - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) = (states[i * num_encoders : (i + 1) * num_encoders] for i in range(7)) - - batch_size = cached_len[0].shape[1] - - len_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_len[i]: (num_layers, batch_size) - len_avg = cached_len[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - len_list[n].append(len_avg[n]) - - avg_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_avg[i]: (num_layers, batch_size, D) - avg = cached_avg[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - avg_list[n].append(avg[n]) - - key_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_key[i]: (num_layers, left_context, batch_size, D) - key = cached_key[i].chunk(chunks=batch_size, dim=2) - for n in range(batch_size): - key_list[n].append(key[n]) - - val_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_val[i]: (num_layers, left_context, batch_size, D) - val = cached_val[i].chunk(chunks=batch_size, dim=2) - for n in range(batch_size): - val_list[n].append(val[n]) - - val2_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_val2[i]: (num_layers, left_context, batch_size, D) - val2 = cached_val2[i].chunk(chunks=batch_size, dim=2) - for n in range(batch_size): - val2_list[n].append(val2[n]) - - conv1_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_conv1[i]: (num_layers, batch_size, D, kernel-1) - conv1 = cached_conv1[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - conv1_list[n].append(conv1[n]) - - conv2_list = [[] for _ in range(batch_size)] - for i in range(num_encoders): - # cached_conv2[i]: (num_layers, batch_size, D, kernel-1) - conv2 = cached_conv2[i].chunk(chunks=batch_size, dim=1) - for n in range(batch_size): - conv2_list[n].append(conv2[n]) - - state_list = [ - ( - len_list[i] - + avg_list[i] - + key_list[i] - + val_list[i] - + val2_list[i] - + conv1_list[i] - + conv2_list[i] - ) - for i in range(batch_size) - ] - return state_list - - -class Zipformer(EncoderInterface): - """ - Args: - num_features (int): Number of input features - d_model: (int,int): embedding dimension of 2 encoder stacks - attention_dim: (int,int): attention dimension of 2 encoder stacks - nhead (int, int): number of heads - dim_feedforward (int, int): feedforward dimension in 2 encoder stacks - num_encoder_layers (int): number of encoder layers - dropout (float): dropout rate - cnn_module_kernel (int): Kernel size of convolution module - vgg_frontend (bool): whether to use vgg frontend. - warmup_batches (float): number of batches to warm up over - """ - - def __init__( - self, - num_features: int, - output_downsampling_factor: int = 2, - encoder_dims: Tuple[int] = (384, 384), - attention_dim: Tuple[int] = (256, 256), - encoder_unmasked_dims: Tuple[int] = (256, 256), - zipformer_downsampling_factors: Tuple[int] = (2, 4), - nhead: Tuple[int] = (8, 8), - feedforward_dim: Tuple[int] = (1536, 2048), - num_encoder_layers: Tuple[int] = (12, 12), - dropout: float = 0.1, - cnn_module_kernels: Tuple[int] = (31, 31), - pos_dim: int = 4, - num_left_chunks: int = 4, - short_chunk_threshold: float = 0.75, - short_chunk_size: int = 50, - decode_chunk_size: int = 16, - warmup_batches: float = 4000.0, - ) -> None: - super(Zipformer, self).__init__() - - self.num_features = num_features - assert 0 < encoder_dims[0] <= encoder_dims[1] - self.encoder_dims = encoder_dims - self.encoder_unmasked_dims = encoder_unmasked_dims - self.zipformer_downsampling_factors = zipformer_downsampling_factors - self.output_downsampling_factor = output_downsampling_factor - - self.num_left_chunks = num_left_chunks - self.short_chunk_threshold = short_chunk_threshold - self.short_chunk_size = short_chunk_size - - # Used in decoding - self.decode_chunk_size = decode_chunk_size - - # will be written to, see set_batch_count() - self.batch_count = 0 - self.warmup_end = warmup_batches - - for u, d in zip(encoder_unmasked_dims, encoder_dims): - assert u <= d, (u, d) - - # self.encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, (T - 7)//2, encoder_dims). - # That is, it does two things simultaneously: - # (1) subsampling: T -> (T - 7)//2 - # (2) embedding: num_features -> encoder_dims - self.encoder_embed = Conv2dSubsampling( - num_features, encoder_dims[0], dropout=dropout - ) - - # each one will be ZipformerEncoder or DownsampledZipformerEncoder - encoders = [] - - self.num_encoders = len(encoder_dims) - for i in range(self.num_encoders): - encoder_layer = ZipformerEncoderLayer( - encoder_dims[i], - attention_dim[i], - nhead[i], - feedforward_dim[i], - dropout, - cnn_module_kernels[i], - pos_dim, - ) - - # For the segment of the warmup period, we let the Conv2dSubsampling - # layer learn something. Then we start to warm up the other encoders. - encoder = ZipformerEncoder( - encoder_layer, - num_encoder_layers[i], - dropout, - warmup_begin=warmup_batches * (i + 1) / (self.num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (self.num_encoders + 1), - ) - - if zipformer_downsampling_factors[i] != 1: - encoder = DownsampledZipformerEncoder( - encoder, - input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], - output_dim=encoder_dims[i], - downsample=zipformer_downsampling_factors[i], - ) - encoders.append(encoder) - self.encoders = nn.ModuleList(encoders) - - # initializes self.skip_layers and self.skip_modules - self._init_skip_modules() - - self.downsample_output = AttentionDownsample( - encoder_dims[-1], encoder_dims[-1], downsample=output_downsampling_factor - ) - - def _get_layer_skip_dropout_prob(self): - if not self.training: - return 0.0 - batch_count = self.batch_count - min_dropout_prob = 0.025 - - if batch_count > self.warmup_end: - return min_dropout_prob - else: - return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob) - - def _init_skip_modules(self): - """ - If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer - indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of - layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2, - we combine the outputs of layers 1 and 5. - """ - skip_layers = [] - skip_modules = [] - z = self.zipformer_downsampling_factors - for i in range(len(z)): - if i <= 1 or z[i - 1] <= z[i]: - skip_layers.append(None) - skip_modules.append(SimpleCombinerIdentity()) - else: - # TEMP - for j in range(i - 2, -1, -1): - if z[j] <= z[i] or j == 0: - # TEMP logging statement. - logging.info( - f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " - f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}." - ) - skip_layers.append(j) - skip_modules.append( - SimpleCombiner( - self.encoder_dims[j], - self.encoder_dims[i - 1], - min_weight=(0.0, 0.25), - ) - ) - break - self.skip_layers = skip_layers - self.skip_modules = nn.ModuleList(skip_modules) - - def get_feature_masks(self, x: torch.Tensor) -> List[float]: - # Note: The actual return type is Union[List[float], List[Tensor]], - # but to make torch.jit.script() work, we use List[float] - """ - In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of - randomized feature masks, one per encoder. - On e.g. 15% of frames, these masks will zero out all encoder dims larger than - some supplied number, e.g. >256, so in effect on those frames we are using - a smaller encoder dim. - - We generate the random masks at this level because we want the 2 masks to 'agree' - all the way up the encoder stack. This will mean that the 1st mask will have - mask values repeated self.zipformer_downsampling_factors times. - - Args: - x: the embeddings (needed for the shape and dtype and device), of shape - (num_frames, batch_size, encoder_dims0) - """ - num_encoders = len(self.encoder_dims) - if torch.jit.is_scripting() or not self.training: - return [1.0] * num_encoders - - (num_frames0, batch_size, _encoder_dims0) = x.shape - - assert self.encoder_dims[0] == _encoder_dims0, ( - self.encoder_dims, - _encoder_dims0, - ) - - max_downsampling_factor = max(self.zipformer_downsampling_factors) - - num_frames_max = num_frames0 + max_downsampling_factor - 1 - - feature_mask_dropout_prob = 0.15 - - # frame_mask_max shape: (num_frames_max, batch_size, 1) - frame_mask_max = ( - torch.rand(num_frames_max, batch_size, 1, device=x.device) - > feature_mask_dropout_prob - ).to(x.dtype) - - feature_masks = [] - for i in range(num_encoders): - ds = self.zipformer_downsampling_factors[i] - upsample_factor = max_downsampling_factor // ds - - frame_mask = ( - frame_mask_max.unsqueeze(1) - .expand(num_frames_max, upsample_factor, batch_size, 1) - .reshape(num_frames_max * upsample_factor, batch_size, 1) - ) - num_frames = (num_frames0 + ds - 1) // ds - frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones( - num_frames, - batch_size, - self.encoder_dims[i], - dtype=x.dtype, - device=x.device, - ) - u = self.encoder_unmasked_dims[i] - feature_mask[:, :, u:] *= frame_mask - feature_masks.append(feature_mask) - - return feature_masks - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - chunk_size: - The chunk size used in evaluation mode. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - """ - x = self.encoder_embed(x) - - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - lengths = (x_lens - 7) >> 1 - assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) - mask = make_pad_mask(lengths) - - outputs = [] - feature_masks = self.get_feature_masks(x) - - if self.training: - # Training mode - max_ds = max(self.zipformer_downsampling_factors) - # Generate dynamic chunk-wise attention mask during training - max_len = x.size(0) // max_ds - short_chunk_size = self.short_chunk_size // max_ds - chunk_size = torch.randint(1, max_len, (1,)).item() - if chunk_size > (max_len * self.short_chunk_threshold): - # Full attention - chunk_size = x.size(0) - else: - # Chunk-wise attention - chunk_size = chunk_size % short_chunk_size + 1 - chunk_size *= max_ds - else: - chunk_size = self.decode_chunk_size - # Evaluation mode - for ds in self.zipformer_downsampling_factors: - assert chunk_size % ds == 0, (chunk_size, ds) - - attn_mask = ~subsequent_chunk_mask( - size=x.size(0), - chunk_size=chunk_size, - num_left_chunks=self.num_left_chunks, - device=x.device, - ) - - for i, (module, skip_module) in enumerate( - zip(self.encoders, self.skip_modules) - ): - ds = self.zipformer_downsampling_factors[i] - k = self.skip_layers[i] - if isinstance(k, int): - layer_skip_dropout_prob = self._get_layer_skip_dropout_prob() - if torch.jit.is_scripting(): - x = skip_module(outputs[k], x) - elif (not self.training) or random.random() > layer_skip_dropout_prob: - x = skip_module(outputs[k], x) - x = module( - x, - feature_mask=feature_masks[i], - src_key_padding_mask=None if mask is None else mask[..., ::ds], - attn_mask=attn_mask[::ds, ::ds], - ) - outputs.append(x) - - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2, self.output_downsampling_factor - lengths = (lengths + 1) >> 1 - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return x, lengths - - def streaming_forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - states: List[Tensor], - ) -> Tuple[Tensor, Tensor, List[Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - seq_len is the input chunk length. - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - states: - A list of 7 * num_encoders elements: - ``states[0:num_encoders]`` is the cached numbers of past frames. - ``states[num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - - Returns: - Return a tuple containing 3 tensors: - - embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1]) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - - updated states. - """ - assert len(states) == 7 * self.num_encoders, (len(states), self.num_encoders) - - cached_len = states[: self.num_encoders] - cached_avg = states[self.num_encoders : 2 * self.num_encoders] - cached_key = states[2 * self.num_encoders : 3 * self.num_encoders] - cached_val = states[3 * self.num_encoders : 4 * self.num_encoders] - cached_val2 = states[4 * self.num_encoders : 5 * self.num_encoders] - cached_conv1 = states[5 * self.num_encoders : 6 * self.num_encoders] - cached_conv2 = states[6 * self.num_encoders : 7 * self.num_encoders] - - x = self.encoder_embed(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - lengths = (x_lens - 7) >> 1 - assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) - - outputs = [] - new_cached_len = [] - new_cached_avg = [] - new_cached_key = [] - new_cached_val = [] - new_cached_val2 = [] - new_cached_conv1 = [] - new_cached_conv2 = [] - - for i, (module, skip_module) in enumerate( - zip(self.encoders, self.skip_modules) - ): - k = self.skip_layers[i] - if isinstance(k, int): - x = skip_module(outputs[k], x) - x, len_avg, avg, key, val, val2, conv1, conv2 = module.streaming_forward( - x, - cached_len=cached_len[i], - cached_avg=cached_avg[i], - cached_key=cached_key[i], - cached_val=cached_val[i], - cached_val2=cached_val2[i], - cached_conv1=cached_conv1[i], - cached_conv2=cached_conv2[i], - ) - outputs.append(x) - # Update caches - new_cached_len.append(len_avg) - new_cached_avg.append(avg) - new_cached_key.append(key) - new_cached_val.append(val) - new_cached_val2.append(val2) - new_cached_conv1.append(conv1) - new_cached_conv2.append(conv2) - - x = self.downsample_output(x) - # class Downsample has this rounding behavior.. - assert self.output_downsampling_factor == 2, self.output_downsampling_factor - lengths = (lengths + 1) >> 1 - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - new_states = ( - new_cached_len - + new_cached_avg - + new_cached_key - + new_cached_val - + new_cached_val2 - + new_cached_conv1 - + new_cached_conv2 - ) - return x, lengths, new_states - - @torch.jit.export - def get_init_state( - self, - device: torch.device = torch.device("cpu"), - ) -> List[Tensor]: - """Get initial states. - A list of 7 * num_encoders elements: - ``states[0:num_encoders]`` is the cached numbers of past frames. - ``states[num_encoders:2*num_encoders]`` is the cached average tensors. - ``states[2*num_encoders:3*num_encoders]`` is the cached key tensors of the first attention modules. - ``states[3*num_encoders:4*num_encoders]`` is the cached value tensors of the first attention modules. - ``states[4*num_encoders:5*num_encoders]`` is the cached value tensors of the second attention modules. - ``states[5*num_encoders:6*num_encoders]`` is the cached left contexts of the first convolution modules. - ``states[6*num_encoders:7*num_encoders]`` is the cached left contexts of the second convolution modules. - """ - cached_len = [] - cached_avg = [] - cached_key = [] - cached_val = [] - cached_val2 = [] - cached_conv1 = [] - cached_conv2 = [] - - left_context_len = self.decode_chunk_size * self.num_left_chunks - - for i, encoder in enumerate(self.encoders): - num_layers = encoder.num_layers - ds = self.zipformer_downsampling_factors[i] - - len_avg = torch.zeros(num_layers, 1, dtype=torch.int32, device=device) - cached_len.append(len_avg) - - avg = torch.zeros(num_layers, 1, encoder.d_model, device=device) - cached_avg.append(avg) - - key = torch.zeros( - num_layers, - left_context_len // ds, - 1, - encoder.attention_dim, - device=device, - ) - cached_key.append(key) - - val = torch.zeros( - num_layers, - left_context_len // ds, - 1, - encoder.attention_dim // 2, - device=device, - ) - cached_val.append(val) - - val2 = torch.zeros( - num_layers, - left_context_len // ds, - 1, - encoder.attention_dim // 2, - device=device, - ) - cached_val2.append(val2) - - conv1 = torch.zeros( - num_layers, - 1, - encoder.d_model, - encoder.cnn_module_kernel - 1, - device=device, - ) - cached_conv1.append(conv1) - - conv2 = torch.zeros( - num_layers, - 1, - encoder.d_model, - encoder.cnn_module_kernel - 1, - device=device, - ) - cached_conv2.append(conv2) - - states = ( - cached_len - + cached_avg - + cached_key - + cached_val - + cached_val2 - + cached_conv1 - + cached_conv2 - ) - return states - - -class ZipformerEncoderLayer(nn.Module): - """ - ZipformerEncoderLayer is made up of self-attn, feedforward and convolution networks. - - Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - feedforward_dim: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - - Examples:: - >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - d_model: int, - attention_dim: int, - nhead: int, - feedforward_dim: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - pos_dim: int = 4, - ) -> None: - super(ZipformerEncoderLayer, self).__init__() - - self.d_model = d_model - self.attention_dim = attention_dim - self.cnn_module_kernel = cnn_module_kernel - - # will be written to, see set_batch_count() - self.batch_count = 0 - - self.self_attn = RelPositionMultiheadAttention( - d_model, - attention_dim, - nhead, - pos_dim, - dropout=0.0, - ) - - self.pooling = PoolingModule(d_model) - - self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) - - self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) - - self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) - - self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) - - self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) - - self.norm_final = BasicNorm(d_model) - - self.bypass_scale = nn.Parameter(torch.tensor(0.5)) - - # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = ActivationBalancer( - d_model, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, - max_abs=6.0, - ) - self.whiten = Whiten( - num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01 - ) - - def get_bypass_scale(self): - if torch.jit.is_scripting() or not self.training: - return self.bypass_scale - if random.random() < 0.1: - # ensure we get grads if self.bypass_scale becomes out of range - return self.bypass_scale - # hardcode warmup period for bypass scale - warmup_period = 20000.0 - initial_clamp_min = 0.75 - final_clamp_min = 0.25 - if self.batch_count > warmup_period: - clamp_min = final_clamp_min - else: - clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( - initial_clamp_min - final_clamp_min - ) - return self.bypass_scale.clamp(min=clamp_min, max=1.0) - - def get_dynamic_dropout_rate(self): - # return dropout rate for the dynamic modules (self_attn, pooling, convolution); this - # starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable - # at the beginning, by making the network focus on the feedforward modules. - if torch.jit.is_scripting() or not self.training: - return 0.0 - warmup_period = 2000.0 - initial_dropout_rate = 0.2 - final_dropout_rate = 0.0 - if self.batch_count > warmup_period: - return final_dropout_rate - else: - return initial_dropout_rate - ( - initial_dropout_rate * final_dropout_rate - ) * (self.batch_count / warmup_period) - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - batch_split: if not None, this layer will only be applied to - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number - """ - src_orig = src - - # macaron style feed forward module - src = src + self.feed_forward1(src) - - # dropout rate for submodules that interact with time. - dynamic_dropout = self.get_dynamic_dropout_rate() - - # pooling module - if torch.jit.is_scripting(): - src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) - elif random.random() >= dynamic_dropout: - src = src + self.pooling(src, src_key_padding_mask=src_key_padding_mask) - - if torch.jit.is_scripting(): - src_att, attn_weights = self.self_attn( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - src = src + src_att - - src = src + self.conv_module1( - src, src_key_padding_mask=src_key_padding_mask - ) - - src = src + self.feed_forward2(src) - - src = src + self.self_attn.forward2(src, attn_weights) - - src = src + self.conv_module2( - src, src_key_padding_mask=src_key_padding_mask - ) - else: - use_self_attn = random.random() >= dynamic_dropout - if use_self_attn: - src_att, attn_weights = self.self_attn( - src, - pos_emb=pos_emb, - attn_mask=attn_mask, - key_padding_mask=src_key_padding_mask, - ) - src = src + src_att - - if random.random() >= dynamic_dropout: - src = src + self.conv_module1( - src, src_key_padding_mask=src_key_padding_mask - ) - - src = src + self.feed_forward2(src) - - if use_self_attn: - src = src + self.self_attn.forward2(src, attn_weights) - - if random.random() >= dynamic_dropout: - src = src + self.conv_module2( - src, src_key_padding_mask=src_key_padding_mask - ) - - src = src + self.feed_forward3(src) - - src = self.norm_final(self.balancer(src)) - - delta = src - src_orig - - src = src_orig + delta * self.get_bypass_scale() - - return self.whiten(src) - - def streaming_forward( - self, - src: Tensor, - pos_emb: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - cached_key: Tensor, - cached_val: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - cached_len: processed number of past frames. - cached_avg: cached average of past frames. - cached_key: cached key tensor of left context for the first attention module. - cached_val: cached value tensor of left context for the first attention module. - cached_val2: cached value tensor of left context for the second attention module. - cached_conv1: cached left context for the first convolution module. - cached_conv2: cached left context for the second convolution module. - - Shape: - src: (S, N, E). - pos_emb: (N, left_context_len+2*S-1, E) - cached_len: (N,) - N is the batch size. - cached_avg: (N, C). - N is the batch size, C is the feature dimension. - cached_key: (left_context_len, N, K). - N is the batch size, K is the key dimension. - cached_val: (left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_val2: (left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_conv1: (N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - cached_conv2: (N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - """ - src_orig = src - - # macaron style feed forward module - src = src + self.feed_forward1(src) - - src_pool, cached_len, cached_avg = self.pooling.streaming_forward( - src, - cached_len=cached_len, - cached_avg=cached_avg, - ) - src = src + src_pool - - ( - src_attn, - attn_weights, - cached_key, - cached_val, - ) = self.self_attn.streaming_forward( - src, - pos_emb=pos_emb, - cached_key=cached_key, - cached_val=cached_val, - ) - src = src + src_attn - - src_conv, cached_conv1 = self.conv_module1.streaming_forward( - src, - cache=cached_conv1, - ) - src = src + src_conv - - src = src + self.feed_forward2(src) - - src_attn, cached_val2 = self.self_attn.streaming_forward2( - src, - attn_weights, - cached_val=cached_val2, - ) - src = src + src_attn - - src_conv, cached_conv2 = self.conv_module2.streaming_forward( - src, - cache=cached_conv2, - ) - src = src + src_conv - - src = src + self.feed_forward3(src) - - src = self.norm_final(self.balancer(src)) - - delta = src - src_orig - - src = src_orig + delta * self.bypass_scale - - return ( - src, - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) - - -class ZipformerEncoder(nn.Module): - r"""ZipformerEncoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the ZipformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - - Examples:: - >>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8) - >>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> out = zipformer_encoder(src) - """ - - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dropout: float, - warmup_begin: float, - warmup_end: float, - ) -> None: - super().__init__() - # will be written to, see set_batch_count() Note: in inference time this - # may be zero but should be treated as large, we can check if - # self.training is true. - self.batch_count = 0 - self.warmup_begin = warmup_begin - self.warmup_end = warmup_end - # module_seed is for when we need a random number that is unique to the module but - # shared across jobs. It's used to randomly select how many layers to drop, - # so that we can keep this consistent across worker tasks (for efficiency). - self.module_seed = torch.randint(0, 1000, ()).item() - - self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) - - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - self.d_model = encoder_layer.d_model - self.attention_dim = encoder_layer.attention_dim - self.cnn_module_kernel = encoder_layer.cnn_module_kernel - - assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - - delta = (1.0 / num_layers) * (warmup_end - warmup_begin) - cur_begin = warmup_begin - for i in range(num_layers): - self.layers[i].warmup_begin = cur_begin - cur_begin += delta - self.layers[i].warmup_end = cur_begin - - def get_layers_to_drop(self, rnd_seed: int): - ans = set() - if not self.training: - return ans - - batch_count = self.batch_count - num_layers = len(self.layers) - - def get_layerdrop_prob(layer: int) -> float: - layer_warmup_begin = self.layers[layer].warmup_begin - layer_warmup_end = self.layers[layer].warmup_end - - initial_layerdrop_prob = 0.5 - final_layerdrop_prob = 0.05 - - if batch_count == 0: - # As a special case, if batch_count == 0, return 0 (drop no - # layers). This is rather ugly, I'm afraid; it is intended to - # enable our scan_pessimistic_batches_for_oom() code to work correctly - # so if we are going to get OOM it will happen early. - # also search for 'batch_count' with quotes in this file to see - # how we initialize the warmup count to a random number between - # 0 and 10. - return 0.0 - elif batch_count < layer_warmup_begin: - return initial_layerdrop_prob - elif batch_count > layer_warmup_end: - return final_layerdrop_prob - else: - # linearly interpolate - t = (batch_count - layer_warmup_begin) / layer_warmup_end - assert 0.0 <= t < 1.001, t - return initial_layerdrop_prob + t * ( - final_layerdrop_prob - initial_layerdrop_prob - ) - - shared_rng = random.Random(batch_count + self.module_seed) - independent_rng = random.Random(rnd_seed) - - layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] - tot = sum(layerdrop_probs) - # Instead of drawing the samples independently, we first randomly decide - # how many layers to drop out, using the same random number generator between - # jobs so that all jobs drop out the same number (this is for speed). - # Then we use an approximate approach to drop out the individual layers - # with their specified probs while reaching this exact target. - num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot))) - - layers = list(range(num_layers)) - independent_rng.shuffle(layers) - - # go through the shuffled layers until we get the required number of samples. - if num_to_drop > 0: - for layer in itertools.cycle(layers): - if independent_rng.random() < layerdrop_probs[layer]: - ans.add(layer) - if len(ans) == num_to_drop: - break - if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info( - f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " - f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}" - ) - return ans - - def forward( - self, - src: Tensor, - # Note: The type of feature_mask should be Union[float, Tensor], - # but to make torch.jit.script() work, we use `float` here - feature_mask: float = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer. - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - Returns: (x, x_no_combine), both of shape (S, N, E) - """ - pos_emb = self.encoder_pos(src) - output = src - - if torch.jit.is_scripting(): - layers_to_drop = [] - else: - rnd_seed = src.numel() + random.randint(0, 1000) - layers_to_drop = self.get_layers_to_drop(rnd_seed) - - output = output * feature_mask - - for i, mod in enumerate(self.layers): - if not torch.jit.is_scripting(): - if i in layers_to_drop: - continue - output = mod( - output, - pos_emb, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - - output = output * feature_mask - - return output - - @torch.jit.export - def streaming_forward( - self, - src: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - cached_key: Tensor, - cached_val: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - cached_len: number of past frames. - cached_avg: cached average of past frames. - cached_key: cached key tensor for first attention module. - cached_val: cached value tensor for first attention module. - cached_val2: cached value tensor for second attention module. - cached_conv1: cached left contexts for the first convolution module. - cached_conv2: cached left contexts for the second convolution module. - - Shape: - src: (S, N, E). - cached_len: (N,) - N is the batch size. - cached_avg: (num_layers, N, C). - N is the batch size, C is the feature dimension. - cached_key: (num_layers, left_context_len, N, K). - N is the batch size, K is the key dimension. - cached_val: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_val2: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_conv1: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - cached_conv2: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - - Returns: A tuple of 8 tensors: - - output tensor - - updated cached number of past frmaes. - - updated cached average of past frmaes. - - updated cached key tensor of of the first attention module. - - updated cached value tensor of of the first attention module. - - updated cached value tensor of of the second attention module. - - updated cached left contexts of the first convolution module. - - updated cached left contexts of the second convolution module. - """ - assert cached_len.size(0) == self.num_layers, ( - cached_len.size(0), - self.num_layers, - ) - assert cached_avg.size(0) == self.num_layers, ( - cached_avg.size(0), - self.num_layers, - ) - assert cached_key.size(0) == self.num_layers, ( - cached_key.size(0), - self.num_layers, - ) - assert cached_val.size(0) == self.num_layers, ( - cached_val.size(0), - self.num_layers, - ) - assert cached_val2.size(0) == self.num_layers, ( - cached_val2.size(0), - self.num_layers, - ) - assert cached_conv1.size(0) == self.num_layers, ( - cached_conv1.size(0), - self.num_layers, - ) - assert cached_conv2.size(0) == self.num_layers, ( - cached_conv2.size(0), - self.num_layers, - ) - - left_context_len = cached_key.shape[1] - pos_emb = self.encoder_pos(src, left_context_len) - output = src - - new_cached_len = [] - new_cached_avg = [] - new_cached_key = [] - new_cached_val = [] - new_cached_val2 = [] - new_cached_conv1 = [] - new_cached_conv2 = [] - for i, mod in enumerate(self.layers): - output, len_avg, avg, key, val, val2, conv1, conv2 = mod.streaming_forward( - output, - pos_emb, - cached_len=cached_len[i], - cached_avg=cached_avg[i], - cached_key=cached_key[i], - cached_val=cached_val[i], - cached_val2=cached_val2[i], - cached_conv1=cached_conv1[i], - cached_conv2=cached_conv2[i], - ) - # Update caches - new_cached_len.append(len_avg) - new_cached_avg.append(avg) - new_cached_key.append(key) - new_cached_val.append(val) - new_cached_val2.append(val2) - new_cached_conv1.append(conv1) - new_cached_conv2.append(conv2) - - return ( - output, - torch.stack(new_cached_len, dim=0), - torch.stack(new_cached_avg, dim=0), - torch.stack(new_cached_key, dim=0), - torch.stack(new_cached_val, dim=0), - torch.stack(new_cached_val2, dim=0), - torch.stack(new_cached_conv1, dim=0), - torch.stack(new_cached_conv2, dim=0), - ) - - -class DownsampledZipformerEncoder(nn.Module): - r""" - DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate, - after convolutional downsampling, and then upsampled again at the output, and combined - with the origin input, so that the output has the same shape as the input. - """ - - def __init__( - self, encoder: nn.Module, input_dim: int, output_dim: int, downsample: int - ): - super(DownsampledZipformerEncoder, self).__init__() - self.downsample_factor = downsample - self.downsample = AttentionDownsample(input_dim, output_dim, downsample) - self.encoder = encoder - self.num_layers = encoder.num_layers - self.d_model = encoder.d_model - self.attention_dim = encoder.attention_dim - self.cnn_module_kernel = encoder.cnn_module_kernel - self.upsample = SimpleUpsample(output_dim, downsample) - self.out_combiner = SimpleCombiner( - input_dim, output_dim, min_weight=(0.0, 0.25) - ) - - def forward( - self, - src: Tensor, - # Note: the type of feature_mask should be Unino[float, Tensor], - # but to make torch.jit.script() happ, we use float here - feature_mask: float = 1.0, - attn_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required). - feature_mask: something that broadcasts with src, that we'll multiply `src` - by at every layer. feature_mask is expected to be already downsampled by - self.downsample_factor. - attn_mask: attention mask (optional). Should be downsampled already. - src_key_padding_mask: the mask for the src keys per batch (optional). Should be downsampled already. - - Shape: - src: (S, N, E). - attn_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - Returns: output of shape (S, N, F) where F is the number of output features - (output_dim to constructor) - """ - src_orig = src - src = self.downsample(src) - - src = self.encoder( - src, - feature_mask=feature_mask, - attn_mask=attn_mask, - src_key_padding_mask=src_key_padding_mask, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - - return self.out_combiner(src_orig, src) - - def streaming_forward( - self, - src: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - cached_key: Tensor, - cached_val: Tensor, - cached_val2: Tensor, - cached_conv1: Tensor, - cached_conv2: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: - r"""Downsample, go through encoder, upsample. - - Args: - src: the sequence to the encoder (required). - cached_avg: cached average value of past frames. - cached_len: length of past frames. - cached_key: cached key tensor for the first attention module. - cached_val: cached value tensor for the first attention module. - cached_val2: cached value tensor for the second attention module. - cached_conv1: cached left context for the first convolution module. - cached_conv2: cached left context for the second convolution module. - - Shape: - src: (S, N, E). - cached_len: (N,) - N is the batch size. - cached_avg: (num_layers, N, C). - N is the batch size, C is the feature dimension. - cached_key: (num_layers, left_context_len, N, K). - N is the batch size, K is the key dimension. - cached_val: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_val2: (num_layers, left_context_len, N, V). - N is the batch size, V is the key dimension. - cached_conv1: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - cached_conv2: (num_layers, N, C, kernel_size-1). - N is the batch size, C is the convolution channels. - Returns: output of shape (S, N, F) where F is the number of output features - (output_dim to constructor) - """ - src_orig = src - src = self.downsample(src) - - ( - src, - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) = self.encoder.streaming_forward( - src, - cached_len=cached_len, - cached_avg=cached_avg, - cached_key=cached_key, - cached_val=cached_val, - cached_val2=cached_val2, - cached_conv1=cached_conv1, - cached_conv2=cached_conv2, - ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] - - return ( - self.out_combiner(src_orig, src), - cached_len, - cached_avg, - cached_key, - cached_val, - cached_val2, - cached_conv1, - cached_conv2, - ) - - -class AttentionDownsample(torch.nn.Module): - """ - Does downsampling with attention, by weighted sum, and a projection.. - """ - - def __init__(self, in_channels: int, out_channels: int, downsample: int): - """ - Require out_channels > in_channels. - """ - super(AttentionDownsample, self).__init__() - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) - - # fill in the extra dimensions with a projection of the input - if out_channels > in_channels: - self.extra_proj = nn.Linear( - in_channels * downsample, out_channels - in_channels, bias=False - ) - else: - self.extra_proj = None - self.downsample = downsample - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, 1, in_channels) - Returns a tensor of shape - ( (seq_len+downsample-1)//downsample, batch_size, out_channels) - """ - (seq_len, batch_size, in_channels) = src.shape - ds = self.downsample - d_seq_len = (seq_len + ds - 1) // ds - - # Pad to an exact multiple of self.downsample - if seq_len != d_seq_len * ds: - # right-pad src, repeating the last element. - pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) - src = torch.cat((src, src_extra), dim=0) - assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) - - src = src.reshape(d_seq_len, ds, batch_size, in_channels) - scores = (src * self.query).sum(dim=-1, keepdim=True) - - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): - scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) - - weights = scores.softmax(dim=1) - - # ans1 is the first `in_channels` channels of the output - ans = (src * weights).sum(dim=1) - src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels) - - if self.extra_proj is not None: - ans2 = self.extra_proj(src) - ans = torch.cat((ans, ans2), dim=2) - return ans - - -class SimpleUpsample(torch.nn.Module): - """ - A very simple form of upsampling that mostly just repeats the input, but - also adds a position-specific bias. - """ - - def __init__(self, num_channels: int, upsample: int): - super(SimpleUpsample, self).__init__() - self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) - - def forward(self, src: Tensor) -> Tensor: - """ - x: (seq_len, batch_size, num_channels) - Returns a tensor of shape - ( (seq_len*upsample), batch_size, num_channels) - """ - upsample = self.bias.shape[0] - (seq_len, batch_size, num_channels) = src.shape - src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels) - src = src + self.bias.unsqueeze(1) - src = src.reshape(seq_len * upsample, batch_size, num_channels) - return src - - -class SimpleCombinerIdentity(nn.Module): - def __init__(self, *args, **kwargs): - super().__init__() - - def forward(self, src1: Tensor, src2: Tensor) -> Tensor: - return src1 - - -class SimpleCombiner(torch.nn.Module): - """ - A very simple way of combining 2 vectors of 2 different dims, via a - learned weighted combination in the shared part of the dim. - Args: - dim1: the dimension of the first input, e.g. 256 - dim2: the dimension of the second input, e.g. 384. - The output will have the same dimension as dim2. - """ - - def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): - super(SimpleCombiner, self).__init__() - assert dim2 >= dim1, (dim2, dim1) - self.weight1 = nn.Parameter(torch.zeros(())) - self.min_weight = min_weight - - def forward(self, src1: Tensor, src2: Tensor) -> Tensor: - """ - src1: (*, dim1) - src2: (*, dim2) - - Returns: a tensor of shape (*, dim2) - """ - assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape) - - weight1 = self.weight1 - if not torch.jit.is_scripting(): - if ( - self.training - and random.random() < 0.25 - and self.min_weight != (0.0, 0.0) - ): - weight1 = weight1.clamp( - min=self.min_weight[0], max=1.0 - self.min_weight[1] - ) - - src1 = src1 * weight1 - src2 = src2 * (1.0 - weight1) - - src1_dim = src1.shape[-1] - src2_dim = src2.shape[-1] - if src1_dim != src2_dim: - if src1_dim < src2_dim: - src1 = torch.nn.functional.pad(src1, (0, src2_dim - src1_dim)) - else: - src1 = src1[:src2_dim] - - return src1 + src2 - - -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module. - - See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py - - Args: - d_model: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length. - - """ - - def __init__( - self, - d_model: int, - dropout_rate: float, - max_len: int = 5000, - ) -> None: - """Construct a PositionalEncoding object.""" - super(RelPositionalEncoding, self).__init__() - self.d_model = d_model - self.dropout = torch.nn.Dropout(dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(max_len)) - - def extend_pe(self, x: Tensor, left_context_len: int = 0) -> None: - """Reset the positional encodings.""" - x_size_left = x.size(0) + left_context_len - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x_size_left * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vector and `j` means the - # position of key vector. We use positive relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tensor: - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (time, batch, `*`). - left_context_len: (int): Length of cached left context. - - Returns: - torch.Tensor: Encoded tensor (batch, left_context_len + 2*time-1, `*`). - - """ - self.extend_pe(x, left_context_len) - x_size_left = x.size(0) + left_context_len - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - - x_size_left - + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(0), - ] - return self.dropout(pos_emb) - - -class RelPositionMultiheadAttention(nn.Module): - r"""Multi-Head Attention layer with relative position encoding - - This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context", - we have to write up the differences. - - - Args: - embed_dim: total dimension of the model. - attention_dim: dimension in the attention module, may be less or more than embed_dim - but must be a multiple of num_heads. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. - - Examples:: - - >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - attention_dim: int, - num_heads: int, - pos_dim: int, - dropout: float = 0.0, - ) -> None: - super(RelPositionMultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.attention_dim = attention_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = attention_dim // num_heads - self.pos_dim = pos_dim - assert self.head_dim % 2 == 0, self.head_dim - assert self.head_dim * num_heads == attention_dim, ( - self.head_dim, - num_heads, - attention_dim, - ) - - # the initial_scale is supposed to take over the "scaling" factor of - # head_dim ** -0.5, dividing it between the query and key. - in_proj_dim = ( - 2 * attention_dim # query, key - + attention_dim // 2 # value - + pos_dim * num_heads # positional encoding query - ) - - self.in_proj = ScaledLinear( - embed_dim, in_proj_dim, bias=True, initial_scale=self.head_dim**-0.25 - ) - - # self.whiten_values is applied on the values in forward(); - # it just copies the keys but prevents low-rank distribution by modifying grads. - self.whiten_values = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - self.whiten_keys = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear( - embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 - ) - - # the following are for diagnosics only, see --print-diagnostics option. - # they only copy their inputs. - self.copy_pos_query = Identity() - self.copy_query = Identity() - - self.out_proj = ScaledLinear( - attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 - ) - - self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) - self.out_proj2 = ScaledLinear( - attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 - ) - # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - - def forward( - self, - x: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - r""" - Args: - x: input to be projected to query, key, value - pos_emb: Positional embedding tensor - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask and a value is True, - the corresponding value on the attention layer will be ignored. When given - a byte mask and a value is non-zero, the corresponding value on the attention - layer will be ignored - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - - Inputs: - - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - - Returns: (attn_output, attn_weights) - - - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads - and S is the sequence length. - """ - x, weights = self.multi_head_attention_forward( - self.in_proj(x), - self.linear_pos(pos_emb), - self.attention_dim, - self.num_heads, - self.dropout, - self.out_proj.weight, - self.out_proj.bias, - training=self.training, - key_padding_mask=key_padding_mask, - attn_mask=attn_mask, - ) - return x, weights - - def streaming_forward( - self, - x: Tensor, - pos_emb: Tensor, - cached_key: Tensor, - cached_val: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - r""" - Args: - x: input to be projected to query, key, value - pos_emb: Positional embedding tensor - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - - Inputs: - - x: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - cached_key: :math:`(left_context_len, N, K)`, where N is the batch size, K is the key dimension. - - cached_val: :math:`(left_context_len, N, V)`, where N is the batch size, V is the value dimension. - - - Returns: (attn_output, attn_weights, cached_key, cached_val) - - - attn_output: :math:`(S, N, E)` where S is the sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * N, S, S)` where N is the batch size, H is the num-heads - and S is the sequence length. - - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of - left context - - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of - """ - ( - x, - weights, - cached_key, - cached_val, - ) = self.streaming_multi_head_attention_forward( - self.in_proj(x), - self.linear_pos(pos_emb), - self.attention_dim, - self.num_heads, - self.out_proj.weight, - self.out_proj.bias, - cached_key=cached_key, - cached_val=cached_val, - ) - return x, weights, cached_key, cached_val - - def multi_head_attention_forward( - self, - x_proj: Tensor, - pos: Tensor, - attention_dim: int, - num_heads: int, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - attn_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - r""" - Args: - x_proj: the projected input, to be split into query, key, value. - pos: head-specific biases arising from the positional embeddings. - attention_dim: dimension inside attention mechanism - num_heads: parallel attention heads. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - - Shape: - Inputs: - - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is - the attention dimension. Will be split into (query, key, value, pos). - - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence - length, N is the batch size, and A is the attention dim. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * H, S, S)` where N is the batch size, - H is the num-heads, S is the sequence length. - """ - - seq_len, bsz, _ = x_proj.size() - - head_dim = attention_dim // num_heads - pos_dim = self.pos_dim # positional-encoding dim per head - assert ( - head_dim * num_heads == attention_dim - ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" - - # self-attention - q = x_proj[..., 0:attention_dim] - k = x_proj[..., attention_dim : 2 * attention_dim] - value_dim = attention_dim // 2 - v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] - # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[..., 2 * attention_dim + value_dim :] - - k = self.whiten_keys(k) # does nothing in the forward pass. - v = self.whiten_values(v) # does nothing in the forward pass. - q = self.copy_query(q) # for diagnostics only, does nothing. - p = self.copy_pos_query(p) # for diagnostics only, does nothing. - - if attn_mask is not None: - assert ( - attn_mask.dtype == torch.float32 - or attn_mask.dtype == torch.float64 - or attn_mask.dtype == torch.float16 - or attn_mask.dtype == torch.uint8 - or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( - attn_mask.dtype - ) - if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." - ) - attn_mask = attn_mask.to(torch.bool) - - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, seq_len, seq_len]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [ - bsz * num_heads, - seq_len, - seq_len, - ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") - else: - raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) - ) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - q = q.reshape(seq_len, bsz, num_heads, head_dim) - p = p.reshape(seq_len, bsz, num_heads, pos_dim) - k = k.reshape(seq_len, bsz, num_heads, head_dim) - v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - - if key_padding_mask is not None: - assert key_padding_mask.size(0) == bsz, "{} == {}".format( - key_padding_mask.size(0), bsz - ) - assert key_padding_mask.size(1) == seq_len, "{} == {}".format( - key_padding_mask.size(1), seq_len - ) - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) - p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - - seq_len2 = 2 * seq_len - 1 - pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) - # pos shape now: (batch, head, pos_dim, seq_len2) - - # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_weights = torch.matmul(p, pos) - # the following .as_strided() expression converts the last axis of pos_weights from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, seq_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) - - # caution: they are really scores at this point. - attn_output_weights = torch.matmul(q, k) + pos_weights - - if not torch.jit.is_scripting(): - if training and random.random() < 0.1: - # This is a harder way of limiting the attention scores to not be too large. - # It incurs a penalty if any of them has an absolute value greater than 50.0. - # this should be outside the normal range of the attention scores. We use - # this mechanism instead of, say, a limit on entropy, because once the entropy - # gets very small gradients through the softmax can become very small, and - # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt( - attn_output_weights, limit=25.0, penalty=1.0e-04 - ) - - # attn_output_weights: (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, seq_len, seq_len - ) - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights = attn_output_weights.masked_fill( - attn_mask, float("-inf") - ) - else: - attn_output_weights = attn_output_weights + attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view( - bsz, num_heads, seq_len, seq_len - ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), - ) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, seq_len, seq_len - ) - - # Using this version of softmax, defined in scaling.py, - # should save a little of the memory used in backprop by, if - # we are in automatic mixed precision mode (amp) == autocast, - # only storing the half-precision output for backprop purposes. - attn_output_weights = softmax(attn_output_weights, dim=-1) - - # If we are using chunk-wise attention mask and setting a limited - # num_left_chunks, the attention may only see the padding values which - # will also be masked out by `key_padding_mask`. At this circumstances, - # the whole column of `attn_output_weights` will be `-inf` - # (i.e. be `nan` after softmax). So we fill `0.0` at the masking - # positions to avoid invalid loss value below. - if ( - attn_mask is not None - and attn_mask.dtype == torch.bool - and key_padding_mask is not None - ): - if attn_mask.size(0) != 1: - attn_mask = attn_mask.view(bsz, num_heads, seq_len, seq_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) - else: - # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) - - attn_output_weights = attn_output_weights.view( - bsz, num_heads, seq_len, seq_len - ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, seq_len, seq_len - ) - - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=dropout_p, training=training - ) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, attention_dim // 2) - ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) - - return attn_output, attn_output_weights - - def streaming_multi_head_attention_forward( - self, - x_proj: Tensor, - pos: Tensor, - attention_dim: int, - num_heads: int, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - cached_key: Tensor, - cached_val: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: - r""" - Args: - x_proj: the projected input, to be split into query, key, value. - pos: head-specific biases arising from the positional embeddings. - attention_dim: dimension inside attention mechanism - num_heads: parallel attention heads. - out_proj_weight, out_proj_bias: the output projection weight and bias. - cached_key: cached attention key tensor of left context. - cached_val: cached attention value tensor of left context. - - Shape: - Inputs: - - x: :math:`(L, N, 7 * A // 2)` where L is the target sequence length, N is the batch size, A is - the attention dimension. Will be split into (query, key, value, pos). - - pos: :math:`(N, 2*L-1, A//2)` or :math:`(1, 2*L-1, A//2)` where L is the sequence - length, N is the batch size, and A is the attention dim. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_weights: :math:`(N * H, S, S)` where N is the batch size, - H is the num-heads, S is the sequence length. - - cached_key: :math:`(left_context_len, N, K)`, updated cached attention key tensor of left context. - - cached_val: :math:`(left_context_len, N, K)`, updated cached attention value tensor of left context. - """ - - seq_len, bsz, _ = x_proj.size() - - head_dim = attention_dim // num_heads - pos_dim = self.pos_dim # positional-encoding dim per head - assert ( - head_dim * num_heads == attention_dim - ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" - - # self-attention - q = x_proj[..., 0:attention_dim] - k = x_proj[..., attention_dim : 2 * attention_dim] - value_dim = attention_dim // 2 - v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] - # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[..., 2 * attention_dim + value_dim :] - - left_context_len = cached_key.shape[0] - assert left_context_len > 0, left_context_len - assert cached_key.shape[0] == cached_val.shape[0], ( - cached_key.shape, - cached_val.shape, - ) - # Pad cached left contexts - k = torch.cat([cached_key, k], dim=0) - v = torch.cat([cached_val, v], dim=0) - # Update cached left contexts - cached_key = k[-left_context_len:, ...] - cached_val = v[-left_context_len:, ...] - - # The length of key and value - kv_len = k.shape[0] - - q = q.reshape(seq_len, bsz, num_heads, head_dim) - p = p.reshape(seq_len, bsz, num_heads, pos_dim) - k = k.reshape(kv_len, bsz, num_heads, head_dim) - v = v.reshape(kv_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - - q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) - p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - - seq_len2 = 2 * seq_len - 1 + left_context_len - pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) - # pos shape now: (batch, head, pos_dim, seq_len2) - - # (batch, head, time1, pos_dim) x (1, head, pos_dim, seq_len2) -> (batch, head, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_weights = torch.matmul(p, pos) - # the following .as_strided() expression converts the last axis of pos_weights from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, kv_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) - - # caution: they are really scores at this point. - attn_output_weights = torch.matmul(q, k) + pos_weights - - # attn_output_weights: (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, seq_len, kv_len) - - # Using this version of softmax, defined in scaling.py, - # should save a little of the memory used in backprop by, if - # we are in automatic mixed precision mode (amp) == autocast, - # only storing the half-precision output for backprop purposes. - attn_output_weights = softmax(attn_output_weights, dim=-1) - - attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [bsz * num_heads, seq_len, head_dim // 2] - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, attention_dim // 2) - ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) - - return attn_output, attn_output_weights, cached_key, cached_val - - def forward2( - self, - x: Tensor, - attn_weights: Tensor, - ) -> Tensor: - """ - Second forward function, where we re-use the attn_weights returned by the first forward function - but with different input. - Args: - x: input, of shape (seq_len, batch_size, embed_dim) - attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) - Returns: - output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) - """ - num_heads = self.num_heads - (seq_len, bsz, embed_dim) = x.shape - head_dim = self.attention_dim // num_heads - # v: (tgt_len, bsz, embed_dim // 2) - v = self.in_proj2(x) - v = self.whiten_values2(v) # does nothing in the forward pass. - v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) - - # now v: (bsz * num_heads, seq_len, head_dim // 2) - attn_output = torch.bmm(attn_weights, v) - - if not torch.jit.is_scripting(): - if random.random() < 0.001 or __name__ == "__main__": - self._print_attn_stats(attn_weights, attn_output) - - # attn_output: (bsz * num_heads, seq_len, head_dim) - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, self.attention_dim // 2) - ) - # returned value is of shape (seq_len, bsz, embed_dim), like x. - return self.out_proj2(attn_output) - - def streaming_forward2( - self, - x: Tensor, - attn_weights: Tensor, - cached_val: Tensor, - ) -> Tuple[Tensor, Tensor]: - """ - Second forward function, where we re-use the attn_weights returned by the first forward function - but with different input. - Args: - x: input, of shape (seq_len, batch_size, embed_dim) - attn_weights: attention weights returned by forward(), of shape (batch_size * num_heads, seq_len, seq_len) - cached_val: cached attention value tensor of left context. - Returns: - - output of the same shape as x, i.e. (seq_len, batch_size, embed_dim) - - updated cached attention value tensor of left context. - """ - num_heads = self.num_heads - (seq_len, bsz, embed_dim) = x.shape - head_dim = self.attention_dim // num_heads - # v: (tgt_len, bsz, embed_dim // 2) - v = self.in_proj2(x) - - left_context_len = cached_val.shape[0] - assert left_context_len > 0, left_context_len - v = torch.cat([cached_val, v], dim=0) - cached_val = v[-left_context_len:] - - seq_len2 = left_context_len + seq_len - v = v.reshape(seq_len2, bsz * num_heads, head_dim // 2).transpose(0, 1) - - # now v: (bsz * num_heads, seq_len, head_dim // 2) - attn_output = torch.bmm(attn_weights, v) - - # attn_output: (bsz * num_heads, seq_len, head_dim) - attn_output = ( - attn_output.transpose(0, 1) - .contiguous() - .view(seq_len, bsz, self.attention_dim // 2) - ) - # returned value is of shape (seq_len, bsz, embed_dim), like x. - return self.out_proj2(attn_output), cached_val - - def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): - # attn_weights: (batch_size * num_heads, seq_len, seq_len) - # attn_output: (bsz * num_heads, seq_len, head_dim) - (n, seq_len, head_dim) = attn_output.shape - num_heads = self.num_heads - bsz = n // num_heads - - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=False): - attn_weights = attn_weights.to(torch.float32) - attn_output = attn_output.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .reshape(bsz, num_heads, seq_len) - .mean(dim=(0, 2)) - ) - attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) - attn_output = attn_output.permute(1, 0, 2, 3).reshape( - num_heads, bsz * seq_len, head_dim - ) - attn_output_mean = attn_output.mean(dim=1, keepdim=True) - attn_output = attn_output - attn_output_mean - attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( - bsz * seq_len - ) - # attn_covar: (num_heads, head_dim, head_dim) - # eigs, _ = torch.symeig(attn_covar) - # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") - - attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) - embed_dim = self.in_proj2.weight.shape[1] - in_proj_covar = ( - self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 - ).mean(dim=(1, 2)) - out_proj_covar = ( - self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 - ).mean(dim=(0, 2)) - logging.info( - f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}" - ) - - -class PoolingModule(nn.Module): - """ - Averages the input over the time dimension and project with a square matrix. - """ - - def __init__(self, d_model: int): - super().__init__() - self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """ - Args: - x: a Tensor of shape (T, N, C) - src_key_padding_mask: a Tensor of bool, of shape (N, T), with True in masked - positions. - - Returns: - - output, a Tensor of shape (T, N, C). - """ - if src_key_padding_mask is not None: - # False in padding positions - padding_mask = src_key_padding_mask.logical_not().to(x.dtype) # (N, T) - # Cumulated numbers of frames from start - cum_mask = padding_mask.cumsum(dim=1) # (N, T) - x = x.cumsum(dim=0) # (T, N, C) - pooling_mask = padding_mask / cum_mask - pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) - # now pooling_mask: (T, N, 1) - x = x * pooling_mask # (T, N, C) - else: - num_frames = x.shape[0] - cum_mask = torch.arange(1, num_frames + 1).unsqueeze(1) # (T, 1) - x = x.cumsum(dim=0) # (T, N, C) - pooling_mask = (1.0 / cum_mask).unsqueeze(2) - # now pooling_mask: (T, N, 1) - x = x * pooling_mask - - x = self.proj(x) - return x - - def streaming_forward( - self, - x: Tensor, - cached_len: Tensor, - cached_avg: Tensor, - ) -> Tuple[Tensor, Tensor, Tensor]: - """ - Args: - x: a Tensor of shape (T, N, C) - cached_len: a Tensor of int, of shape (N,), containing the number of - past frames in batch. - cached_avg: a Tensor of shape (N, C), the average over all past frames - in batch. - - Returns: - A tuple of 2 tensors: - - output, a Tensor of shape (T, N, C). - - updated cached_avg, a Tensor of shape (N, C). - """ - x = x.cumsum(dim=0) # (T, N, C) - x = x + (cached_avg * cached_len.unsqueeze(1)).unsqueeze(0) - # Cumulated numbers of frames from start - cum_mask = torch.arange(1, x.size(0) + 1, device=x.device) - cum_mask = cum_mask.unsqueeze(1) + cached_len.unsqueeze(0) # (T, N) - pooling_mask = (1.0 / cum_mask).unsqueeze(2) - # now pooling_mask: (T, N, 1) - x = x * pooling_mask # (T, N, C) - - cached_len = cached_len + x.size(0) - cached_avg = x[-1] - - x = self.proj(x) - return x, cached_len, cached_avg - - -class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer model.""" - - def __init__(self, d_model: int, feedforward_dim: int, dropout: float): - super(FeedforwardModule, self).__init__() - self.in_proj = nn.Linear(d_model, feedforward_dim) - self.balancer = ActivationBalancer( - feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 - ) - self.activation = DoubleSwish() - self.dropout = nn.Dropout(dropout) - self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) - - def forward(self, x: Tensor): - x = self.in_proj(x) - x = self.balancer(x) - x = self.activation(x) - x = self.dropout(x) - x = self.out_proj(x) - return x - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Zipformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - - """ - - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: - """Construct an ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0, kernel_size - - self.pointwise_conv1 = nn.Conv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - - # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.deriv_balancer1 = ActivationBalancer( - 2 * channels, - channel_dim=1, - max_abs=10.0, - min_positive=0.05, - max_positive=1.0, - ) - - # Will pad cached left context - self.lorder = kernel_size - 1 - self.depthwise_conv = nn.Conv1d( - channels, - channels, - kernel_size, - stride=1, - padding=0, - groups=channels, - bias=bias, - ) - - self.deriv_balancer2 = ActivationBalancer( - channels, - channel_dim=1, - min_positive=0.05, - max_positive=1.0, - max_abs=20.0, - ) - - self.activation = DoubleSwish() - - self.pointwise_conv2 = ScaledConv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - initial_scale=0.05, - ) - - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch (optional): - (batch, #time), contains bool in masked positions. - - Returns: - - Output tensor (#time, batch, channels). - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - - x = self.deriv_balancer1(x) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - if src_key_padding_mask is not None: - x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - - # 1D Depthwise Conv - # Make depthwise_conv causal by - # manualy padding self.lorder zeros to the left - x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) - x = self.depthwise_conv(x) - - x = self.deriv_balancer2(x) - x = self.activation(x) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1) - - def streaming_forward( - self, - x: Tensor, - cache: Tensor, - ) -> Tuple[Tensor, Tensor]: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - src_key_padding_mask: the mask for the src keys per batch: - (batch, #time), contains bool in masked positions. - cache: Cached left context for depthwise_conv, with shape of - (batch, channels, #kernel_size-1). Only used in real streaming decoding. - - Returns: - A tuple of 2 tensors: - - Output tensor (#time, batch, channels). - - New cached left context, with shape of (batch, channels, #kernel_size-1). - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - - x = self.deriv_balancer1(x) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # 1D Depthwise Conv - assert cache.shape == (x.size(0), x.size(1), self.lorder), ( - cache.shape, - (x.size(0), x.size(1), self.lorder), - ) - x = torch.cat([cache, x], dim=2) - # Update cache - cache = x[:, :, -self.lorder :] - x = self.depthwise_conv(x) - - x = self.deriv_balancer2(x) - x = self.activation(x) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - return x.permute(2, 0, 1), cache - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = (T-3)//2 - 2 == (T-7)//2 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128, - dropout: float = 0.1, - ) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >=7, in_channels >=7 - out_channels - Output dim. The output shape is (N, (T-7)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer2_channels: - Number of channels in layer2 - layer3_channels: - Number of channels in layer3 - """ - assert in_channels >= 7, in_channels - super().__init__() - - self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=(0, 1), # (time, freq) - ), - ActivationBalancer(layer1_channels, channel_dim=1), - DoubleSwish(), - nn.Conv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - padding=0, - ), - ActivationBalancer(layer2_channels, channel_dim=1), - DoubleSwish(), - nn.Conv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=(1, 2), # (time, freq) - ), - ActivationBalancer(layer3_channels, channel_dim=1), - DoubleSwish(), - ) - out_height = (((in_channels - 1) // 2) - 1) // 2 - self.out = ScaledLinear(out_height * layer3_channels, out_channels) - self.dropout = nn.Dropout(dropout) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, (T-7)//2, odim) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = self.conv(x) - # Now x is of shape (N, odim, (T-7)//2, ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).reshape(b, t, c * f)) - # Now x is of shape (N, (T-7)//2, odim) - x = self.dropout(x) - return x - - -def _test_zipformer_main(): - feature_dim = 50 - batch_size = 5 - seq_len = 47 - feature_dim = 50 - # Just make sure the forward pass runs. - - c = Zipformer( - num_features=feature_dim, - encoder_dims=(64, 96), - encoder_unmasked_dims=(48, 64), - nhead=(4, 4), - decode_chunk_size=4, - ) - # Just make sure the forward pass runs. - f = c( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - assert ((seq_len - 7) // 2 + 1) // 2 == f[0].shape[1], (seq_len, f.shape[1]) - f[0].sum().backward() - c.eval() - f = c( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - ) - f # to remove flake8 warnings - - -def _test_conv2d_subsampling(): - num_features = 80 - encoder_dims = 384 - dropout = 0.1 - encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) - for i in range(20, 40): - x = torch.rand(2, i, num_features) - y = encoder_embed(x) - assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) - - -def _test_pooling_module(): - N, S, C = 2, 12, 32 - chunk_len = 4 - m = PoolingModule(d_model=C) - - # test chunk-wise forward with padding_mask - x = torch.randn(S, N, C) - y = m(x) - cached_len = torch.zeros(N, dtype=torch.int32) - cached_avg = torch.zeros(N, C) - for i in range(S // chunk_len): - start = i * chunk_len - end = start + chunk_len - x_chunk = x[start:end] - y_chunk, cached_len, cached_avg = m.streaming_forward( - x_chunk, - cached_len=cached_len, - cached_avg=cached_avg, - ) - assert torch.allclose(y_chunk, y[start:end]), (y_chunk, y[start:end]) - - -def _test_state_stack_unstack(): - m = Zipformer( - num_features=80, - encoder_dims=(64, 96), - encoder_unmasked_dims=(48, 64), - nhead=(4, 4), - zipformer_downsampling_factors=(4, 8), - num_left_chunks=2, - decode_chunk_size=8, - ) - s1 = m.get_init_state() - s2 = m.get_init_state() - states = stack_states([s1, s2]) - new_s1, new_s2 = unstack_states(states) - for i in range(m.num_encoders * 7): - for x, y in zip(s1[i], new_s1[i]): - assert torch.equal(x, y) - for x, y in zip(s2[i], new_s2[i]): - assert torch.equal(x, y) - - -if __name__ == "__main__": - logging.getLogger().setLevel(logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - _test_zipformer_main() - _test_conv2d_subsampling() - _test_pooling_module() - _test_state_stack_unstack() diff --git a/egs/libricss/SURT/local/compute_fbank_libricss.py b/egs/libricss/SURT/local/compute_fbank_libricss.py index 5c6c853ee..afd66899c 100755 --- a/egs/libricss/SURT/local/compute_fbank_libricss.py +++ b/egs/libricss/SURT/local/compute_fbank_libricss.py @@ -25,6 +25,7 @@ The generated fbank features are saved in data/fbank. import logging from pathlib import Path +import pyloudnorm as pyln import torch import torch.multiprocessing from lhotse import LilcomChunkyWriter, load_manifest_lazy @@ -69,6 +70,11 @@ def compute_fbank_libricss(): dev_cuts = cuts.filter(lambda c: "session0" in c.id) test_cuts = cuts.filter(lambda c: "session0" not in c.id) + # If SDM cuts, apply loudness normalization + if name == "sdm": + dev_cuts = dev_cuts.normalize_loudness(target=-23.0) + test_cuts = test_cuts.normalize_loudness(target=-23.0) + logging.info(f"Extracting fbank features for {name} dev cuts") _ = dev_cuts.compute_and_store_features_batch( extractor=extractor, diff --git a/egs/libricss/SURT/local/compute_fbank_librimix.py b/egs/libricss/SURT/local/compute_fbank_librimix.py deleted file mode 100755 index aeed3c25b..000000000 --- a/egs/libricss/SURT/local/compute_fbank_librimix.py +++ /dev/null @@ -1,115 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Johns Hopkins University (authors: Desh Raj) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -""" -This file computes fbank features of the synthetically mixed LibriSpeech -train and dev sets. -It looks for manifests in the directory data/manifests. - -The generated fbank features are saved in data/fbank. -""" -import logging -from pathlib import Path - -import torch -import torch.multiprocessing -from lhotse import LilcomChunkyWriter -from lhotse.features.kaldifeat import ( - KaldifeatFbank, - KaldifeatFbankConfig, - KaldifeatFrameOptions, - KaldifeatMelOptions, -) -from lhotse.recipes.utils import read_manifests_if_cached - -# Torch's multithreaded behavior needs to be disabled or -# it wastes a lot of CPU and slow things down. -# Do this outside of main() in case it needs to take effect -# even when we are not invoking the main (e.g. when spawning subprocesses). -torch.set_num_threads(1) -torch.set_num_interop_threads(1) -torch.multiprocessing.set_sharing_strategy("file_system") - - -def compute_fbank_librimix(): - src_dir = Path("data/manifests") - output_dir = Path("data/fbank") - - sampling_rate = 16000 - num_mel_bins = 80 - - extractor = KaldifeatFbank( - KaldifeatFbankConfig( - frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), - mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), - device="cuda", - ) - ) - - logging.info("Reading manifests") - manifests = read_manifests_if_cached( - dataset_parts=["train_norvb_v1", "dev_norvb_v1"], - types=["cuts"], - output_dir=src_dir, - prefix="libri-mix", - suffix="jsonl.gz", - lazy=True, - ) - - train_cuts = manifests["train_norvb_v1"]["cuts"] - dev_cuts = manifests["dev_norvb_v1"]["cuts"] - # train_2spk_cuts = manifests["train_2spk_norvb"]["cuts"] - - logging.info("Extracting fbank features for training cuts") - _ = train_cuts.compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / "librimix_feats_train_norvb_v1", - manifest_path=src_dir / "cuts_train_norvb_v1.jsonl.gz", - batch_duration=5000, - num_workers=4, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - logging.info("Extracting fbank features for dev cuts") - _ = dev_cuts.compute_and_store_features_batch( - extractor=extractor, - storage_path=output_dir / "librimix_feats_dev_norvb_v1", - manifest_path=src_dir / "cuts_dev_norvb_v1.jsonl.gz", - batch_duration=5000, - num_workers=4, - storage_type=LilcomChunkyWriter, - overwrite=True, - ) - - # logging.info("Extracting fbank features for 2-spk train cuts") - # _ = train_2spk_cuts.compute_and_store_features_batch( - # extractor=extractor, - # storage_path=output_dir / "librimix_feats_train_2spk_norvb", - # manifest_path=src_dir / "cuts_train_2spk_norvb.jsonl.gz", - # batch_duration=5000, - # num_workers=4, - # storage_type=LilcomChunkyWriter, - # overwrite=True, - # ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - logging.basicConfig(format=formatter, level=logging.INFO) - compute_fbank_librimix() diff --git a/egs/libricss/SURT/local/compute_fbank_librispeech.py b/egs/libricss/SURT/local/compute_fbank_librispeech.py index 8dfe12e85..5c8aece9c 100755 --- a/egs/libricss/SURT/local/compute_fbank_librispeech.py +++ b/egs/libricss/SURT/local/compute_fbank_librispeech.py @@ -25,7 +25,6 @@ The generated fbank features are saved in data/fbank. import logging from pathlib import Path -from typing import Optional import torch from lhotse import CutSet, LilcomChunkyWriter @@ -43,17 +42,17 @@ from lhotse.recipes.utils import read_manifests_if_cached # even when we are not invoking the main (e.g. when spawning subprocesses). torch.set_num_threads(1) torch.set_num_interop_threads(1) +torch.multiprocessing.set_sharing_strategy("file_system") -def compute_fbank_librispeech(bpe_model: Optional[str] = None): +def compute_fbank_librispeech(): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_mel_bins = 80 dataset_parts = ( - # "dev-clean", - # "train-clean-100", - # "train-clean-360", + "train-clean-100", + "train-clean-360", "train-other-500", ) prefix = "librispeech" @@ -92,8 +91,7 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): supervisions=m["supervisions"], ) - if "train" in partition: - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, diff --git a/egs/libricss/SURT/local/compute_fbank_lsmix.py b/egs/libricss/SURT/local/compute_fbank_lsmix.py new file mode 100755 index 000000000..da42f8ba1 --- /dev/null +++ b/egs/libricss/SURT/local/compute_fbank_lsmix.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +# Copyright 2022 Johns Hopkins University (authors: Desh Raj) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file computes fbank features of the synthetically mixed LibriSpeech +train and dev sets. +It looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. +""" +import logging +import random +import warnings +from pathlib import Path + +import torch +import torch.multiprocessing +from lhotse import LilcomChunkyWriter, load_manifest +from lhotse.cut import MixedCut, MixTrack, MultiCut +from lhotse.features.kaldifeat import ( + KaldifeatFbank, + KaldifeatFbankConfig, + KaldifeatFrameOptions, + KaldifeatMelOptions, +) +from lhotse.recipes.utils import read_manifests_if_cached +from lhotse.utils import fix_random_seed, uuid4 + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +torch.multiprocessing.set_sharing_strategy("file_system") + + +def compute_fbank_lsmix(): + src_dir = Path("data/manifests") + output_dir = Path("data/fbank") + + sampling_rate = 16000 + num_mel_bins = 80 + + extractor = KaldifeatFbank( + KaldifeatFbankConfig( + frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate), + mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins), + device="cuda", + ) + ) + + logging.info("Reading manifests") + manifests = read_manifests_if_cached( + dataset_parts=["train_clean_full", "train_clean_ov40"], + types=["cuts"], + output_dir=src_dir, + prefix="lsmix", + suffix="jsonl.gz", + lazy=True, + ) + + cs = {} + cs["clean_full"] = manifests["train_clean_full"]["cuts"] + cs["clean_ov40"] = manifests["train_clean_ov40"]["cuts"] + + # only uses RIRs and noises from REVERB challenge + real_rirs = load_manifest(src_dir / "real-rir_recordings_all.jsonl.gz").filter( + lambda r: "RVB2014" in r.id + ) + noises = load_manifest(src_dir / "iso-noise_recordings_all.jsonl.gz").filter( + lambda r: "RVB2014" in r.id + ) + + # Apply perturbation to the training cuts + logging.info("Applying perturbation to the training cuts") + cs["rvb_full"] = cs["clean_full"].map( + lambda c: augment( + c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True + ) + ) + cs["rvb_ov40"] = cs["clean_ov40"].map( + lambda c: augment( + c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True + ) + ) + + for type_affix in ["full", "ov40"]: + for rvb_affix in ["clean", "rvb"]: + logging.info( + f"Extracting fbank features for {type_affix} {rvb_affix} training cuts" + ) + cuts = cs[f"{rvb_affix}_{type_affix}"] + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + _ = cuts.compute_and_store_features_batch( + extractor=extractor, + storage_path=output_dir + / f"lsmix_feats_train_{rvb_affix}_{type_affix}", + manifest_path=src_dir + / f"cuts_train_{rvb_affix}_{type_affix}.jsonl.gz", + batch_duration=5000, + num_workers=4, + storage_type=LilcomChunkyWriter, + overwrite=True, + ) + + +def augment(cut, perturb_snr=False, rirs=None, noises=None, perturb_loudness=False): + """ + Given a mixed cut, this function optionally applies the following augmentations: + - Perturbing the SNRs of the tracks (in range [-5, 5] dB) + - Reverberation using a randomly selected RIR + - Adding noise + - Perturbing the loudness (in range [-20, -25] dB) + """ + out_cut = cut.drop_features() + + # Perturb the SNRs (optional) + if perturb_snr: + snrs = [random.uniform(-5, 5) for _ in range(len(cut.tracks))] + for i, (track, snr) in enumerate(zip(out_cut.tracks, snrs)): + if i == 0: + # Skip the first track since it is the reference + continue + track.snr = snr + + # Reverberate the cut (optional) + if rirs is not None: + # Select an RIR at random + rir = random.choice(rirs) + # Select a channel at random + rir_channel = random.choice(list(range(rir.num_channels))) + # Reverberate the cut + out_cut = out_cut.reverb_rir(rir_recording=rir, rir_channels=[rir_channel]) + + # Add noise (optional) + if noises is not None: + # Select a noise recording at random + noise = random.choice(noises).to_cut() + if isinstance(noise, MultiCut): + noise = noise.to_mono()[0] + # Select an SNR at random + snr = random.uniform(10, 30) + # Repeat the noise to match the duration of the cut + noise = repeat_cut(noise, out_cut.duration) + out_cut = MixedCut( + id=out_cut.id, + tracks=[ + MixTrack(cut=out_cut, type="MixedCut"), + MixTrack(cut=noise, type="DataCut", snr=snr), + ], + ) + + # Perturb the loudness (optional) + if perturb_loudness: + target_loudness = random.uniform(-20, -25) + out_cut = out_cut.normalize_loudness(target_loudness, mix_first=True) + return out_cut + + +def repeat_cut(cut, duration): + while cut.duration < duration: + cut = cut.mix(cut, offset_other_by=cut.duration) + return cut.truncate(duration=duration) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + fix_random_seed(42) + compute_fbank_lsmix() diff --git a/egs/libricss/SURT/prepare.sh b/egs/libricss/SURT/prepare.sh index 192ccd6b9..b1d92f41b 100755 --- a/egs/libricss/SURT/prepare.sh +++ b/egs/libricss/SURT/prepare.sh @@ -4,7 +4,6 @@ set -eou pipefail stage=-1 stop_stage=100 -use_gss=true # Use GSS-based enhancement with MDM setting # We assume dl_dir (download dir) contains the following # directories and files. If not, they will be downloaded @@ -24,8 +23,10 @@ use_gss=true # Use GSS-based enhancement with MDM setting # - noise # - speech # +# - $dl_dir/rirs_noises +# This directory contains the RIRS_NOISES corpus downloaded from https://openslr.org/28/. +# dl_dir=$PWD/download -cmd="queue-freegpu.pl --config conf/gpu.conf --gpu 1 --mem 4G" . shared/parse_options.sh || exit 1 @@ -71,6 +72,15 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then if [ ! -d $dl_dir/musan ]; then lhotse download musan $dl_dir fi + + # If you have pre-downloaded it to /path/to/rirs_noises, + # you can create a symlink + # + # ln -sfv /path/to/rirs_noises $dl_dir/ + # + if [ ! -d $dl_dir/rirs_noises ]; then + lhotse download rirs_noises $dl_dir + fi fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then @@ -94,123 +104,101 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then - log "Stage 3: Prepare musan manifest" + log "Stage 3: Prepare musan manifest and RIRs" # We assume that you have downloaded the musan corpus # to $dl_dir/musan mkdir -p data/manifests lhotse prepare musan $dl_dir/musan data/manifests + + # We assume that you have downloaded the RIRS_NOISES corpus + # to $dl_dir/rirs_noises + lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises data/manifests fi if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Extract features for LibriSpeech, trim to alignments, and shuffle the cuts" - $cmd exp/extract_libri_fbank.log python local/compute_fbank_librispeech.py + python local/compute_fbank_librispeech.py lhotse combine data/manifests/librispeech_cuts_train* - |\ lhotse cut trim-to-alignments --type word --max-pause 0.2 - - |\ shuf | gzip -c > data/manifests/librispeech_cuts_train_trimmed.jsonl.gz - lhotse cut trim-to-alignments --type word --max-pause 0.2 data/manifests/librispeech_cuts_dev-clean.jsonl.gz - |\ - shuf | gzip -c > data/manifests/librispeech_cuts_dev_trimmed.jsonl.gz fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Create simulated mixtures from LibriSpeech (train and dev). This may take a while." - # We create a 2-speaker set which will be used during the model warmup phase, and a - # full training set (2,3,4 speakers) that will be used for the subsequent training. - # We create anechoic and reverberant versions of both sets. For the full set, we compute - # silence and overlap distributions based on LibriCSS sessions (no 0L). - - sim_cmd="queue.pl --mem 16G -l 'num_proc=4,h_rt=600:00:00'" + # We create a high overlap set which will be used during the model warmup phase, and a + # full training set that will be used for the subsequent training. gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\ grep -v "0L" | grep -v "OV10" |\ gzip -c > data/manifests/libricss-sdm_supervisions_all_v1.jsonl.gz - # 2-speaker anechoic - # log "Generating 2-speaker anechoic training set" - # $sim_cmd exp/sim_train_2spk.log lhotse workflows simulate-meetings \ - # --method conversational \ - # --prob-diff-spk-overlap 1.0 \ - # --num-meetings 50000 \ - # --num-speakers-per-meeting 2 \ - # --max-duration-per-speaker 20.0 \ - # --max-utterances-per-speaker 1 \ - # --seed 1234 \ - # --num-jobs 4 \ - # data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \ - # data/manifests/libri-mix_cuts_train_2spk_norvb.jsonl.gz + gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\ + grep "OV40" |\ + gzip -c > data/manifests/libricss-sdm_supervisions_ov40.jsonl.gz - # 2-speaker reverberant - # log "Generating 2-speaker reverberant training set" - # lhotse workflows simulate-meetings \ - # --method conversational \ - # --prob-diff-spk-overlap 1.0 \ - # --num-meetings 50000 \ - # --num-speakers-per-meeting 2 \ - # --max-duration-per-speaker 20.0 \ - # --max-utterances-per-speaker 1 \ - # --seed 1234 \ - # --reverberate \ - # --num-jobs 4 \ - # data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \ - # data/manifests/libri-mix_cuts_train_2spk_rvb.jsonl.gz + # Warmup mixtures (100k) based on high overlap (OV40) + log "Generating 100k anechoic train mixtures for warmup" + lhotse workflows simulate-meetings \ + --method conversational \ + --fit-to-supervisions data/manifests/libricss-sdm_supervisions_ov40.jsonl.gz \ + --num-meetings 100000 \ + --num-speakers-per-meeting 2,3 \ + --max-duration-per-speaker 15.0 \ + --max-utterances-per-speaker 3 \ + --seed 1234 \ + --num-jobs 4 \ + data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \ + data/manifests/lsmix_cuts_train_clean_ov40.jsonl.gz # Full training set (2,3 speakers) anechoic - for part in dev train; do - if [ $part == "dev" ]; then - num_jobs=1 - else - num_jobs=4 - fi - log "Generating anechoic ${part} set (full)" - $sim_cmd exp/sim_${part}.log lhotse workflows simulate-meetings \ - --method conversational \ - --fit-to-supervisions data/manifests/libricss-sdm_supervisions_all_v1.jsonl.gz \ - --num-repeats 1 \ - --num-speakers-per-meeting 2,3 \ - --max-duration-per-speaker 15.0 \ - --max-utterances-per-speaker 3 \ - --seed 1234 \ - --num-jobs ${num_jobs} \ - data/manifests/librispeech_cuts_${part}_trimmed.jsonl.gz \ - data/manifests/libri-mix_cuts_${part}_norvb_v1.jsonl.gz - done - - # Full training set (2,3,4 speakers) reverberant - # for part in dev train; do - # log "Generating reverberant ${part} set (full)" `` - # lhotse workflows simulate-meetings \ - # --method conversational \ - # --num-repeats 1 \ - # --num-speakers-per-meeting 2,3,4 \ - # --max-duration-per-speaker 20.0 \ - # --max-utterances-per-speaker 5 \ - # --seed 1234 \ - # --reverberate \ - # data/manifests/librispeech_cuts_${part}_trimmed.jsonl.gz \ - # data/manifests/libri-mix_cuts_${part}_rvb.jsonl.gz - # done + log "Generating anechoic ${part} set (full)" + lhotse workflows simulate-meetings \ + --method conversational \ + --fit-to-supervisions data/manifests/libricss-sdm_supervisions_all_v1.jsonl.gz \ + --num-repeats 1 \ + --num-speakers-per-meeting 2,3 \ + --max-duration-per-speaker 15.0 \ + --max-utterances-per-speaker 3 \ + --seed 1234 \ + --num-jobs 4 \ + data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \ + data/manifests/lsmix_cuts_train_clean_full.jsonl.gz fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "Stage 6: Compute fbank features for musan" mkdir -p data/fbank - $cmd exp/feats_musan.log python local/compute_fbank_musan.py + python local/compute_fbank_musan.py fi if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then log "Stage 7: Compute fbank features for simulated Libri-mix" mkdir -p data/fbank - $cmd exp/feats_librimix_norvb_v1.log python local/compute_fbank_librimix.py + python local/compute_fbank_lsmix.py fi if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then - log "Stage 8: Compute fbank features for LibriCSS" - mkdir -p data/fbank - $cmd exp/feats_libricss.log python local/compute_fbank_libricss.py + log "Stage 8: Add source feats to mixtures (useful for auxiliary tasks)" + python local/add_source_feats.py + + log "Combining lsmix-clean and lsmix-rvb" + for type in full ov40; do + cat <(gunzip -c data/manifests/cuts_train_clean_${type}_sources.jsonl.gz) \ + <(gunzip -c data/manifests/cuts_train_rvb_${type}_sources.jsonl.gz) |\ + shuf | gzip -c > data/manifests/cuts_train_${type}_sources.jsonl.gz + done fi if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then - log "Stage 9: Download LibriSpeech BPE model from HuggingFace." - mkdir -p data/lang_bpe_500 && pushd data/lang_bpe_500 + log "Stage 9: Compute fbank features for LibriCSS" + mkdir -p data/fbank + python local/compute_fbank_libricss.py +fi + +if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then + log "Stage 10: Download LibriSpeech BPE model from HuggingFace." + mkdir -p data/lang_bpe_500 + pushd data/lang_bpe_500 wget https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/resolve/main/data/lang_bpe_500/bpe.model popd fi