diff --git a/egs/libricss/SURT/README.md b/egs/libricss/SURT/README.md
new file mode 100644
index 000000000..10a1aaad1
--- /dev/null
+++ b/egs/libricss/SURT/README.md
@@ -0,0 +1,249 @@
+# Introduction
+
+This is a multi-talker ASR recipe for the LibriCSS dataset. We train a Streaming
+Unmixing and Recognition Transducer (SURT) model for the task. In this README,
+we will describe the task, the model, and the training process. We will also
+provide links to pre-trained models and training logs.
+
+## Task
+
+LibriCSS is a multi-talker meeting corpus formed from mixing together LibriSpeech utterances
+and replaying in a real meeting room. It consists of 10 1-hour sessions of audio, each
+recorded on a 7-channel microphone. The sessions are recorded at a sampling rate of 16 kHz.
+For more information, refer to the paper:
+Z. Chen et al., "Continuous speech separation: dataset and analysis,"
+ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP),
+Barcelona, Spain, 2020
+
+In this recipe, we perform the "continuous, streaming, multi-talker ASR" task on LibriCSS.
+
+* By "continuous", we mean that the model should be able to transcribe unsegmented audio
+without the need of an external VAD.
+* By "streaming", we mean that the model has limited right context. We use a right-context
+of at most 32 frames (320 ms).
+* By "multi-talker", we mean that the model should be able to transcribe overlapping speech
+from multiple speakers.
+
+For now, we do not care about speaker attribution, i.e., the transcription is speaker
+agnostic. The evaluation depends on the particular model type. In this case, we use
+the optimal reference combination WER (ORC-WER) metric as implemented in the
+[meeteval](https://github.com/fgnt/meeteval) toolkit.
+
+## Model
+
+We use the Streaming Unmixing and Recognition Transducer (SURT) model for this task.
+The model is based on the papers:
+
+- Lu, Liang et al. “Streaming End-to-End Multi-Talker Speech Recognition.” IEEE Signal Processing Letters 28 (2020): 803-807.
+- Raj, Desh et al. “Continuous Streaming Multi-Talker ASR with Dual-Path Transducers.” ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) (2021): 7317-7321.
+
+The model is a combination of a speech separation model and a speech recognition model,
+but trained end-to-end with a single loss function. The overall architecture is shown
+in the figure below. Note that this architecture is slightly different from the one
+in the above papers. A detailed description of the model can be found in the following
+paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR](https://arxiv.org/abs/2306.10559).
+
+
+
+
+ Streaming Unmixing and Recognition Transducer
+
+
+
+In the [dprnn_zipformer](./dprnn_zipformer) recipe, for example, we use a DPRNN-based masking network
+and a Zipfomer-based recognition network. But other combinations are possible as well.
+
+## Training objective
+
+We train the model using the pruned transducer loss, similar to other ASR recipes in
+icefall. However, an important consideration is how to assign references to the output
+channels (2 in this case). For this, we use the heuristic error assignment training (HEAT)
+strategy, which assigns references to the first available channel based on their start
+times. An illustrative example is shown in the figure below:
+
+
+
+
+ Illustration of HEAT-based reference assignment.
+
+
+
+## Description of the recipe
+
+### Pre-requisites
+
+The recipes in this directory need the following packages to be installed:
+
+- [meeteval](https://github.com/fgnt/meeteval)
+- [einops](https://github.com/arogozhnikov/einops)
+
+Additionally, we initialize the "recognition" transducer with a pre-trained model,
+trained on LibriSpeech. For this, please run the following from within `egs/librispeech/ASR`:
+
+```bash
+./prepare.sh
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+python pruned_transducer_stateless7_streaming/train.py \
+ --use-fp16 True \
+ --exp-dir pruned_transducer_stateless7_streaming/exp \
+ --world-size 4 \
+ --max-duration 800 \
+ --num-epochs 10 \
+ --keep-last-k 1 \
+ --manifest-dir data/manifests \
+ --enable-musan true \
+ --master-port 54321 \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --num-encoder-layers 2,2,2,2,2 \
+ --feedforward-dims 768,768,768,768,768 \
+ --nhead 8,8,8,8,8 \
+ --encoder-dims 256,256,256,256,256 \
+ --attention-dims 192,192,192,192,192 \
+ --encoder-unmasked-dims 192,192,192,192,192 \
+ --zipformer-downsampling-factors 1,2,4,8,2 \
+ --cnn-module-kernels 31,31,31,31,31 \
+ --decoder-dim 512 \
+ --joiner-dim 512
+```
+
+The above is for SURT-base (~26M). For SURT-large (~38M), use `--num-encoder-layers 2,4,3,2,4`.
+
+Once the above model is trained for 10 epochs, copy it to `egs/libricss/SURT/exp`:
+
+```bash
+cp -r pruned_transducer_stateless7_streaming/exp/epoch-10.pt exp/zipformer_base.pt
+```
+
+**NOTE:** We also provide this pre-trained checkpoint (see the section below), so you can skip
+the above step if you want.
+
+### Training
+
+To train the model, run the following from within `egs/libricss/SURT`:
+
+```bash
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+python dprnn_zipformer/train.py \
+ --use-fp16 True \
+ --exp-dir dprnn_zipformer/exp/surt_base \
+ --world-size 4 \
+ --max-duration 500 \
+ --max-duration-valid 250 \
+ --max-cuts 200 \
+ --num-buckets 50 \
+ --num-epochs 30 \
+ --enable-spec-aug True \
+ --enable-musan False \
+ --ctc-loss-scale 0.2 \
+ --heat-loss-scale 0.2 \
+ --base-lr 0.004 \
+ --model-init-ckpt exp/zipformer_base.pt \
+ --chunk-width-randomization True \
+ --num-mask-encoder-layers 4 \
+ --num-encoder-layers 2,2,2,2,2
+```
+
+The above is for SURT-base (~26M). For SURT-large (~38M), use:
+
+```bash
+ --num-mask-encoder-layers 6 \
+ --num-encoder-layers 2,4,3,2,4 \
+ --model-init-ckpt exp/zipformer_large.pt \
+```
+
+**NOTE:** You may need to decrease the `--max-duration` for SURT-large to avoid OOM.
+
+### Adaptation
+
+The training step above only trains on simulated mixtures. For best results, we also
+adapt the final model on the LibriCSS dev set. For this, run the following from within
+`egs/libricss/SURT`:
+
+```bash
+export CUDA_VISIBLE_DEVICES="0"
+
+python dprnn_zipformer/train_adapt.py \
+ --use-fp16 True \
+ --exp-dir dprnn_zipformer/exp/surt_base_adapt \
+ --world-size 1 \
+ --max-duration 500 \
+ --max-duration-valid 250 \
+ --max-cuts 200 \
+ --num-buckets 50 \
+ --num-epochs 8 \
+ --lr-epochs 2 \
+ --enable-spec-aug True \
+ --enable-musan False \
+ --ctc-loss-scale 0.2 \
+ --base-lr 0.0004 \
+ --model-init-ckpt dprnn_zipformer/exp/surt_base/epoch-30.pt \
+ --chunk-width-randomization True \
+ --num-mask-encoder-layers 4 \
+ --num-encoder-layers 2,2,2,2,2
+```
+
+For SURT-large, use the following config:
+
+```bash
+ --num-mask-encoder-layers 6 \
+ --num-encoder-layers 2,4,3,2,4 \
+ --model-init-ckpt dprnn_zipformer/exp/surt_large/epoch-30.pt \
+ --num-epochs 15 \
+ --lr-epochs 4 \
+```
+
+
+### Decoding
+
+To decode the model, run the following from within `egs/libricss/SURT`:
+
+#### Greedy search
+
+```bash
+export CUDA_VISIBLE_DEVICES="0"
+
+python dprnn_zipformer/decode.py \
+ --epoch 8 --avg 1 --use-averaged-model False \
+ --exp-dir dprnn_zipformer/exp/surt_base_adapt \
+ --max-duration 250 \
+ --decoding-method greedy_search
+```
+
+#### Beam search
+
+```bash
+python dprnn_zipformer/decode.py \
+ --epoch 8 --avg 1 --use-averaged-model False \
+ --exp-dir dprnn_zipformer/exp/surt_base_adapt \
+ --max-duration 250 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+```
+
+## Results (using beam search)
+
+#### IHM-Mix
+
+| Model | # params | 0L | 0S | OV10 | OV20 | OV30 | OV40 | Avg. |
+|------------|:-------:|:----:|:---:|----:|:----:|:----:|:----:|:----:|
+| dprnn_zipformer (base) | 26.7 | 5.1 | 4.2 | 13.7 | 18.7 | 20.5 | 20.6 | 13.8 |
+| dprnn_zipformer (large) | 37.9 | 4.6 | 3.8 | 12.7 | 14.3 | 16.7 | 21.2 | 12.2 |
+
+#### SDM
+
+| Model | # params | 0L | 0S | OV10 | OV20 | OV30 | OV40 | Avg. |
+|------------|:-------:|:----:|:---:|----:|:----:|:----:|:----:|:----:|
+| dprnn_zipformer (base) | 26.7 | 6.8 | 7.2 | 21.4 | 24.5 | 28.6 | 31.2 | 20.0 |
+| dprnn_zipformer (large) | 37.9 | 6.4 | 6.9 | 17.9 | 19.7 | 25.2 | 25.5 | 16.9 |
+
+## Pre-trained models and logs
+
+* Pre-trained models:
+
+* Training logs:
+ - surt_base:
+ - surt_base_adapt:
+ - surt_large:
+ - surt_large_adapt:
diff --git a/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py b/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py
new file mode 100644
index 000000000..51df91598
--- /dev/null
+++ b/egs/libricss/SURT/dprnn_zipformer/asr_datamodule.py
@@ -0,0 +1,372 @@
+# Copyright 2021 Piotr Żelasko
+# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
+# Copyright 2023 Johns Hopkins Univrtsity (Author: Desh Raj)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import inspect
+import logging
+from functools import lru_cache
+from pathlib import Path
+from typing import Any, Callable, Dict, List, Optional
+
+import torch
+from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
+from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
+ CutMix,
+ DynamicBucketingSampler,
+ K2SurtDataset,
+ PrecomputedFeatures,
+ SimpleCutSampler,
+ SpecAugment,
+)
+from lhotse.dataset.input_strategies import OnTheFlyFeatures
+from lhotse.utils import fix_random_seed
+from torch.utils.data import DataLoader
+
+from icefall.utils import str2bool
+
+
+class _SeedWorkers:
+ def __init__(self, seed: int):
+ self.seed = seed
+
+ def __call__(self, worker_id: int):
+ fix_random_seed(self.seed + worker_id)
+
+
+class LibriCssAsrDataModule:
+ """
+ DataModule for k2 ASR experiments.
+ It assumes there is always one train and valid dataloader,
+ but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
+ and test-other).
+
+ It contains all the common data pipeline modules used in ASR
+ experiments, e.g.:
+ - dynamic batch size,
+ - bucketing samplers,
+ - augmentation,
+ - on-the-fly feature extraction
+
+ This class should be derived for specific corpora used in ASR tasks.
+ """
+
+ def __init__(self, args: argparse.Namespace):
+ self.args = args
+
+ @classmethod
+ def add_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(
+ title="ASR data related options",
+ description="These options are used for the preparation of "
+ "PyTorch DataLoaders from Lhotse CutSet's -- they control the "
+ "effective batch sizes, sampling strategies, applied data "
+ "augmentations, etc.",
+ )
+ group.add_argument(
+ "--manifest-dir",
+ type=Path,
+ default=Path("data/manifests"),
+ help="Path to directory with train/valid/test cuts.",
+ )
+ group.add_argument(
+ "--max-duration",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--max-duration-valid",
+ type=int,
+ default=200.0,
+ help="Maximum pooled recordings duration (seconds) in a "
+ "single batch. You can reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--max-cuts",
+ type=int,
+ default=100,
+ help="Maximum number of cuts in a single batch. You can "
+ "reduce it if it causes CUDA OOM.",
+ )
+ group.add_argument(
+ "--bucketing-sampler",
+ type=str2bool,
+ default=True,
+ help="When enabled, the batches will come from buckets of "
+ "similar duration (saves padding frames).",
+ )
+ group.add_argument(
+ "--num-buckets",
+ type=int,
+ default=30,
+ help="The number of buckets for the DynamicBucketingSampler"
+ "(you might want to increase it for larger datasets).",
+ )
+ group.add_argument(
+ "--on-the-fly-feats",
+ type=str2bool,
+ default=False,
+ help=(
+ "When enabled, use on-the-fly cut mixing and feature "
+ "extraction. Will drop existing precomputed feature manifests "
+ "if available."
+ ),
+ )
+ group.add_argument(
+ "--shuffle",
+ type=str2bool,
+ default=True,
+ help="When enabled (=default), the examples will be "
+ "shuffled for each epoch.",
+ )
+ group.add_argument(
+ "--drop-last",
+ type=str2bool,
+ default=True,
+ help="Whether to drop last batch. Used by sampler.",
+ )
+ group.add_argument(
+ "--return-cuts",
+ type=str2bool,
+ default=True,
+ help="When enabled, each batch will have the "
+ "field: batch['supervisions']['cut'] with the cuts that "
+ "were used to construct it.",
+ )
+
+ group.add_argument(
+ "--num-workers",
+ type=int,
+ default=2,
+ help="The number of training dataloader workers that "
+ "collect the batches.",
+ )
+
+ group.add_argument(
+ "--enable-spec-aug",
+ type=str2bool,
+ default=True,
+ help="When enabled, use SpecAugment for training dataset.",
+ )
+
+ group.add_argument(
+ "--spec-aug-time-warp-factor",
+ type=int,
+ default=80,
+ help="Used only when --enable-spec-aug is True. "
+ "It specifies the factor for time warping in SpecAugment. "
+ "Larger values mean more warping. "
+ "A value less than 1 means to disable time warp.",
+ )
+
+ group.add_argument(
+ "--enable-musan",
+ type=str2bool,
+ default=True,
+ help="When enabled, select noise from MUSAN and mix it"
+ "with training dataset. ",
+ )
+
+ def train_dataloaders(
+ self,
+ cuts_train: CutSet,
+ sampler_state_dict: Optional[Dict[str, Any]] = None,
+ return_sources: bool = True,
+ strict: bool = True,
+ ) -> DataLoader:
+ """
+ Args:
+ cuts_train:
+ CutSet for training.
+ sampler_state_dict:
+ The state dict for the training sampler.
+ """
+ transforms = []
+ if self.args.enable_musan:
+ logging.info("Enable MUSAN")
+ logging.info("About to get Musan cuts")
+ cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
+ transforms.append(
+ CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
+ )
+ else:
+ logging.info("Disable MUSAN")
+
+ input_transforms = []
+ if self.args.enable_spec_aug:
+ logging.info("Enable SpecAugment")
+ logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
+ # Set the value of num_frame_masks according to Lhotse's version.
+ # In different Lhotse's versions, the default of num_frame_masks is
+ # different.
+ num_frame_masks = 10
+ num_frame_masks_parameter = inspect.signature(
+ SpecAugment.__init__
+ ).parameters["num_frame_masks"]
+ if num_frame_masks_parameter.default == 1:
+ num_frame_masks = 2
+ logging.info(f"Num frame mask: {num_frame_masks}")
+ input_transforms.append(
+ SpecAugment(
+ time_warp_factor=self.args.spec_aug_time_warp_factor,
+ num_frame_masks=num_frame_masks,
+ features_mask_size=27,
+ num_feature_masks=2,
+ frames_mask_size=100,
+ )
+ )
+ else:
+ logging.info("Disable SpecAugment")
+
+ logging.info("About to create train dataset")
+ train = K2SurtDataset(
+ input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ cut_transforms=transforms,
+ input_transforms=input_transforms,
+ return_cuts=self.args.return_cuts,
+ return_sources=return_sources,
+ strict=strict,
+ )
+
+ if self.args.bucketing_sampler:
+ logging.info("Using DynamicBucketingSampler.")
+ train_sampler = DynamicBucketingSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ quadratic_duration=30.0,
+ max_cuts=self.args.max_cuts,
+ shuffle=self.args.shuffle,
+ num_buckets=self.args.num_buckets,
+ drop_last=self.args.drop_last,
+ )
+ else:
+ logging.info("Using SingleCutSampler.")
+ train_sampler = SimpleCutSampler(
+ cuts_train,
+ max_duration=self.args.max_duration,
+ max_cuts=self.args.max_cuts,
+ shuffle=self.args.shuffle,
+ )
+ logging.info("About to create train dataloader")
+
+ if sampler_state_dict is not None:
+ logging.info("Loading sampler state dict")
+ train_sampler.load_state_dict(sampler_state_dict)
+
+ # 'seed' is derived from the current random state, which will have
+ # previously been set in the main process.
+ seed = torch.randint(0, 100000, ()).item()
+ worker_init_fn = _SeedWorkers(seed)
+
+ train_dl = DataLoader(
+ train,
+ sampler=train_sampler,
+ batch_size=None,
+ num_workers=self.args.num_workers,
+ persistent_workers=False,
+ worker_init_fn=worker_init_fn,
+ )
+
+ return train_dl
+
+ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
+ transforms = []
+
+ logging.info("About to create dev dataset")
+ validate = K2SurtDataset(
+ input_strategy=OnTheFlyFeatures(
+ OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ )
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ cut_transforms=transforms,
+ return_cuts=self.args.return_cuts,
+ return_sources=False,
+ strict=False,
+ )
+ valid_sampler = DynamicBucketingSampler(
+ cuts_valid,
+ max_duration=self.args.max_duration_valid,
+ max_cuts=self.args.max_cuts,
+ shuffle=False,
+ )
+ logging.info("About to create dev dataloader")
+ valid_dl = DataLoader(
+ validate,
+ sampler=valid_sampler,
+ batch_size=None,
+ num_workers=2,
+ persistent_workers=False,
+ )
+
+ return valid_dl
+
+ def test_dataloaders(self, cuts: CutSet) -> DataLoader:
+ logging.debug("About to create test dataset")
+ test = K2SurtDataset(
+ input_strategy=OnTheFlyFeatures(
+ OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
+ )
+ if self.args.on_the_fly_feats
+ else PrecomputedFeatures(),
+ return_cuts=self.args.return_cuts,
+ return_sources=False,
+ strict=False,
+ )
+ sampler = DynamicBucketingSampler(
+ cuts,
+ max_duration=self.args.max_duration_valid,
+ max_cuts=self.args.max_cuts,
+ shuffle=False,
+ )
+ logging.debug("About to create test dataloader")
+ test_dl = DataLoader(
+ test,
+ batch_size=None,
+ sampler=sampler,
+ num_workers=self.args.num_workers,
+ )
+ return test_dl
+
+ @lru_cache()
+ def lsmix_cuts(
+ self,
+ rvb_affix: str = "clean",
+ type_affix: str = "full",
+ sources: bool = True,
+ ) -> CutSet:
+ logging.info("About to get train cuts")
+ source_affix = "_sources" if sources else ""
+ cs = load_manifest_lazy(
+ self.args.manifest_dir
+ / f"cuts_train_{rvb_affix}_{type_affix}{source_affix}.jsonl.gz"
+ )
+ cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0)
+ return cs
+
+ @lru_cache()
+ def libricss_cuts(self, split="dev", type="sdm") -> CutSet:
+ logging.info(f"About to get LibriCSS {split} {type} cuts")
+ cs = load_manifest_lazy(
+ self.args.manifest_dir / f"cuts_{split}_libricss-{type}.jsonl.gz"
+ )
+ return cs
diff --git a/egs/libricss/SURT/dprnn_zipformer/beam_search.py b/egs/libricss/SURT/dprnn_zipformer/beam_search.py
new file mode 100644
index 000000000..c8e4643d0
--- /dev/null
+++ b/egs/libricss/SURT/dprnn_zipformer/beam_search.py
@@ -0,0 +1,730 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
+# Xiaoyu Yang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import warnings
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Tuple, Union
+
+import k2
+import torch
+from model import SURT
+
+from icefall import NgramLmStateCost
+from icefall.utils import DecodingResults
+
+
+def greedy_search(
+ model: SURT,
+ encoder_out: torch.Tensor,
+ max_sym_per_frame: int,
+ return_timestamps: bool = False,
+) -> Union[List[int], DecodingResults]:
+ """Greedy search for a single utterance.
+ Args:
+ model:
+ An instance of `SURT`.
+ encoder_out:
+ A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
+ max_sym_per_frame:
+ Maximum number of symbols per frame. If it is set to 0, the WER
+ would be 100%.
+ return_timestamps:
+ Whether to return timestamps.
+ Returns:
+ If return_timestamps is False, return the decoded result.
+ Else, return a DecodingResults object containing
+ decoded result and corresponding timestamps.
+ """
+ assert encoder_out.ndim == 4
+
+ # support only batch_size == 1 for now
+ assert encoder_out.size(0) == 1, encoder_out.size(0)
+
+ blank_id = model.decoder.blank_id
+ context_size = model.decoder.context_size
+ unk_id = getattr(model, "unk_id", blank_id)
+
+ device = next(model.parameters()).device
+
+ decoder_input = torch.tensor(
+ [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64
+ ).reshape(1, context_size)
+
+ decoder_out = model.decoder(decoder_input, need_pad=False)
+ decoder_out = model.joiner.decoder_proj(decoder_out)
+
+ encoder_out = model.joiner.encoder_proj(encoder_out)
+
+ T = encoder_out.size(1)
+ t = 0
+ hyp = [blank_id] * context_size
+
+ # timestamp[i] is the frame index after subsampling
+ # on which hyp[i] is decoded
+ timestamp = []
+
+ # Maximum symbols per utterance.
+ max_sym_per_utt = 1000
+
+ # symbols per frame
+ sym_per_frame = 0
+
+ # symbols per utterance decoded so far
+ sym_per_utt = 0
+
+ while t < T and sym_per_utt < max_sym_per_utt:
+ if sym_per_frame >= max_sym_per_frame:
+ sym_per_frame = 0
+ t += 1
+ continue
+
+ # fmt: off
+ current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
+ # fmt: on
+ logits = model.joiner(
+ current_encoder_out, decoder_out.unsqueeze(1), project_input=False
+ )
+ # logits is (1, 1, 1, vocab_size)
+
+ y = logits.argmax().item()
+ if y not in (blank_id, unk_id):
+ hyp.append(y)
+ timestamp.append(t)
+ decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
+ 1, context_size
+ )
+
+ decoder_out = model.decoder(decoder_input, need_pad=False)
+ decoder_out = model.joiner.decoder_proj(decoder_out)
+
+ sym_per_utt += 1
+ sym_per_frame += 1
+ else:
+ sym_per_frame = 0
+ t += 1
+ hyp = hyp[context_size:] # remove blanks
+
+ if not return_timestamps:
+ return hyp
+ else:
+ return DecodingResults(
+ hyps=[hyp],
+ timestamps=[timestamp],
+ )
+
+
+def greedy_search_batch(
+ model: SURT,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ return_timestamps: bool = False,
+) -> Union[List[List[int]], DecodingResults]:
+ """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
+ Args:
+ model:
+ The SURT model.
+ encoder_out:
+ Output from the encoder. Its shape is (N, T, C), where N >= 1.
+ encoder_out_lens:
+ A 1-D tensor of shape (N,), containing number of valid frames in
+ encoder_out before padding.
+ return_timestamps:
+ Whether to return timestamps.
+ Returns:
+ If return_timestamps is False, return the decoded result.
+ Else, return a DecodingResults object containing
+ decoded result and corresponding timestamps.
+ """
+ assert encoder_out.ndim == 3
+ assert encoder_out.size(0) >= 1, encoder_out.size(0)
+
+ packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
+ input=encoder_out,
+ lengths=encoder_out_lens.cpu(),
+ batch_first=True,
+ enforce_sorted=False,
+ )
+
+ device = next(model.parameters()).device
+
+ blank_id = model.decoder.blank_id
+ unk_id = getattr(model, "unk_id", blank_id)
+ context_size = model.decoder.context_size
+
+ batch_size_list = packed_encoder_out.batch_sizes.tolist()
+ N = encoder_out.size(0)
+ assert torch.all(encoder_out_lens > 0), encoder_out_lens
+ assert N == batch_size_list[0], (N, batch_size_list)
+
+ hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)]
+
+ # timestamp[n][i] is the frame index after subsampling
+ # on which hyp[n][i] is decoded
+ timestamps = [[] for _ in range(N)]
+
+ decoder_input = torch.tensor(
+ hyps,
+ device=device,
+ dtype=torch.int64,
+ ) # (N, context_size)
+
+ decoder_out = model.decoder(decoder_input, need_pad=False)
+ decoder_out = model.joiner.decoder_proj(decoder_out)
+ # decoder_out: (N, 1, decoder_out_dim)
+
+ encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
+
+ offset = 0
+ for (t, batch_size) in enumerate(batch_size_list):
+ start = offset
+ end = offset + batch_size
+ current_encoder_out = encoder_out.data[start:end]
+ current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
+ # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
+ offset = end
+
+ decoder_out = decoder_out[:batch_size]
+
+ logits = model.joiner(
+ current_encoder_out, decoder_out.unsqueeze(1), project_input=False
+ )
+ # logits'shape (batch_size, 1, 1, vocab_size)
+
+ logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
+ assert logits.ndim == 2, logits.shape
+ y = logits.argmax(dim=1).tolist()
+ emitted = False
+ for i, v in enumerate(y):
+ if v not in (blank_id, unk_id):
+ hyps[i].append(v)
+ timestamps[i].append(t)
+ emitted = True
+ if emitted:
+ # update decoder output
+ decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
+ decoder_input = torch.tensor(
+ decoder_input,
+ device=device,
+ dtype=torch.int64,
+ )
+ decoder_out = model.decoder(decoder_input, need_pad=False)
+ decoder_out = model.joiner.decoder_proj(decoder_out)
+
+ sorted_ans = [h[context_size:] for h in hyps]
+ ans = []
+ ans_timestamps = []
+ unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
+ for i in range(N):
+ ans.append(sorted_ans[unsorted_indices[i]])
+ ans_timestamps.append(timestamps[unsorted_indices[i]])
+
+ if not return_timestamps:
+ return ans
+ else:
+ return DecodingResults(
+ hyps=ans,
+ timestamps=ans_timestamps,
+ )
+
+
+def modified_beam_search(
+ model: SURT,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ beam: int = 4,
+ temperature: float = 1.0,
+ return_timestamps: bool = False,
+) -> Union[List[List[int]], DecodingResults]:
+ """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
+
+ Args:
+ model:
+ The SURT model.
+ encoder_out:
+ Output from the encoder. Its shape is (N, T, C).
+ encoder_out_lens:
+ A 1-D tensor of shape (N,), containing number of valid frames in
+ encoder_out before padding.
+ beam:
+ Number of active paths during the beam search.
+ temperature:
+ Softmax temperature.
+ return_timestamps:
+ Whether to return timestamps.
+ Returns:
+ If return_timestamps is False, return the decoded result.
+ Else, return a DecodingResults object containing
+ decoded result and corresponding timestamps.
+ """
+ assert encoder_out.ndim == 3, encoder_out.shape
+ assert encoder_out.size(0) >= 1, encoder_out.size(0)
+
+ packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
+ input=encoder_out,
+ lengths=encoder_out_lens.cpu(),
+ batch_first=True,
+ enforce_sorted=False,
+ )
+
+ blank_id = model.decoder.blank_id
+ unk_id = getattr(model, "unk_id", blank_id)
+ context_size = model.decoder.context_size
+ device = next(model.parameters()).device
+
+ batch_size_list = packed_encoder_out.batch_sizes.tolist()
+ N = encoder_out.size(0)
+ assert torch.all(encoder_out_lens > 0), encoder_out_lens
+ assert N == batch_size_list[0], (N, batch_size_list)
+
+ B = [HypothesisList() for _ in range(N)]
+ for i in range(N):
+ B[i].add(
+ Hypothesis(
+ ys=[blank_id] * context_size,
+ log_prob=torch.zeros(1, dtype=torch.float32, device=device),
+ timestamp=[],
+ )
+ )
+
+ encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
+
+ offset = 0
+ finalized_B = []
+ for (t, batch_size) in enumerate(batch_size_list):
+ start = offset
+ end = offset + batch_size
+ current_encoder_out = encoder_out.data[start:end]
+ current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
+ # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
+ offset = end
+
+ finalized_B = B[batch_size:] + finalized_B
+ B = B[:batch_size]
+
+ hyps_shape = get_hyps_shape(B).to(device)
+
+ A = [list(b) for b in B]
+ B = [HypothesisList() for _ in range(batch_size)]
+
+ ys_log_probs = torch.cat(
+ [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
+ ) # (num_hyps, 1)
+
+ decoder_input = torch.tensor(
+ [hyp.ys[-context_size:] for hyps in A for hyp in hyps],
+ device=device,
+ dtype=torch.int64,
+ ) # (num_hyps, context_size)
+
+ decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
+ decoder_out = model.joiner.decoder_proj(decoder_out)
+ # decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
+
+ # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
+ # as index, so we use `to(torch.int64)` below.
+ current_encoder_out = torch.index_select(
+ current_encoder_out,
+ dim=0,
+ index=hyps_shape.row_ids(1).to(torch.int64),
+ ) # (num_hyps, 1, 1, encoder_out_dim)
+
+ logits = model.joiner(
+ current_encoder_out,
+ decoder_out,
+ project_input=False,
+ ) # (num_hyps, 1, 1, vocab_size)
+
+ logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
+
+ log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
+
+ log_probs.add_(ys_log_probs)
+
+ vocab_size = log_probs.size(-1)
+
+ log_probs = log_probs.reshape(-1)
+
+ row_splits = hyps_shape.row_splits(1) * vocab_size
+ log_probs_shape = k2.ragged.create_ragged_shape2(
+ row_splits=row_splits, cached_tot_size=log_probs.numel()
+ )
+ ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
+
+ for i in range(batch_size):
+ topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
+
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
+ topk_token_indexes = (topk_indexes % vocab_size).tolist()
+
+ for k in range(len(topk_hyp_indexes)):
+ hyp_idx = topk_hyp_indexes[k]
+ hyp = A[i][hyp_idx]
+
+ new_ys = hyp.ys[:]
+ new_token = topk_token_indexes[k]
+ new_timestamp = hyp.timestamp[:]
+ if new_token not in (blank_id, unk_id):
+ new_ys.append(new_token)
+ new_timestamp.append(t)
+
+ new_log_prob = topk_log_probs[k]
+ new_hyp = Hypothesis(
+ ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
+ )
+ B[i].add(new_hyp)
+
+ B = B + finalized_B
+ best_hyps = [b.get_most_probable(length_norm=True) for b in B]
+
+ sorted_ans = [h.ys[context_size:] for h in best_hyps]
+ sorted_timestamps = [h.timestamp for h in best_hyps]
+ ans = []
+ ans_timestamps = []
+ unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
+ for i in range(N):
+ ans.append(sorted_ans[unsorted_indices[i]])
+ ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
+
+ if not return_timestamps:
+ return ans
+ else:
+ return DecodingResults(
+ hyps=ans,
+ timestamps=ans_timestamps,
+ )
+
+
+def beam_search(
+ model: SURT,
+ encoder_out: torch.Tensor,
+ beam: int = 4,
+ temperature: float = 1.0,
+ return_timestamps: bool = False,
+) -> Union[List[int], DecodingResults]:
+ """
+ It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
+
+ espnet/nets/beam_search_SURT.py#L247 is used as a reference.
+
+ Args:
+ model:
+ An instance of `SURT`.
+ encoder_out:
+ A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
+ beam:
+ Beam size.
+ temperature:
+ Softmax temperature.
+ return_timestamps:
+ Whether to return timestamps.
+
+ Returns:
+ If return_timestamps is False, return the decoded result.
+ Else, return a DecodingResults object containing
+ decoded result and corresponding timestamps.
+ """
+ assert encoder_out.ndim == 3
+
+ # support only batch_size == 1 for now
+ assert encoder_out.size(0) == 1, encoder_out.size(0)
+ blank_id = model.decoder.blank_id
+ unk_id = getattr(model, "unk_id", blank_id)
+ context_size = model.decoder.context_size
+
+ device = next(model.parameters()).device
+
+ decoder_input = torch.tensor(
+ [blank_id] * context_size,
+ device=device,
+ dtype=torch.int64,
+ ).reshape(1, context_size)
+
+ decoder_out = model.decoder(decoder_input, need_pad=False)
+ decoder_out = model.joiner.decoder_proj(decoder_out)
+
+ encoder_out = model.joiner.encoder_proj(encoder_out)
+
+ T = encoder_out.size(1)
+ t = 0
+
+ B = HypothesisList()
+ B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[]))
+
+ max_sym_per_utt = 20000
+
+ sym_per_utt = 0
+
+ decoder_cache: Dict[str, torch.Tensor] = {}
+
+ while t < T and sym_per_utt < max_sym_per_utt:
+ # fmt: off
+ current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
+ # fmt: on
+ A = B
+ B = HypothesisList()
+
+ joint_cache: Dict[str, torch.Tensor] = {}
+
+ # TODO(fangjun): Implement prefix search to update the `log_prob`
+ # of hypotheses in A
+
+ while True:
+ y_star = A.get_most_probable()
+ A.remove(y_star)
+
+ cached_key = y_star.key
+
+ if cached_key not in decoder_cache:
+ decoder_input = torch.tensor(
+ [y_star.ys[-context_size:]],
+ device=device,
+ dtype=torch.int64,
+ ).reshape(1, context_size)
+
+ decoder_out = model.decoder(decoder_input, need_pad=False)
+ decoder_out = model.joiner.decoder_proj(decoder_out)
+ decoder_cache[cached_key] = decoder_out
+ else:
+ decoder_out = decoder_cache[cached_key]
+
+ cached_key += f"-t-{t}"
+ if cached_key not in joint_cache:
+ logits = model.joiner(
+ current_encoder_out,
+ decoder_out.unsqueeze(1),
+ project_input=False,
+ )
+
+ # TODO(fangjun): Scale the blank posterior
+ log_prob = (logits / temperature).log_softmax(dim=-1)
+ # log_prob is (1, 1, 1, vocab_size)
+ log_prob = log_prob.squeeze()
+ # Now log_prob is (vocab_size,)
+ joint_cache[cached_key] = log_prob
+ else:
+ log_prob = joint_cache[cached_key]
+
+ # First, process the blank symbol
+ skip_log_prob = log_prob[blank_id]
+ new_y_star_log_prob = y_star.log_prob + skip_log_prob
+
+ # ys[:] returns a copy of ys
+ B.add(
+ Hypothesis(
+ ys=y_star.ys[:],
+ log_prob=new_y_star_log_prob,
+ timestamp=y_star.timestamp[:],
+ )
+ )
+
+ # Second, process other non-blank labels
+ values, indices = log_prob.topk(beam + 1)
+ for i, v in zip(indices.tolist(), values.tolist()):
+ if i in (blank_id, unk_id):
+ continue
+ new_ys = y_star.ys + [i]
+ new_log_prob = y_star.log_prob + v
+ new_timestamp = y_star.timestamp + [t]
+ A.add(
+ Hypothesis(
+ ys=new_ys,
+ log_prob=new_log_prob,
+ timestamp=new_timestamp,
+ )
+ )
+
+ # Check whether B contains more than "beam" elements more probable
+ # than the most probable in A
+ A_most_probable = A.get_most_probable()
+
+ kept_B = B.filter(A_most_probable.log_prob)
+
+ if len(kept_B) >= beam:
+ B = kept_B.topk(beam)
+ break
+
+ t += 1
+
+ best_hyp = B.get_most_probable(length_norm=True)
+ ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
+
+ if not return_timestamps:
+ return ys
+ else:
+ return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp])
+
+
+@dataclass
+class Hypothesis:
+ # The predicted tokens so far.
+ # Newly predicted tokens are appended to `ys`.
+ ys: List[int]
+
+ # The log prob of ys.
+ # It contains only one entry.
+ log_prob: torch.Tensor
+
+ # timestamp[i] is the frame index after subsampling
+ # on which ys[i] is decoded
+ timestamp: List[int] = field(default_factory=list)
+
+ # the lm score for next token given the current ys
+ lm_score: Optional[torch.Tensor] = None
+
+ # the RNNLM states (h and c in LSTM)
+ state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
+
+ # N-gram LM state
+ state_cost: Optional[NgramLmStateCost] = None
+
+ @property
+ def key(self) -> str:
+ """Return a string representation of self.ys"""
+ return "_".join(map(str, self.ys))
+
+
+class HypothesisList(object):
+ def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
+ """
+ Args:
+ data:
+ A dict of Hypotheses. Its key is its `value.key`.
+ """
+ if data is None:
+ self._data = {}
+ else:
+ self._data = data
+
+ @property
+ def data(self) -> Dict[str, Hypothesis]:
+ return self._data
+
+ def add(self, hyp: Hypothesis) -> None:
+ """Add a Hypothesis to `self`.
+
+ If `hyp` already exists in `self`, its probability is updated using
+ `log-sum-exp` with the existed one.
+
+ Args:
+ hyp:
+ The hypothesis to be added.
+ """
+ key = hyp.key
+ if key in self:
+ old_hyp = self._data[key] # shallow copy
+ torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
+ else:
+ self._data[key] = hyp
+
+ def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
+ """Get the most probable hypothesis, i.e., the one with
+ the largest `log_prob`.
+
+ Args:
+ length_norm:
+ If True, the `log_prob` of a hypothesis is normalized by the
+ number of tokens in it.
+ Returns:
+ Return the hypothesis that has the largest `log_prob`.
+ """
+ if length_norm:
+ return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
+ else:
+ return max(self._data.values(), key=lambda hyp: hyp.log_prob)
+
+ def remove(self, hyp: Hypothesis) -> None:
+ """Remove a given hypothesis.
+
+ Caution:
+ `self` is modified **in-place**.
+
+ Args:
+ hyp:
+ The hypothesis to be removed from `self`.
+ Note: It must be contained in `self`. Otherwise,
+ an exception is raised.
+ """
+ key = hyp.key
+ assert key in self, f"{key} does not exist"
+ del self._data[key]
+
+ def filter(self, threshold: torch.Tensor) -> "HypothesisList":
+ """Remove all Hypotheses whose log_prob is less than threshold.
+
+ Caution:
+ `self` is not modified. Instead, a new HypothesisList is returned.
+
+ Returns:
+ Return a new HypothesisList containing all hypotheses from `self`
+ with `log_prob` being greater than the given `threshold`.
+ """
+ ans = HypothesisList()
+ for _, hyp in self._data.items():
+ if hyp.log_prob > threshold:
+ ans.add(hyp) # shallow copy
+ return ans
+
+ def topk(self, k: int) -> "HypothesisList":
+ """Return the top-k hypothesis."""
+ hyps = list(self._data.items())
+
+ hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
+
+ ans = HypothesisList(dict(hyps))
+ return ans
+
+ def __contains__(self, key: str):
+ return key in self._data
+
+ def __iter__(self):
+ return iter(self._data.values())
+
+ def __len__(self) -> int:
+ return len(self._data)
+
+ def __str__(self) -> str:
+ s = []
+ for key in self:
+ s.append(key)
+ return ", ".join(s)
+
+
+def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
+ """Return a ragged shape with axes [utt][num_hyps].
+
+ Args:
+ hyps:
+ len(hyps) == batch_size. It contains the current hypothesis for
+ each utterance in the batch.
+ Returns:
+ Return a ragged shape with 2 axes [utt][num_hyps]. Note that
+ the shape is on CPU.
+ """
+ num_hyps = [len(h) for h in hyps]
+
+ # torch.cumsum() is inclusive sum, so we put a 0 at the beginning
+ # to get exclusive sum later.
+ num_hyps.insert(0, 0)
+
+ num_hyps = torch.tensor(num_hyps)
+ row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
+ ans = k2.ragged.create_ragged_shape2(
+ row_splits=row_splits, cached_tot_size=row_splits[-1].item()
+ )
+ return ans
diff --git a/egs/libricss/SURT/dprnn_zipformer/decode.py b/egs/libricss/SURT/dprnn_zipformer/decode.py
new file mode 100755
index 000000000..6abbffe00
--- /dev/null
+++ b/egs/libricss/SURT/dprnn_zipformer/decode.py
@@ -0,0 +1,654 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+(1) greedy search
+./dprnn_zipformer/decode.py \
+ --epoch 30 \
+ --avg 9 \
+ --use-averaged-model true \
+ --exp-dir ./dprnn_zipformer/exp \
+ --max-duration 600 \
+ --decoding-method greedy_search
+
+(2) modified beam search
+./dprnn_zipformer/decode.py \
+ --epoch 30 \
+ --avg 9 \
+ --use-averaged-model true \
+ --exp-dir ./dprnn_zipformer/exp \
+ --max-duration 600 \
+ --decoding-method modified_beam_search \
+ --beam-size 4
+"""
+
+
+import argparse
+import logging
+from collections import defaultdict
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import k2
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from asr_datamodule import LibriCssAsrDataModule
+from beam_search import (
+ beam_search,
+ greedy_search,
+ greedy_search_batch,
+ modified_beam_search,
+)
+from lhotse.utils import EPSILON
+from train import add_model_arguments, get_params, get_surt_model
+
+from icefall import LmScorer, NgramLm
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.lexicon import Lexicon
+from icefall.utils import (
+ AttributeDict,
+ setup_logger,
+ store_transcripts,
+ str2bool,
+ write_surt_error_stats,
+)
+
+OVERLAP_RATIOS = ["0L", "0S", "OV10", "OV20", "OV30", "OV40"]
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=9,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="dprnn_zipformer/exp",
+ help="The experiment dir",
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--lang-dir",
+ type=Path,
+ default="data/lang_bpe_500",
+ help="The lang dir containing word table and LG graph",
+ )
+
+ parser.add_argument(
+ "--decoding-method",
+ type=str,
+ default="greedy_search",
+ help="""Possible values are:
+ - greedy_search
+ - beam_search
+ - modified_beam_search
+ """,
+ )
+
+ parser.add_argument(
+ "--beam-size",
+ type=int,
+ default=4,
+ help="""An integer indicating how many candidates we will keep for each
+ frame. Used only when --decoding-method is beam_search or
+ modified_beam_search.""",
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+ parser.add_argument(
+ "--max-sym-per-frame",
+ type=int,
+ default=1,
+ help="""Maximum number of symbols per frame.
+ Used only when --decoding_method is greedy_search""",
+ )
+
+ parser.add_argument(
+ "--save-masks",
+ type=str2bool,
+ default=False,
+ help="""If true, save masks generated by unmixing module.""",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def decode_one_batch(
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+) -> Dict[str, List[List[str]]]:
+ """Decode one batch and return the result in a dict. The dict has the
+ following format:
+
+ - key: It indicates the setting used for decoding. For example,
+ if greedy_search is used, it would be "greedy_search"
+ If beam search with a beam size of 7 is used, it would be
+ "beam_7"
+ - value: It contains the decoding result. `len(value)` equals to
+ batch size. `value[i]` is the decoding result for the i-th
+ utterance in the given batch.
+ Args:
+ params:
+ It's the return value of :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ batch:
+ It is the return value from iterating
+ `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
+ for the format of the `batch`.
+ Returns:
+ Return the decoding result. See above description for the format of
+ the returned dict.
+ """
+ device = next(model.parameters()).device
+ feature = batch["inputs"]
+ assert feature.ndim == 3
+
+ feature = feature.to(device)
+ feature_lens = batch["input_lens"].to(device)
+
+ # Apply the mask encoder
+ B, T, F = feature.shape
+ processed = model.mask_encoder(feature) # B,T,F*num_channels
+ masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1)
+ x_masked = [feature * m for m in masks]
+
+ masks_dict = {}
+ if params.save_masks:
+ # To save the masks, we split them by batch and trim each mask to the length of
+ # the corresponding feature. We save them in a dict, where the key is the
+ # cut ID and the value is the mask.
+ for i in range(B):
+ mask = torch.cat(
+ [x_masked[j][i, : feature_lens[i]] for j in range(params.num_channels)],
+ dim=-1,
+ )
+ mask = mask.cpu().numpy()
+ masks_dict[batch["cuts"][i].id] = mask
+
+ # Recognition
+ # Concatenate the inputs along the batch axis
+ h = torch.cat(x_masked, dim=0)
+ h_lens = feature_lens.repeat(params.num_channels)
+ encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
+
+ if model.joint_encoder_layer is not None:
+ encoder_out = model.joint_encoder_layer(encoder_out)
+
+ def _group_channels(hyps: List[str]) -> List[List[str]]:
+ """
+ Currently we have a batch of size M*B, where M is the number of
+ channels and B is the batch size. We need to group the hypotheses
+ into B groups, each of which contains M hypotheses.
+
+ Example:
+ hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2']
+ _group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']]
+ """
+ assert len(hyps) == B * params.num_channels
+ out_hyps = []
+ for i in range(B):
+ out_hyps.append(hyps[i::B])
+ return out_hyps
+
+ hyps = []
+ if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
+ hyp_tokens = greedy_search_batch(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp)
+ elif params.decoding_method == "modified_beam_search":
+ hyp_tokens = modified_beam_search(
+ model=model,
+ encoder_out=encoder_out,
+ encoder_out_lens=encoder_out_lens,
+ beam=params.beam_size,
+ )
+ for hyp in sp.decode(hyp_tokens):
+ hyps.append(hyp)
+ else:
+ batch_size = encoder_out.size(0)
+
+ for i in range(batch_size):
+ # fmt: off
+ encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
+ # fmt: on
+ if params.decoding_method == "greedy_search":
+ hyp = greedy_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ max_sym_per_frame=params.max_sym_per_frame,
+ )
+ elif params.decoding_method == "beam_search":
+ hyp = beam_search(
+ model=model,
+ encoder_out=encoder_out_i,
+ beam=params.beam_size,
+ )
+ else:
+ raise ValueError(
+ f"Unsupported decoding method: {params.decoding_method}"
+ )
+ hyps.append(sp.decode(hyp))
+
+ if params.decoding_method == "greedy_search":
+ return {"greedy_search": _group_channels(hyps)}, masks_dict
+ else:
+ return {f"beam_size_{params.beam_size}": _group_channels(hyps)}, masks_dict
+
+
+def decode_dataset(
+ dl: torch.utils.data.DataLoader,
+ params: AttributeDict,
+ model: nn.Module,
+ sp: spm.SentencePieceProcessor,
+) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
+ """Decode dataset.
+
+ Args:
+ dl:
+ PyTorch's dataloader containing the dataset to decode.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ sp:
+ The BPE model.
+ Returns:
+ Return a dict, whose key may be "greedy_search" if greedy search
+ is used, or it may be "beam_7" if beam size of 7 is used.
+ Its value is a list of tuples. Each tuple contains two elements:
+ The first is the reference transcript, and the second is the
+ predicted result.
+ """
+ num_cuts = 0
+
+ try:
+ num_batches = len(dl)
+ except TypeError:
+ num_batches = "?"
+
+ if params.decoding_method == "greedy_search":
+ log_interval = 50
+ else:
+ log_interval = 20
+
+ results = defaultdict(list)
+ masks = {}
+ for batch_idx, batch in enumerate(dl):
+ cut_ids = [cut.id for cut in batch["cuts"]]
+ cuts_batch = batch["cuts"]
+
+ hyps_dict, masks_dict = decode_one_batch(
+ params=params,
+ model=model,
+ sp=sp,
+ )
+ masks.update(masks_dict)
+
+ for name, hyps in hyps_dict.items():
+ this_batch = []
+ for cut_id, hyp_words in zip(cut_ids, hyps):
+ # Reference is a list of supervision texts sorted by start time.
+ ref_words = [
+ s.text.strip()
+ for s in sorted(
+ cuts_batch[cut_id].supervisions, key=lambda s: s.start
+ )
+ ]
+ this_batch.append((cut_id, ref_words, hyp_words))
+
+ results[name].extend(this_batch)
+
+ num_cuts += len(cut_ids)
+
+ if batch_idx % log_interval == 0:
+ batch_str = f"{batch_idx}/{num_batches}"
+
+ logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
+ return results, masks_dict
+
+
+def save_results(
+ params: AttributeDict,
+ test_set_name: str,
+ results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
+):
+ test_set_wers = dict()
+ for key, results in results_dict.items():
+ recog_path = (
+ params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ results = sorted(results)
+ store_transcripts(filename=recog_path, texts=results)
+ logging.info(f"The transcripts are stored in {recog_path}")
+
+ # The following prints out WERs, per-word error statistics and aligned
+ # ref/hyp pairs.
+ errs_filename = (
+ params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_filename, "w") as f:
+ wer = write_surt_error_stats(
+ f,
+ f"{test_set_name}-{key}",
+ results,
+ enable_log=True,
+ num_channels=params.num_channels,
+ )
+ test_set_wers[key] = wer
+
+ logging.info("Wrote detailed error stats to {}".format(errs_filename))
+
+ test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
+ errs_info = (
+ params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
+ )
+ with open(errs_info, "w") as f:
+ print("settings\tWER", file=f)
+ for key, val in test_set_wers:
+ print("{}\t{}".format(key, val), file=f)
+
+ s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
+ note = "\tbest for {}".format(test_set_name)
+ for key, val in test_set_wers:
+ s += "{}\t{}{}\n".format(key, val, note)
+ note = ""
+ logging.info(s)
+
+
+def save_masks(
+ params: AttributeDict,
+ test_set_name: str,
+ masks: List[torch.Tensor],
+):
+ masks_path = params.res_dir / f"masks-{test_set_name}.txt"
+ torch.save(masks, masks_path)
+ logging.info(f"The masks are stored in {masks_path}")
+
+
+@torch.no_grad()
+def main():
+ parser = get_parser()
+ LmScorer.add_arguments(parser)
+ LibriCssAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+ args.lang_dir = Path(args.lang_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ assert params.decoding_method in (
+ "greedy_search",
+ "beam_search",
+ "modified_beam_search",
+ ), f"Decoding method {params.decoding_method} is not supported."
+ params.res_dir = params.exp_dir / params.decoding_method
+
+ if params.iter > 0:
+ params.suffix = f"iter-{params.iter}-avg-{params.avg}"
+ else:
+ params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
+
+ if "beam_search" in params.decoding_method:
+ params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
+ else:
+ params.suffix += f"-context-{params.context_size}"
+ params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
+
+ if params.use_averaged_model:
+ params.suffix += "-use-averaged-model"
+
+ setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
+ logging.info("Decoding started")
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # and are defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.unk_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_surt_model(params)
+ assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
+ model.encoder.decode_chunk_size,
+ params.decode_chunk_len,
+ )
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to(device)
+ model.eval()
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ # we need cut ids to display recognition results.
+ args.return_cuts = True
+ libricss = LibriCssAsrDataModule(args)
+
+ dev_cuts = libricss.libricss_cuts(split="dev", type="ihm-mix").to_eager()
+ dev_cuts_grouped = [dev_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS]
+ test_cuts = libricss.libricss_cuts(split="test", type="ihm-mix").to_eager()
+ test_cuts_grouped = [
+ test_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS
+ ]
+
+ for dev_set, ol in zip(dev_cuts_grouped, OVERLAP_RATIOS):
+ dev_dl = libricss.test_dataloaders(dev_set)
+ results_dict, masks = decode_dataset(
+ dl=dev_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=f"dev_{ol}",
+ results_dict=results_dict,
+ )
+
+ if params.save_masks:
+ save_masks(
+ params=params,
+ test_set_name=f"dev_{ol}",
+ masks=masks,
+ )
+
+ for test_set, ol in zip(test_cuts_grouped, OVERLAP_RATIOS):
+ test_dl = libricss.test_dataloaders(test_set)
+ results_dict, masks = decode_dataset(
+ dl=test_dl,
+ params=params,
+ model=model,
+ sp=sp,
+ )
+
+ save_results(
+ params=params,
+ test_set_name=f"test_{ol}",
+ results_dict=results_dict,
+ )
+
+ if params.save_masks:
+ save_masks(
+ params=params,
+ test_set_name=f"test_{ol}",
+ masks=masks,
+ )
+
+ logging.info("Done!")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/libricss/SURT/dprnn_zipformer/decoder.py b/egs/libricss/SURT/dprnn_zipformer/decoder.py
new file mode 120000
index 000000000..8283d8c5a
--- /dev/null
+++ b/egs/libricss/SURT/dprnn_zipformer/decoder.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py
\ No newline at end of file
diff --git a/egs/libricss/SURT/dprnn_zipformer/dprnn.py b/egs/libricss/SURT/dprnn_zipformer/dprnn.py
new file mode 100644
index 000000000..440dea885
--- /dev/null
+++ b/egs/libricss/SURT/dprnn_zipformer/dprnn.py
@@ -0,0 +1,305 @@
+import random
+from typing import Optional, Tuple
+
+import torch
+import torch.nn as nn
+from einops import rearrange
+from scaling import ActivationBalancer, BasicNorm, DoubleSwish, ScaledLinear, ScaledLSTM
+from torch.autograd import Variable
+
+EPS = torch.finfo(torch.get_default_dtype()).eps
+
+
+def _pad_segment(input, segment_size):
+ # Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L342
+ # input is the features: (B, N, T)
+ batch_size, dim, seq_len = input.shape
+ segment_stride = segment_size // 2
+
+ rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size
+ if rest > 0:
+ pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type())
+ input = torch.cat([input, pad], 2)
+
+ pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(input.type())
+ input = torch.cat([pad_aux, input, pad_aux], 2)
+
+ return input, rest
+
+
+def split_feature(input, segment_size):
+ # Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L358
+ # split the feature into chunks of segment size
+ # input is the features: (B, N, T)
+
+ input, rest = _pad_segment(input, segment_size)
+ batch_size, dim, seq_len = input.shape
+ segment_stride = segment_size // 2
+
+ segments1 = (
+ input[:, :, :-segment_stride]
+ .contiguous()
+ .view(batch_size, dim, -1, segment_size)
+ )
+ segments2 = (
+ input[:, :, segment_stride:]
+ .contiguous()
+ .view(batch_size, dim, -1, segment_size)
+ )
+ segments = (
+ torch.cat([segments1, segments2], 3)
+ .view(batch_size, dim, -1, segment_size)
+ .transpose(2, 3)
+ )
+
+ return segments.contiguous(), rest
+
+
+def merge_feature(input, rest):
+ # Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L385
+ # merge the splitted features into full utterance
+ # input is the features: (B, N, L, K)
+
+ batch_size, dim, segment_size, _ = input.shape
+ segment_stride = segment_size // 2
+ input = (
+ input.transpose(2, 3).contiguous().view(batch_size, dim, -1, segment_size * 2)
+ ) # B, N, K, L
+
+ input1 = (
+ input[:, :, :, :segment_size]
+ .contiguous()
+ .view(batch_size, dim, -1)[:, :, segment_stride:]
+ )
+ input2 = (
+ input[:, :, :, segment_size:]
+ .contiguous()
+ .view(batch_size, dim, -1)[:, :, :-segment_stride]
+ )
+
+ output = input1 + input2
+ if rest > 0:
+ output = output[:, :, :-rest]
+
+ return output.contiguous() # B, N, T
+
+
+class RNNEncoderLayer(nn.Module):
+ """
+ RNNEncoderLayer is made up of lstm and feedforward networks.
+ Args:
+ input_size:
+ The number of expected features in the input (required).
+ hidden_size:
+ The hidden dimension of rnn layer.
+ dropout:
+ The dropout value (default=0.1).
+ layer_dropout:
+ The dropout value for model-level warmup (default=0.075).
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ hidden_size: int,
+ dropout: float = 0.1,
+ bidirectional: bool = False,
+ ) -> None:
+ super(RNNEncoderLayer, self).__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+
+ assert hidden_size >= input_size, (hidden_size, input_size)
+ self.lstm = ScaledLSTM(
+ input_size=input_size,
+ hidden_size=hidden_size // 2 if bidirectional else hidden_size,
+ proj_size=0,
+ num_layers=1,
+ dropout=0.0,
+ batch_first=True,
+ bidirectional=bidirectional,
+ )
+ self.norm_final = BasicNorm(input_size)
+
+ # try to ensure the output is close to zero-mean (or at least, zero-median). # noqa
+ self.balancer = ActivationBalancer(
+ num_channels=input_size,
+ channel_dim=-1,
+ min_positive=0.45,
+ max_positive=0.55,
+ max_abs=6.0,
+ )
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(
+ self,
+ src: torch.Tensor,
+ states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ warmup: float = 1.0,
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
+ """
+ Pass the input through the encoder layer.
+ Args:
+ src:
+ The sequence to the encoder layer (required).
+ Its shape is (S, N, E), where S is the sequence length,
+ N is the batch size, and E is the feature number.
+ states:
+ A tuple of 2 tensors (optional). It is for streaming inference.
+ states[0] is the hidden states of all layers,
+ with shape of (1, N, input_size);
+ states[1] is the cell states of all layers,
+ with shape of (1, N, hidden_size).
+ """
+ src_orig = src
+
+ # alpha = 1.0 means fully use this encoder layer, 0.0 would mean
+ # completely bypass it.
+ alpha = warmup if self.training else 1.0
+
+ # lstm module
+ src_lstm, new_states = self.lstm(src, states)
+ src = self.dropout(src_lstm) + src
+ src = self.norm_final(self.balancer(src))
+
+ if alpha != 1.0:
+ src = alpha * src + (1 - alpha) * src_orig
+
+ return src
+
+
+# dual-path RNN
+class DPRNN(nn.Module):
+ """Deep dual-path RNN.
+ Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py
+
+ args:
+ input_size: int, dimension of the input feature. The input should have shape
+ (batch, seq_len, input_size).
+ hidden_size: int, dimension of the hidden state.
+ output_size: int, dimension of the output size.
+ dropout: float, dropout ratio. Default is 0.
+ num_blocks: int, number of stacked RNN layers. Default is 1.
+ """
+
+ def __init__(
+ self,
+ feature_dim,
+ input_size,
+ hidden_size,
+ output_size,
+ dropout=0.1,
+ num_blocks=1,
+ segment_size=50,
+ chunk_width_randomization=False,
+ ):
+ super().__init__()
+
+ self.input_size = input_size
+ self.output_size = output_size
+ self.hidden_size = hidden_size
+
+ self.segment_size = segment_size
+ self.chunk_width_randomization = chunk_width_randomization
+
+ self.input_embed = nn.Sequential(
+ ScaledLinear(feature_dim, input_size),
+ BasicNorm(input_size),
+ ActivationBalancer(
+ num_channels=input_size,
+ channel_dim=-1,
+ min_positive=0.45,
+ max_positive=0.55,
+ ),
+ )
+
+ # dual-path RNN
+ self.row_rnn = nn.ModuleList([])
+ self.col_rnn = nn.ModuleList([])
+ for _ in range(num_blocks):
+ # intra-RNN is non-causal
+ self.row_rnn.append(
+ RNNEncoderLayer(
+ input_size, hidden_size, dropout=dropout, bidirectional=True
+ )
+ )
+ self.col_rnn.append(
+ RNNEncoderLayer(
+ input_size, hidden_size, dropout=dropout, bidirectional=False
+ )
+ )
+
+ # output layer
+ self.out_embed = nn.Sequential(
+ ScaledLinear(input_size, output_size),
+ BasicNorm(output_size),
+ ActivationBalancer(
+ num_channels=output_size,
+ channel_dim=-1,
+ min_positive=0.45,
+ max_positive=0.55,
+ ),
+ )
+
+ def forward(self, input):
+ # input shape: B, T, F
+ input = self.input_embed(input)
+ B, T, D = input.shape
+
+ if self.chunk_width_randomization and self.training:
+ segment_size = random.randint(self.segment_size // 2, self.segment_size)
+ else:
+ segment_size = self.segment_size
+ input, rest = split_feature(input.transpose(1, 2), segment_size)
+ # input shape: batch, N, dim1, dim2
+ # apply RNN on dim1 first and then dim2
+ # output shape: B, output_size, dim1, dim2
+ # input = input.to(device)
+ batch_size, _, dim1, dim2 = input.shape
+ output = input
+ for i in range(len(self.row_rnn)):
+ row_input = (
+ output.permute(0, 3, 2, 1)
+ .contiguous()
+ .view(batch_size * dim2, dim1, -1)
+ ) # B*dim2, dim1, N
+ output = self.row_rnn[i](row_input) # B*dim2, dim1, H
+ output = (
+ output.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
+ ) # B, N, dim1, dim2
+
+ col_input = (
+ output.permute(0, 2, 3, 1)
+ .contiguous()
+ .view(batch_size * dim1, dim2, -1)
+ ) # B*dim1, dim2, N
+ output = self.col_rnn[i](col_input) # B*dim1, dim2, H
+ output = (
+ output.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
+ ) # B, N, dim1, dim2
+
+ output = merge_feature(output, rest)
+ output = output.transpose(1, 2)
+ output = self.out_embed(output)
+
+ # Apply ReLU to the output
+ output = torch.relu(output)
+
+ return output
+
+
+if __name__ == "__main__":
+
+ model = DPRNN(
+ 80,
+ 256,
+ 256,
+ 160,
+ dropout=0.1,
+ num_blocks=4,
+ segment_size=32,
+ chunk_width_randomization=True,
+ )
+ input = torch.randn(2, 1002, 80)
+ print(sum(p.numel() for p in model.parameters()))
+ print(model(input).shape)
diff --git a/egs/libricss/SURT/dprnn_zipformer/encoder_interface.py b/egs/libricss/SURT/dprnn_zipformer/encoder_interface.py
new file mode 120000
index 000000000..0c2673d46
--- /dev/null
+++ b/egs/libricss/SURT/dprnn_zipformer/encoder_interface.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py
\ No newline at end of file
diff --git a/egs/libricss/SURT/dprnn_zipformer/export.py b/egs/libricss/SURT/dprnn_zipformer/export.py
new file mode 100755
index 000000000..f51f2a7ab
--- /dev/null
+++ b/egs/libricss/SURT/dprnn_zipformer/export.py
@@ -0,0 +1,306 @@
+#!/usr/bin/env python3
+#
+# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script converts several saved checkpoints
+# to a single one using model averaging.
+"""
+
+Usage:
+
+(1) Export to torchscript model using torch.jit.script()
+
+./dprnn_zipformer/export.py \
+ --exp-dir ./dprnn_zipformer/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --epoch 30 \
+ --avg 9 \
+ --jit 1
+
+It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
+load it by `torch.jit.load("cpu_jit.pt")`.
+
+Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
+are on CPU. You can use `to("cuda")` to move them to a CUDA device.
+
+Check
+https://github.com/k2-fsa/sherpa
+for how to use the exported models outside of icefall.
+
+(2) Export `model.state_dict()`
+
+./dprnn_zipformer/export.py \
+ --exp-dir ./dprnn_zipformer/exp \
+ --bpe-model data/lang_bpe_500/bpe.model \
+ --epoch 30 \
+ --avg 9
+
+It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
+load it by `icefall.checkpoint.load_checkpoint()`.
+
+To use the generated file with `dprnn_zipformer/decode.py`,
+you can do:
+
+ cd /path/to/exp_dir
+ ln -s pretrained.pt epoch-9999.pt
+
+ cd /path/to/egs/librispeech/ASR
+ ./dprnn_zipformer/decode.py \
+ --exp-dir ./dprnn_zipformer/exp \
+ --epoch 9999 \
+ --avg 1 \
+ --max-duration 600 \
+ --decoding-method greedy_search \
+ --bpe-model data/lang_bpe_500/bpe.model
+"""
+
+import argparse
+import logging
+from pathlib import Path
+
+import sentencepiece as spm
+import torch
+import torch.nn as nn
+from scaling_converter import convert_scaled_to_non_scaled
+from train import add_model_arguments, get_params, get_surt_model
+
+from icefall.checkpoint import (
+ average_checkpoints,
+ average_checkpoints_with_averaged_model,
+ find_checkpoints,
+ load_checkpoint,
+)
+from icefall.utils import str2bool
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--epoch",
+ type=int,
+ default=30,
+ help="""It specifies the checkpoint to use for decoding.
+ Note: Epoch counts from 1.
+ You can specify --avg to use more checkpoints for model averaging.""",
+ )
+
+ parser.add_argument(
+ "--iter",
+ type=int,
+ default=0,
+ help="""If positive, --epoch is ignored and it
+ will use the checkpoint exp_dir/checkpoint-iter.pt.
+ You can specify --avg to use more checkpoints for model averaging.
+ """,
+ )
+
+ parser.add_argument(
+ "--avg",
+ type=int,
+ default=9,
+ help="Number of checkpoints to average. Automatically select "
+ "consecutive checkpoints before the checkpoint specified by "
+ "'--epoch' and '--iter'",
+ )
+
+ parser.add_argument(
+ "--use-averaged-model",
+ type=str2bool,
+ default=True,
+ help="Whether to load averaged model. Currently it only supports "
+ "using --epoch. If True, it would decode with the averaged model "
+ "over the epoch range from `epoch-avg` (excluded) to `epoch`."
+ "Actually only the models with epoch number of `epoch-avg` and "
+ "`epoch` are loaded for averaging. ",
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="dprnn_zipformer/exp",
+ help="""It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--jit",
+ type=str2bool,
+ default=False,
+ help="""True to save a model after applying torch.jit.script.
+ It will generate a file named cpu_jit.pt
+
+ Check ./jit_pretrained.py for how to use it.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+@torch.no_grad()
+def main():
+ args = get_parser().parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ params = get_params()
+ params.update(vars(args))
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", 0)
+
+ logging.info(f"device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_surt_model(params)
+
+ model.to(device)
+
+ if not params.use_averaged_model:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ elif params.avg == 1:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ else:
+ start = params.epoch - params.avg + 1
+ filenames = []
+ for i in range(start, params.epoch + 1):
+ if i >= 1:
+ filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
+ logging.info(f"averaging {filenames}")
+ model.to(device)
+ model.load_state_dict(average_checkpoints(filenames, device=device))
+ else:
+ if params.iter > 0:
+ filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
+ : params.avg + 1
+ ]
+ if len(filenames) == 0:
+ raise ValueError(
+ f"No checkpoints found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ elif len(filenames) < params.avg + 1:
+ raise ValueError(
+ f"Not enough checkpoints ({len(filenames)}) found for"
+ f" --iter {params.iter}, --avg {params.avg}"
+ )
+ filename_start = filenames[-1]
+ filename_end = filenames[0]
+ logging.info(
+ "Calculating the averaged model over iteration checkpoints"
+ f" from {filename_start} (excluded) to {filename_end}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ else:
+ assert params.avg > 0, params.avg
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
+ )
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+
+ model.to("cpu")
+ model.eval()
+
+ if params.jit is True:
+ convert_scaled_to_non_scaled(model, inplace=True)
+ # We won't use the forward() method of the model in C++, so just ignore
+ # it here.
+ # Otherwise, one of its arguments is a ragged tensor and is not
+ # torch scriptabe.
+ model.__class__.forward = torch.jit.ignore(model.__class__.forward)
+ logging.info("Using torch.jit.script")
+ model = torch.jit.script(model)
+ filename = params.exp_dir / "cpu_jit.pt"
+ model.save(str(filename))
+ logging.info(f"Saved to {filename}")
+ else:
+ logging.info("Not using torchscript. Export model.state_dict()")
+ # Save it using a format so that it can be loaded
+ # by :func:`load_checkpoint`
+ filename = params.exp_dir / "pretrained.pt"
+ torch.save({"model": model.state_dict()}, str(filename))
+ logging.info(f"Saved to {filename}")
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ main()
diff --git a/egs/libricss/SURT/dprnn_zipformer/joiner.py b/egs/libricss/SURT/dprnn_zipformer/joiner.py
new file mode 120000
index 000000000..0f0c3c90a
--- /dev/null
+++ b/egs/libricss/SURT/dprnn_zipformer/joiner.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py
\ No newline at end of file
diff --git a/egs/libricss/SURT/dprnn_zipformer/model.py b/egs/libricss/SURT/dprnn_zipformer/model.py
new file mode 100644
index 000000000..688e1e78d
--- /dev/null
+++ b/egs/libricss/SURT/dprnn_zipformer/model.py
@@ -0,0 +1,316 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
+# Copyright 2023 Johns Hopkins University (author: Desh Raj)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List, Optional, Tuple
+
+import k2
+import torch
+import torch.nn as nn
+from encoder_interface import EncoderInterface
+
+from icefall.utils import add_sos
+
+
+class SURT(nn.Module):
+ """It implements Streaming Unmixing and Recognition Transducer (SURT).
+ https://arxiv.org/abs/2011.13148
+ """
+
+ def __init__(
+ self,
+ mask_encoder: nn.Module,
+ encoder: EncoderInterface,
+ joint_encoder_layer: Optional[nn.Module],
+ decoder: nn.Module,
+ joiner: nn.Module,
+ num_channels: int,
+ encoder_dim: int,
+ decoder_dim: int,
+ joiner_dim: int,
+ vocab_size: int,
+ ):
+ """
+ Args:
+ mask_encoder:
+ It is the masking network. It generates a mask for each channel of the
+ encoder. These masks are applied to the input features, and then passed
+ to the transcription network.
+ encoder:
+ It is the transcription network in the paper. Its accepts
+ two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
+ It returns two tensors: `logits` of shape (N, T, encoder_dm) and
+ `logit_lens` of shape (N,).
+ decoder:
+ It is the prediction network in the paper. Its input shape
+ is (N, U) and its output shape is (N, U, decoder_dim).
+ It should contain one attribute: `blank_id`.
+ joiner:
+ It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
+ Its output shape is (N, T, U, vocab_size). Note that its output contains
+ unnormalized probs, i.e., not processed by log-softmax.
+ num_channels:
+ It is the number of channels that the input features will be split into.
+ In general, it should be equal to the maximum number of simultaneously
+ active speakers. For most real scenarios, using 2 channels is sufficient.
+ """
+ super().__init__()
+ assert isinstance(encoder, EncoderInterface), type(encoder)
+ assert hasattr(decoder, "blank_id")
+
+ self.mask_encoder = mask_encoder
+ self.encoder = encoder
+ self.joint_encoder_layer = joint_encoder_layer
+ self.decoder = decoder
+ self.joiner = joiner
+ self.num_channels = num_channels
+
+ self.simple_am_proj = nn.Linear(
+ encoder_dim,
+ vocab_size,
+ )
+ self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
+
+ self.ctc_output = nn.Sequential(
+ nn.Dropout(p=0.1),
+ nn.Linear(encoder_dim, vocab_size),
+ nn.LogSoftmax(dim=-1),
+ )
+
+ def forward_helper(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ y: k2.RaggedTensor,
+ prune_range: int = 5,
+ am_scale: float = 0.0,
+ lm_scale: float = 0.0,
+ reduction: str = "sum",
+ beam_size: int = 10,
+ use_double_scores: bool = False,
+ subsampling_factor: int = 1,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Compute transducer loss for one branch of the SURT model.
+ """
+ encoder_out, x_lens = self.encoder(x, x_lens)
+ assert torch.all(x_lens > 0)
+
+ if self.joint_encoder_layer is not None:
+ encoder_out = self.joint_encoder_layer(encoder_out)
+
+ # compute ctc log-probs
+ ctc_output = self.ctc_output(encoder_out)
+
+ # For the decoder, i.e., the prediction network
+ row_splits = y.shape.row_splits(1)
+ y_lens = row_splits[1:] - row_splits[:-1]
+
+ blank_id = self.decoder.blank_id
+ sos_y = add_sos(y, sos_id=blank_id)
+
+ # sos_y_padded: [B, S + 1], start with SOS.
+ sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
+
+ # decoder_out: [B, S + 1, decoder_dim]
+ decoder_out = self.decoder(sos_y_padded)
+
+ # Note: y does not start with SOS
+ # y_padded : [B, S]
+ y_padded = y.pad(mode="constant", padding_value=0)
+
+ y_padded = y_padded.to(torch.int64)
+ boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
+ boundary[:, 2] = y_lens
+ boundary[:, 3] = x_lens
+
+ lm = self.simple_lm_proj(decoder_out)
+ am = self.simple_am_proj(encoder_out)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
+ lm=lm.float(),
+ am=am.float(),
+ symbols=y_padded,
+ termination_symbol=blank_id,
+ lm_only_scale=lm_scale,
+ am_only_scale=am_scale,
+ boundary=boundary,
+ reduction=reduction,
+ return_grad=True,
+ )
+
+ # ranges : [B, T, prune_range]
+ ranges = k2.get_rnnt_prune_ranges(
+ px_grad=px_grad,
+ py_grad=py_grad,
+ boundary=boundary,
+ s_range=prune_range,
+ )
+
+ # am_pruned : [B, T, prune_range, encoder_dim]
+ # lm_pruned : [B, T, prune_range, decoder_dim]
+ am_pruned, lm_pruned = k2.do_rnnt_pruning(
+ am=self.joiner.encoder_proj(encoder_out),
+ lm=self.joiner.decoder_proj(decoder_out),
+ ranges=ranges,
+ )
+
+ # logits : [B, T, prune_range, vocab_size]
+
+ # project_input=False since we applied the decoder's input projections
+ # prior to do_rnnt_pruning (this is an optimization for speed).
+ logits = self.joiner(am_pruned, lm_pruned, project_input=False)
+
+ with torch.cuda.amp.autocast(enabled=False):
+ pruned_loss = k2.rnnt_loss_pruned(
+ logits=logits.float(),
+ symbols=y_padded,
+ ranges=ranges,
+ termination_symbol=blank_id,
+ boundary=boundary,
+ reduction=reduction,
+ )
+
+ # Compute ctc loss
+ supervision_segments = torch.stack(
+ (
+ torch.arange(len(x_lens), device="cpu"),
+ torch.zeros_like(x_lens, device="cpu"),
+ torch.clone(x_lens).detach().cpu(),
+ ),
+ dim=1,
+ ).to(torch.int32)
+ # We need to sort supervision_segments in decreasing order of num_frames
+ indices = torch.argsort(supervision_segments[:, 2], descending=True)
+ supervision_segments = supervision_segments[indices]
+
+ # Works with a BPE model
+ decoding_graph = k2.ctc_graph(y, modified=False, device=x.device)
+ dense_fsa_vec = k2.DenseFsaVec(
+ ctc_output,
+ supervision_segments,
+ allow_truncate=subsampling_factor - 1,
+ )
+ ctc_loss = k2.ctc_loss(
+ decoding_graph=decoding_graph,
+ dense_fsa_vec=dense_fsa_vec,
+ output_beam=beam_size,
+ reduction="none",
+ use_double_scores=use_double_scores,
+ )
+
+ return (simple_loss, pruned_loss, ctc_loss)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_lens: torch.Tensor,
+ y: k2.RaggedTensor,
+ prune_range: int = 5,
+ am_scale: float = 0.0,
+ lm_scale: float = 0.0,
+ reduction: str = "sum",
+ beam_size: int = 10,
+ use_double_scores: bool = False,
+ subsampling_factor: int = 1,
+ return_masks: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Args:
+ x:
+ A 3-D tensor of shape (N, T, C).
+ x_lens:
+ A 1-D tensor of shape (N,). It contains the number of frames in `x`
+ before padding.
+ y:
+ A ragged tensor of shape (N*num_channels, S). It contains the labels
+ of the N utterances. The labels are in the range [0, vocab_size). All
+ the channels are concatenated together one after another.
+ prune_range:
+ The prune range for rnnt loss, it means how many symbols(context)
+ we are considering for each frame to compute the loss.
+ am_scale:
+ The scale to smooth the loss with am (output of encoder network)
+ part
+ lm_scale:
+ The scale to smooth the loss with lm (output of predictor network)
+ part
+ reduction:
+ "sum" to sum the losses over all utterances in the batch.
+ "none" to return the loss in a 1-D tensor for each utterance
+ in the batch.
+ beam_size:
+ The beam size used in CTC decoding.
+ use_double_scores:
+ If True, use double precision for CTC decoding.
+ subsampling_factor:
+ The subsampling factor of the model. It is used to compute the
+ supervision segments for CTC loss.
+ return_masks:
+ If True, return the masks as well as masked features.
+ Returns:
+ Return the transducer loss.
+
+ Note:
+ Regarding am_scale & lm_scale, it will make the loss-function one of
+ the form:
+ lm_scale * lm_probs + am_scale * am_probs +
+ (1-lm_scale-am_scale) * combined_probs
+ """
+ assert x.ndim == 3, x.shape
+ assert x_lens.ndim == 1, x_lens.shape
+ assert y.num_axes == 2, y.num_axes
+
+ assert x.size(0) == x_lens.size(0), (x.size(), x_lens.size())
+
+ # Apply the mask encoder
+ B, T, F = x.shape
+ processed = self.mask_encoder(x) # B,T,F*num_channels
+ masks = processed.view(B, T, F, self.num_channels).unbind(dim=-1)
+ x_masked = [x * m for m in masks]
+
+ # Recognition
+ # Stack the inputs along the batch axis
+ h = torch.cat(x_masked, dim=0)
+ h_lens = torch.cat([x_lens for _ in range(self.num_channels)], dim=0)
+
+ simple_loss, pruned_loss, ctc_loss = self.forward_helper(
+ h,
+ h_lens,
+ y,
+ prune_range,
+ am_scale,
+ lm_scale,
+ reduction=reduction,
+ beam_size=beam_size,
+ use_double_scores=use_double_scores,
+ subsampling_factor=subsampling_factor,
+ )
+
+ # Chunks the outputs into 2 parts along batch axis and then stack them along a new axis.
+ simple_loss = torch.stack(
+ torch.chunk(simple_loss, self.num_channels, dim=0), dim=0
+ )
+ pruned_loss = torch.stack(
+ torch.chunk(pruned_loss, self.num_channels, dim=0), dim=0
+ )
+ ctc_loss = torch.stack(torch.chunk(ctc_loss, self.num_channels, dim=0), dim=0)
+
+ if return_masks:
+ return (simple_loss, pruned_loss, ctc_loss, x_masked, masks)
+ else:
+ return (simple_loss, pruned_loss, ctc_loss, x_masked)
diff --git a/egs/libricss/SURT/dprnn_zipformer/optim.py b/egs/libricss/SURT/dprnn_zipformer/optim.py
new file mode 120000
index 000000000..8a05abb5f
--- /dev/null
+++ b/egs/libricss/SURT/dprnn_zipformer/optim.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/optim.py
\ No newline at end of file
diff --git a/egs/libricss/SURT/dprnn_zipformer/scaling.py b/egs/libricss/SURT/dprnn_zipformer/scaling.py
new file mode 120000
index 000000000..5f9be9fe0
--- /dev/null
+++ b/egs/libricss/SURT/dprnn_zipformer/scaling.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py
\ No newline at end of file
diff --git a/egs/libricss/SURT/dprnn_zipformer/scaling_converter.py b/egs/libricss/SURT/dprnn_zipformer/scaling_converter.py
new file mode 120000
index 000000000..f9960e5c6
--- /dev/null
+++ b/egs/libricss/SURT/dprnn_zipformer/scaling_converter.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py
\ No newline at end of file
diff --git a/egs/libricss/SURT/dprnn_zipformer/train.py b/egs/libricss/SURT/dprnn_zipformer/train.py
new file mode 100755
index 000000000..6598f8b5d
--- /dev/null
+++ b/egs/libricss/SURT/dprnn_zipformer/train.py
@@ -0,0 +1,1452 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,)
+# Zengwei Yao)
+# 2023 Johns Hopkins University (author: Desh Raj)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES="0,1,2,3"
+
+cd egs/libricss/SURT
+./prepare.sh
+
+./dprnn_zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --exp-dir dprnn_zipformer/exp \
+ --max-duration 300
+
+# For mix precision training:
+
+./dprnn_zipformer/train.py \
+ --world-size 4 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir dprnn_zipformer/exp \
+ --max-duration 550
+"""
+
+import argparse
+import copy
+import logging
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriCssAsrDataModule
+from decoder import Decoder
+from dprnn import DPRNN
+from einops.layers.torch import Rearrange
+from graph_pit.loss.optimized import optimized_graph_pit_mse_loss as gpit_mse
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import LOG_EPSILON, fix_random_seed
+from model import SURT
+from optim import Eden, ScaledAdam
+from scaling import ScaledLSTM
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for module in model.modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-mask-encoder-layers",
+ type=int,
+ default=4,
+ help="Number of layers in the DPRNN based mask encoder.",
+ )
+
+ parser.add_argument(
+ "--mask-encoder-dim",
+ type=int,
+ default=256,
+ help="Hidden dimension of the LSTM blocks in DPRNN.",
+ )
+
+ parser.add_argument(
+ "--mask-encoder-segment-size",
+ type=int,
+ default=32,
+ help="Segment size of the SegLSTM in DPRNN. Ideally, this should be equal to the "
+ "decode-chunk-length of the zipformer encoder.",
+ )
+
+ parser.add_argument(
+ "--chunk-width-randomization",
+ type=bool,
+ default=False,
+ help="Whether to randomize the chunk width in DPRNN.",
+ )
+
+ # Zipformer config is based on:
+ # https://github.com/k2-fsa/icefall/pull/745#issuecomment-1405282740
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,2,2,2",
+ help="Number of zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dims",
+ type=str,
+ default="768,768,768,768,768",
+ help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--nhead",
+ type=str,
+ default="8,8,8,8,8",
+ help="Number of attention heads in the zipformer encoder layers.",
+ )
+
+ parser.add_argument(
+ "--encoder-dims",
+ type=str,
+ default="256,256,256,256,256",
+ help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+ )
+
+ parser.add_argument(
+ "--attention-dims",
+ type=str,
+ default="192,192,192,192,192",
+ help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+ not the same as embedding dimension.""",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dims",
+ type=str,
+ default="192,192,192,192,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
+ " worse.",
+ )
+
+ parser.add_argument(
+ "--zipformer-downsampling-factors",
+ type=str,
+ default="1,2,4,8,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernels",
+ type=str,
+ default="31,31,31,31,31",
+ help="Sizes of kernels in convolution modules",
+ )
+
+ parser.add_argument(
+ "--use-joint-encoder-layer",
+ type=str,
+ default="lstm",
+ choices=["linear", "lstm", "none"],
+ help="Whether to use a joint layer to combine all branches.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--short-chunk-size",
+ type=int,
+ default=50,
+ help="""Chunk length of dynamic training, the chunk size would be either
+ max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
+ """,
+ )
+
+ parser.add_argument(
+ "--num-left-chunks",
+ type=int,
+ default=4,
+ help="How many left context can be seen in chunks when calculating attention.",
+ )
+
+ parser.add_argument(
+ "--decode-chunk-len",
+ type=int,
+ default=32,
+ help="The chunk size for decoding (in frames before subsampling)",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=30,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="conv_lstm_transducer_stateless_ctc/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--model-init-ckpt",
+ type=str,
+ default=None,
+ help="""The model checkpoint to initialize the model (either full or part).
+ If not specified, the model is randomly initialized.
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.004, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=6,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network) part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--heat-loss-scale",
+ type=float,
+ default=0.0,
+ help="Scale for HEAT loss on separated sources.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=2000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=1,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=100,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warm_step for Noam optimizer.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 2000,
+ # parameters for SURT
+ "num_channels": 2,
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed
+ # parameters for Noam
+ "model_warm_step": 5000, # arg given to model, not for lrate
+ # parameters for ctc loss
+ "beam_size": 10,
+ "use_double_scores": True,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_mask_encoder_model(params: AttributeDict) -> nn.Module:
+ mask_encoder = DPRNN(
+ feature_dim=params.feature_dim,
+ input_size=params.mask_encoder_dim,
+ hidden_size=params.mask_encoder_dim,
+ output_size=params.feature_dim * params.num_channels,
+ segment_size=params.mask_encoder_segment_size,
+ num_blocks=params.num_mask_encoder_layers,
+ chunk_width_randomization=params.chunk_width_randomization,
+ )
+ return mask_encoder
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Zipformer and Transformer
+ def to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+ encoder = Zipformer(
+ num_features=params.feature_dim,
+ output_downsampling_factor=2,
+ zipformer_downsampling_factors=to_int_tuple(
+ params.zipformer_downsampling_factors
+ ),
+ encoder_dims=to_int_tuple(params.encoder_dims),
+ attention_dim=to_int_tuple(params.attention_dims),
+ encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+ nhead=to_int_tuple(params.nhead),
+ feedforward_dim=to_int_tuple(params.feedforward_dims),
+ cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+ num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+ num_left_chunks=params.num_left_chunks,
+ short_chunk_size=params.short_chunk_size,
+ decode_chunk_size=params.decode_chunk_len // 2,
+ )
+ return encoder
+
+
+def get_joint_encoder_layer(params: AttributeDict) -> nn.Module:
+ class TakeFirst(nn.Module):
+ def forward(self, x):
+ return x[0]
+
+ if params.use_joint_encoder_layer == "linear":
+ encoder_dim = int(params.encoder_dims.split(",")[-1])
+ joint_layer = nn.Sequential(
+ Rearrange("(c b) t d -> b t (c d)", c=params.num_channels),
+ nn.Linear(
+ params.num_channels * encoder_dim, params.num_channels * encoder_dim
+ ),
+ nn.ReLU(),
+ Rearrange("b t (c d) -> (c b) t d", c=params.num_channels),
+ )
+ elif params.use_joint_encoder_layer == "lstm":
+ encoder_dim = int(params.encoder_dims.split(",")[-1])
+ joint_layer = nn.Sequential(
+ Rearrange("(c b) t d -> b t (c d)", c=params.num_channels),
+ ScaledLSTM(
+ input_size=params.num_channels * encoder_dim,
+ hidden_size=params.num_channels * encoder_dim,
+ num_layers=1,
+ bias=True,
+ batch_first=True,
+ dropout=0.0,
+ bidirectional=False,
+ ),
+ TakeFirst(),
+ nn.ReLU(),
+ Rearrange("b t (c d) -> (c b) t d", c=params.num_channels),
+ )
+ elif params.use_joint_encoder_layer == "none":
+ joint_layer = None
+ else:
+ raise ValueError(
+ f"Unknown joint encoder layer type: {params.use_joint_encoder_layer}"
+ )
+ return joint_layer
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_surt_model(
+ params: AttributeDict,
+) -> nn.Module:
+ mask_encoder = get_mask_encoder_model(params)
+ encoder = get_encoder_model(params)
+ joint_layer = get_joint_encoder_layer(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = SURT(
+ mask_encoder=mask_encoder,
+ encoder=encoder,
+ joint_encoder_layer=joint_layer,
+ decoder=decoder,
+ joiner=joiner,
+ num_channels=params.num_channels,
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_heat_loss(x_masked, batch, num_channels=2) -> Tensor:
+ """
+ Compute HEAT loss for separated sources using the output of mask encoder.
+ Args:
+ x_masked:
+ The output of mask encoder. It is a tensor of shape (B, T, C).
+ batch:
+ A batch of data. See `lhotse.dataset.K2SurtDatasetWithSources()`
+ for the content in it.
+ num_channels:
+ The number of output branches in the SURT model.
+ """
+ B, T, D = x_masked[0].shape
+ device = x_masked[0].device
+
+ # Create training targets for each channel.
+ targets = []
+ for i in range(num_channels):
+ target = torch.ones_like(x_masked[i]) * LOG_EPSILON
+ targets.append(target)
+
+ source_feats = batch["source_feats"]
+ source_boundaries = batch["source_boundaries"]
+ input_lens = batch["input_lens"].to(device)
+ # Assign sources to channels based on the HEAT criteria
+ for b in range(B):
+ cut_source_feats = source_feats[b]
+ cut_source_boundaries = source_boundaries[b]
+ last_seg_end = [0 for _ in range(num_channels)]
+ for source_feat, (start, end) in zip(cut_source_feats, cut_source_boundaries):
+ assigned = False
+ for i in range(num_channels):
+ if start >= last_seg_end[i]:
+ targets[i][b, start:end, :] += source_feat.to(device)
+ last_seg_end[i] = max(end, last_seg_end[i])
+ assigned = True
+ break
+ if not assigned:
+ min_end_channel = last_seg_end.index(min(last_seg_end))
+ targets[min_end_channel][b, start:end, :] += source_feat
+ last_seg_end[min_end_channel] = max(end, last_seg_end[min_end_channel])
+
+ # Get padding mask based on input lengths
+ pad_mask = torch.arange(T, device=device).expand(B, T) > input_lens.unsqueeze(1)
+ pad_mask = pad_mask.unsqueeze(-1)
+
+ # Compute masked loss for each channel
+ losses = torch.zeros((num_channels, B, T, D), device=device)
+ for i in range(num_channels):
+ loss = nn.functional.mse_loss(x_masked[i], targets[i], reduction="none")
+ # Apply padding mask to loss
+ loss.masked_fill_(pad_mask, 0)
+ losses[i] = loss
+
+ # loss: C x B x T x D. pad_mask: B x T x 1
+ # We want to compute loss for each item in the batch. Each item has loss given
+ # by the sum over C, and average over T and D. For T, we need to use the padding.
+ loss = losses.sum(0).mean(-1).sum(-1) / batch["input_lens"].to(device)
+ return loss
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute RNN-T loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Conformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"].to(device)
+ feature_lens = batch["input_lens"].to(device)
+
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+
+ # The dataloader returns text as a list of cuts, each of which is a list of channel
+ # text. We flatten this to a list where all channels are together, i.e., it looks like
+ # [utt1_ch1, utt2_ch1, ..., uttN_ch1, utt1_ch2, ...., uttN,ch2].
+ text = [val for tup in zip(*batch["text"]) for val in tup]
+ assert len(text) == len(feature) * params.num_channels
+
+ # Convert all channel texts to token IDs and create a ragged tensor.
+ y = sp.encode(text, out_type=int)
+ y = k2.RaggedTensor(y).to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.model_warm_step
+
+ with torch.set_grad_enabled(is_training):
+ (simple_loss, pruned_loss, ctc_loss, x_masked) = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ reduction="none",
+ subsampling_factor=params.subsampling_factor,
+ )
+ simple_loss_is_finite = torch.isfinite(simple_loss)
+ pruned_loss_is_finite = torch.isfinite(pruned_loss)
+ ctc_loss_is_finite = torch.isfinite(ctc_loss)
+
+ # Compute HEAT loss
+ if is_training and params.heat_loss_scale > 0.0:
+ heat_loss = compute_heat_loss(
+ x_masked, batch, num_channels=params.num_channels
+ )
+ else:
+ heat_loss = torch.tensor(0.0, device=device)
+
+ heat_loss_is_finite = torch.isfinite(heat_loss)
+ is_finite = (
+ simple_loss_is_finite
+ & pruned_loss_is_finite
+ & ctc_loss_is_finite
+ & heat_loss_is_finite
+ )
+ if not torch.all(is_finite):
+ logging.info(
+ "Not all losses are finite!\n"
+ f"simple_losses: {simple_loss}\n"
+ f"pruned_losses: {pruned_loss}\n"
+ f"ctc_losses: {ctc_loss}\n"
+ f"heat_losses: {heat_loss}\n"
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ simple_loss = simple_loss[simple_loss_is_finite]
+ pruned_loss = pruned_loss[pruned_loss_is_finite]
+ ctc_loss = ctc_loss[ctc_loss_is_finite]
+ heat_loss = heat_loss[heat_loss_is_finite]
+
+ # If either all simple_loss or pruned_loss is inf or nan,
+ # we stop the training process by raising an exception
+ if (
+ torch.all(~simple_loss_is_finite)
+ or torch.all(~pruned_loss_is_finite)
+ or torch.all(~ctc_loss_is_finite)
+ or torch.all(~heat_loss_is_finite)
+ ):
+ raise ValueError(
+ "There are too many utterances in this batch "
+ "leading to inf or nan losses."
+ )
+
+ simple_loss_sum = simple_loss.sum()
+ pruned_loss_sum = pruned_loss.sum()
+ ctc_loss_sum = ctc_loss.sum()
+ heat_loss_sum = heat_loss.sum()
+
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss = (
+ simple_loss_scale * simple_loss_sum
+ + pruned_loss_scale * pruned_loss_sum
+ + params.ctc_loss_scale * ctc_loss_sum
+ + params.heat_loss_scale * heat_loss_sum
+ )
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ # info["frames"] is an approximate number for two reasons:
+ # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
+ # (2) If some utterances in the batch lead to inf/nan loss, they
+ # are filtered out.
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
+ info["utterances"] = feature.size(0)
+ # averaged input duration in frames over utterances
+ info["utt_duration"] = feature_lens.sum().item()
+ # averaged padding proportion over utterances
+ info["utt_pad_proportion"] = (
+ ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
+ )
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss_sum.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss_sum.detach().cpu().item()
+ if params.ctc_loss_scale > 0.0:
+ info["ctc_loss"] = ctc_loss_sum.detach().cpu().item()
+ if params.heat_loss_scale > 0.0:
+ info["heat_loss"] = heat_loss_sum.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ train_dl_warmup: Optional[torch.utils.data.DataLoader],
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ train_dl_warmup:
+ Dataloader for the training dataset with 2 speakers. This is used during the
+ warmup stage.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ torch.cuda.empty_cache()
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ iter_train = iter(train_dl)
+ iter_train_warmup = iter(train_dl_warmup) if train_dl_warmup is not None else None
+
+ batch_idx = 0
+
+ while True:
+ # We first sample a batch from the main dataset. This is because we want to
+ # make sure all epochs have the same number of batches.
+ try:
+ batch = next(iter_train)
+ except StopIteration:
+ break
+
+ # If we are in warmup stage, get the batch from the warmup dataset.
+ if (
+ params.batch_idx_train <= params.model_warm_step
+ and iter_train_warmup is not None
+ ):
+ try:
+ batch = next(iter_train_warmup)
+ except StopIteration:
+ iter_train_warmup = iter(train_dl_warmup)
+ batch = next(iter_train_warmup)
+
+ batch_idx += 1
+
+ params.batch_idx_train += 1
+ batch_size = batch["inputs"].shape[0]
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ set_batch_count(model, params.batch_idx_train)
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ params.cur_batch_idx = batch_idx
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ del params.cur_batch_idx
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+ if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_surt_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+
+ if checkpoints is None and params.model_init_ckpt is not None:
+ logging.info(
+ f"Initializing model with checkpoint from {params.model_init_ckpt}"
+ )
+ init_ckpt = torch.load(params.model_init_ckpt, map_location=device)
+ model.load_state_dict(init_ckpt["model"], strict=False)
+
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ parameters_names = []
+ parameters_names.append(
+ [name_param_pair[0] for name_param_pair in model.named_parameters()]
+ )
+ optimizer = ScaledAdam(
+ model.parameters(),
+ lr=params.base_lr,
+ clipping_scale=2.0,
+ parameters_names=parameters_names,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ diagnostic = diagnostics.attach_diagnostics(model)
+
+ libricss = LibriCssAsrDataModule(args)
+
+ train_cuts = libricss.lsmix_cuts(rvb_affix="comb", type_affix="full", sources=True)
+ train_cuts_ov40 = libricss.lsmix_cuts(
+ rvb_affix="comb", type_affix="ov40", sources=True
+ )
+ dev_cuts = libricss.libricss_cuts(split="dev", type="sdm")
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = libricss.train_dataloaders(
+ train_cuts,
+ sampler_state_dict=sampler_state_dict,
+ )
+ train_dl_ov40 = libricss.train_dataloaders(train_cuts_ov40)
+ valid_dl = libricss.valid_dataloaders(dev_cuts)
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ train_dl_warmup=train_dl_ov40,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = [sp.encode(text_ch) for text_ch in batch["text"]]
+ num_tokens = [sum(len(yi) for yi in y_ch) for y_ch in y]
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def main():
+ parser = get_parser()
+ LibriCssAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+torch.multiprocessing.set_sharing_strategy("file_system")
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/libricss/SURT/dprnn_zipformer/train_adapt.py b/egs/libricss/SURT/dprnn_zipformer/train_adapt.py
new file mode 100755
index 000000000..1c1b0c28c
--- /dev/null
+++ b/egs/libricss/SURT/dprnn_zipformer/train_adapt.py
@@ -0,0 +1,1343 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
+# Wei Kang,
+# Mingshuang Luo,)
+# Zengwei Yao)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+export CUDA_VISIBLE_DEVICES=0
+
+./dprnn_zipformer/train.py \
+ --world-size 1 \
+ --num-epochs 15 \
+ --start-epoch 1 \
+ --exp-dir dprnn_zipformer/exp \
+ --max-duration 300
+
+# For mix precision training:
+
+./dprnn_zipformer/train.py \
+ --world-size 1 \
+ --num-epochs 30 \
+ --start-epoch 1 \
+ --use-fp16 1 \
+ --exp-dir dprnn_zipformer/exp \
+ --max-duration 550
+"""
+
+import argparse
+import copy
+import logging
+import warnings
+from itertools import chain
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, Optional, Tuple, Union
+
+import k2
+import optim
+import sentencepiece as spm
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+from asr_datamodule import LibriCssAsrDataModule
+from decoder import Decoder
+from dprnn import DPRNN
+from einops.layers.torch import Rearrange
+from joiner import Joiner
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import LOG_EPSILON, fix_random_seed
+from model import SURT
+from optim import Eden, ScaledAdam
+from scaling import ScaledLinear, ScaledLSTM
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from zipformer import Zipformer
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import (
+ save_checkpoint_with_global_batch_idx,
+ update_averaged_model,
+)
+from icefall.dist import cleanup_dist, setup_dist
+from icefall.env import get_env_info
+from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for module in model.modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+
+
+def add_model_arguments(parser: argparse.ArgumentParser):
+ parser.add_argument(
+ "--num-mask-encoder-layers",
+ type=int,
+ default=4,
+ help="Number of layers in the DPRNN based mask encoder.",
+ )
+
+ parser.add_argument(
+ "--mask-encoder-dim",
+ type=int,
+ default=256,
+ help="Hidden dimension of the LSTM blocks in DPRNN.",
+ )
+
+ parser.add_argument(
+ "--mask-encoder-segment-size",
+ type=int,
+ default=32,
+ help="Segment size of the SegLSTM in DPRNN. Ideally, this should be equal to the "
+ "decode-chunk-length of the zipformer encoder.",
+ )
+
+ parser.add_argument(
+ "--chunk-width-randomization",
+ type=bool,
+ default=False,
+ help="Whether to randomize the chunk width in DPRNN.",
+ )
+
+ # Zipformer config is based on:
+ # https://github.com/k2-fsa/icefall/pull/745#issuecomment-1405282740
+ parser.add_argument(
+ "--num-encoder-layers",
+ type=str,
+ default="2,2,2,2,2",
+ help="Number of zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--feedforward-dims",
+ type=str,
+ default="768,768,768,768,768",
+ help="Feedforward dimension of the zipformer encoder layers, comma separated.",
+ )
+
+ parser.add_argument(
+ "--nhead",
+ type=str,
+ default="8,8,8,8,8",
+ help="Number of attention heads in the zipformer encoder layers.",
+ )
+
+ parser.add_argument(
+ "--encoder-dims",
+ type=str,
+ default="256,256,256,256,256",
+ help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated",
+ )
+
+ parser.add_argument(
+ "--attention-dims",
+ type=str,
+ default="192,192,192,192,192",
+ help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated;
+ not the same as embedding dimension.""",
+ )
+
+ parser.add_argument(
+ "--encoder-unmasked-dims",
+ type=str,
+ default="192,192,192,192,192",
+ help="Unmasked dimensions in the encoders, relates to augmentation during training. "
+ "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance "
+ " worse.",
+ )
+
+ parser.add_argument(
+ "--zipformer-downsampling-factors",
+ type=str,
+ default="1,2,4,8,2",
+ help="Downsampling factor for each stack of encoder layers.",
+ )
+
+ parser.add_argument(
+ "--cnn-module-kernels",
+ type=str,
+ default="31,31,31,31,31",
+ help="Sizes of kernels in convolution modules",
+ )
+
+ parser.add_argument(
+ "--use-joint-encoder-layer",
+ type=str,
+ default="lstm",
+ choices=["linear", "lstm", "none"],
+ help="Whether to use a joint layer to combine all branches.",
+ )
+
+ parser.add_argument(
+ "--decoder-dim",
+ type=int,
+ default=512,
+ help="Embedding dimension in the decoder model.",
+ )
+
+ parser.add_argument(
+ "--joiner-dim",
+ type=int,
+ default=512,
+ help="""Dimension used in the joiner model.
+ Outputs from the encoder and decoder model are projected
+ to this dimension before adding.
+ """,
+ )
+
+ parser.add_argument(
+ "--short-chunk-size",
+ type=int,
+ default=50,
+ help="""Chunk length of dynamic training, the chunk size would be either
+ max sequence length of current batch or uniformly sampled from (1, short_chunk_size).
+ """,
+ )
+
+ parser.add_argument(
+ "--num-left-chunks",
+ type=int,
+ default=4,
+ help="How many left context can be seen in chunks when calculating attention.",
+ )
+
+ parser.add_argument(
+ "--decode-chunk-len",
+ type=int,
+ default=32,
+ help="The chunk size for decoding (in frames before subsampling)",
+ )
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--world-size",
+ type=int,
+ default=1,
+ help="Number of GPUs for DDP training.",
+ )
+
+ parser.add_argument(
+ "--master-port",
+ type=int,
+ default=12354,
+ help="Master port to use for DDP training.",
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=15,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="conv_lstm_transducer_stateless_ctc/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--model-init-ckpt",
+ type=str,
+ default=None,
+ help="""The model checkpoint to initialize the model (either full or part).
+ If not specified, the model is randomly initialized.
+ """,
+ )
+
+ parser.add_argument(
+ "--bpe-model",
+ type=str,
+ default="data/lang_bpe_500/bpe.model",
+ help="Path to the BPE model",
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=0.0004, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=1000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=2,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--context-size",
+ type=int,
+ default=2,
+ help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
+ )
+
+ parser.add_argument(
+ "--prune-range",
+ type=int,
+ default=5,
+ help="The prune range for rnnt loss, it means how many symbols(context)"
+ "we are using to compute the loss",
+ )
+
+ parser.add_argument(
+ "--lm-scale",
+ type=float,
+ default=0.25,
+ help="The scale to smooth the loss with lm "
+ "(output of prediction network) part.",
+ )
+
+ parser.add_argument(
+ "--am-scale",
+ type=float,
+ default=0.0,
+ help="The scale to smooth the loss with am (output of encoder network) part.",
+ )
+
+ parser.add_argument(
+ "--simple-loss-scale",
+ type=float,
+ default=0.5,
+ help="To get pruning ranges, we will calculate a simple version"
+ "loss(joiner is just addition), this simple loss also uses for"
+ "training (as a regularization item). We will scale the simple loss"
+ "with this parameter before adding to the final loss.",
+ )
+
+ parser.add_argument(
+ "--ctc-loss-scale",
+ type=float,
+ default=0.2,
+ help="Scale for CTC loss.",
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--save-every-n",
+ type=int,
+ default=1000,
+ help="""Save checkpoint after processing this number of batches"
+ periodically. We save checkpoint to exp-dir/ whenever
+ params.batch_idx_train % save_every_n == 0. The checkpoint filename
+ has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
+ Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
+ end of each epoch where `xxx` is the epoch number counting from 0.
+ """,
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=5,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=100,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=False,
+ help="Whether to use half precision training.",
+ )
+
+ add_model_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - best_train_loss: Best training loss so far. It is used to select
+ the model that has the lowest training loss. It is
+ updated during the training.
+
+ - best_valid_loss: Best validation loss so far. It is used to select
+ the model that has the lowest validation loss. It is
+ updated during the training.
+
+ - best_train_epoch: It is the epoch that has the best training loss.
+
+ - best_valid_epoch: It is the epoch that has the best validation loss.
+
+ - batch_idx_train: Used to writing statistics to tensorboard. It
+ contains number of batches trained so far across
+ epochs.
+
+ - log_interval: Print training loss if batch_idx % log_interval` is 0
+
+ - reset_interval: Reset statistics if batch_idx % reset_interval is 0
+
+ - valid_interval: Run validation if batch_idx % valid_interval is 0
+
+ - feature_dim: The model input dim. It has to match the one used
+ in computing features.
+
+ - subsampling_factor: The subsampling factor for the model.
+
+ - num_decoder_layers: Number of decoder layer of transformer decoder.
+
+ - warm_step: The warm_step for Noam optimizer.
+ """
+ params = AttributeDict(
+ {
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 10,
+ "reset_interval": 200,
+ "valid_interval": 100,
+ # parameters for SURT
+ "num_channels": 2,
+ "feature_dim": 80,
+ "subsampling_factor": 4, # not passed in, this is fixed
+ # parameters for Noam
+ "model_warm_step": 5000, # arg given to model, not for lrate
+ # parameters for ctc loss
+ "beam_size": 10,
+ "use_double_scores": True,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def get_mask_encoder_model(params: AttributeDict) -> nn.Module:
+ mask_encoder = DPRNN(
+ feature_dim=params.feature_dim,
+ input_size=params.mask_encoder_dim,
+ hidden_size=params.mask_encoder_dim,
+ output_size=params.feature_dim * params.num_channels,
+ segment_size=params.mask_encoder_segment_size,
+ num_blocks=params.num_mask_encoder_layers,
+ chunk_width_randomization=params.chunk_width_randomization,
+ )
+ return mask_encoder
+
+
+def get_encoder_model(params: AttributeDict) -> nn.Module:
+ # TODO: We can add an option to switch between Zipformer and Transformer
+ def to_int_tuple(s: str):
+ return tuple(map(int, s.split(",")))
+
+ encoder = Zipformer(
+ num_features=params.feature_dim,
+ output_downsampling_factor=2,
+ zipformer_downsampling_factors=to_int_tuple(
+ params.zipformer_downsampling_factors
+ ),
+ encoder_dims=to_int_tuple(params.encoder_dims),
+ attention_dim=to_int_tuple(params.attention_dims),
+ encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims),
+ nhead=to_int_tuple(params.nhead),
+ feedforward_dim=to_int_tuple(params.feedforward_dims),
+ cnn_module_kernels=to_int_tuple(params.cnn_module_kernels),
+ num_encoder_layers=to_int_tuple(params.num_encoder_layers),
+ num_left_chunks=params.num_left_chunks,
+ short_chunk_size=params.short_chunk_size,
+ decode_chunk_size=params.decode_chunk_len // 2,
+ )
+ return encoder
+
+
+def get_joint_encoder_layer(params: AttributeDict) -> nn.Module:
+ class TakeFirst(nn.Module):
+ def forward(self, x):
+ return x[0]
+
+ if params.use_joint_encoder_layer == "linear":
+ encoder_dim = int(params.encoder_dims.split(",")[-1])
+ joint_layer = nn.Sequential(
+ Rearrange("(c b) t d -> b t (c d)", c=params.num_channels),
+ nn.Linear(
+ params.num_channels * encoder_dim, params.num_channels * encoder_dim
+ ),
+ nn.ReLU(),
+ Rearrange("b t (c d) -> (c b) t d", c=params.num_channels),
+ )
+ elif params.use_joint_encoder_layer == "lstm":
+ encoder_dim = int(params.encoder_dims.split(",")[-1])
+ joint_layer = nn.Sequential(
+ Rearrange("(c b) t d -> b t (c d)", c=params.num_channels),
+ ScaledLSTM(
+ input_size=params.num_channels * encoder_dim,
+ hidden_size=params.num_channels * encoder_dim,
+ num_layers=1,
+ bias=True,
+ batch_first=True,
+ dropout=0.0,
+ bidirectional=False,
+ ),
+ TakeFirst(),
+ nn.ReLU(),
+ Rearrange("b t (c d) -> (c b) t d", c=params.num_channels),
+ )
+ elif params.use_joint_encoder_layer == "none":
+ joint_layer = None
+ else:
+ raise ValueError(
+ f"Unknown joint encoder layer type: {params.use_joint_encoder_layer}"
+ )
+ return joint_layer
+
+
+def get_decoder_model(params: AttributeDict) -> nn.Module:
+ decoder = Decoder(
+ vocab_size=params.vocab_size,
+ decoder_dim=params.decoder_dim,
+ blank_id=params.blank_id,
+ context_size=params.context_size,
+ )
+ return decoder
+
+
+def get_joiner_model(params: AttributeDict) -> nn.Module:
+ joiner = Joiner(
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return joiner
+
+
+def get_surt_model(
+ params: AttributeDict,
+) -> nn.Module:
+ mask_encoder = get_mask_encoder_model(params)
+ encoder = get_encoder_model(params)
+ joint_layer = get_joint_encoder_layer(params)
+ decoder = get_decoder_model(params)
+ joiner = get_joiner_model(params)
+
+ model = SURT(
+ mask_encoder=mask_encoder,
+ encoder=encoder,
+ joint_encoder_layer=joint_layer,
+ decoder=decoder,
+ joiner=joiner,
+ num_channels=params.num_channels,
+ encoder_dim=int(params.encoder_dims.split(",")[-1]),
+ decoder_dim=params.decoder_dim,
+ joiner_dim=params.joiner_dim,
+ vocab_size=params.vocab_size,
+ )
+ return model
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute RNN-T loss given the model and its inputs.
+
+ Args:
+ params:
+ Parameters for training. See :func:`get_params`.
+ model:
+ The model for training. It is an instance of Conformer in our case.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ True for training. False for validation. When it is True, this
+ function enables autograd during computation; when it is False, it
+ disables autograd.
+ """
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"].to(device)
+ feature_lens = batch["input_lens"].to(device)
+
+ # at entry, feature is (N, T, C)
+ assert feature.ndim == 3
+
+ # The dataloader returns text as a list of cuts, each of which is a list of channel
+ # text. We flatten this to a list where all channels are together, i.e., it looks like
+ # [utt1_ch1, utt2_ch1, ..., uttN_ch1, utt1_ch2, ...., uttN,ch2].
+ text = [val for tup in zip(*batch["text"]) for val in tup]
+ assert len(text) == len(feature) * params.num_channels
+
+ # Convert all channel texts to token IDs and create a ragged tensor.
+ y = sp.encode(text, out_type=int)
+ y = k2.RaggedTensor(y).to(device)
+
+ batch_idx_train = params.batch_idx_train
+ warm_step = params.model_warm_step
+
+ with torch.set_grad_enabled(is_training):
+ (simple_loss, pruned_loss, ctc_loss, x_masked) = model(
+ x=feature,
+ x_lens=feature_lens,
+ y=y,
+ prune_range=params.prune_range,
+ am_scale=params.am_scale,
+ lm_scale=params.lm_scale,
+ reduction="none",
+ subsampling_factor=params.subsampling_factor,
+ )
+ simple_loss_is_finite = torch.isfinite(simple_loss)
+ pruned_loss_is_finite = torch.isfinite(pruned_loss)
+ ctc_loss_is_finite = torch.isfinite(ctc_loss)
+
+ is_finite = simple_loss_is_finite & pruned_loss_is_finite & ctc_loss_is_finite
+ if not torch.all(is_finite):
+ logging.info(
+ "Not all losses are finite!\n"
+ f"simple_losses: {simple_loss}\n"
+ f"pruned_losses: {pruned_loss}\n"
+ f"ctc_losses: {ctc_loss}\n"
+ )
+ display_and_save_batch(batch, params=params, sp=sp)
+ simple_loss = simple_loss[simple_loss_is_finite]
+ pruned_loss = pruned_loss[pruned_loss_is_finite]
+ ctc_loss = ctc_loss[ctc_loss_is_finite]
+
+ # If either all simple_loss or pruned_loss is inf or nan,
+ # we stop the training process by raising an exception
+ if (
+ torch.all(~simple_loss_is_finite)
+ or torch.all(~pruned_loss_is_finite)
+ or torch.all(~ctc_loss_is_finite)
+ ):
+ raise ValueError(
+ "There are too many utterances in this batch "
+ "leading to inf or nan losses."
+ )
+
+ simple_loss_sum = simple_loss.sum()
+ pruned_loss_sum = pruned_loss.sum()
+ ctc_loss_sum = ctc_loss.sum()
+
+ s = params.simple_loss_scale
+ # take down the scale on the simple loss from 1.0 at the start
+ # to params.simple_loss scale by warm_step.
+ simple_loss_scale = (
+ s
+ if batch_idx_train >= warm_step
+ else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
+ )
+ pruned_loss_scale = (
+ 1.0
+ if batch_idx_train >= warm_step
+ else 0.1 + 0.9 * (batch_idx_train / warm_step)
+ )
+ loss = (
+ simple_loss_scale * simple_loss_sum
+ + pruned_loss_scale * pruned_loss_sum
+ + params.ctc_loss_scale * ctc_loss_sum
+ )
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ # info["frames"] is an approximate number for two reasons:
+ # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2
+ # (2) If some utterances in the batch lead to inf/nan loss, they
+ # are filtered out.
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa
+ info["utterances"] = feature.size(0)
+ # averaged input duration in frames over utterances
+ info["utt_duration"] = feature_lens.sum().item()
+ # averaged padding proportion over utterances
+ info["utt_pad_proportion"] = (
+ ((feature.size(1) - feature_lens) / feature.size(1)).sum().item()
+ )
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+ info["simple_loss"] = simple_loss_sum.detach().cpu().item()
+ info["pruned_loss"] = pruned_loss_sum.detach().cpu().item()
+ if params.ctc_loss_scale > 0.0:
+ info["ctc_loss"] = ctc_loss_sum.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ sp: spm.SentencePieceProcessor,
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ sp: spm.SentencePieceProcessor,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ train_dl_warmup:
+ Dataloader for the training dataset with 2 speakers. This is used during the
+ warmup stage.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ torch.cuda.empty_cache()
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ cur_batch_idx = params.get("cur_batch_idx", 0)
+
+ for batch_idx, batch in enumerate(train_dl):
+ if batch_idx < cur_batch_idx:
+ continue
+ cur_batch_idx = batch_idx
+
+ params.batch_idx_train += 1
+ batch_size = batch["inputs"].shape[0]
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ scaler.scale(loss).backward()
+ set_batch_count(model, params.batch_idx_train)
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ display_and_save_batch(batch, params=params, sp=sp)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if (
+ params.batch_idx_train > 0
+ and params.batch_idx_train % params.save_every_n == 0
+ ):
+ params.cur_batch_idx = batch_idx
+ save_checkpoint_with_global_batch_idx(
+ out_dir=params.exp_dir,
+ global_batch_idx=params.batch_idx_train,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+ del params.cur_batch_idx
+ remove_checkpoints(
+ out_dir=params.exp_dir,
+ topk=params.keep_last_k,
+ rank=rank,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+ if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+
+ if batch_idx % params.log_interval == 0:
+ cur_lr = scheduler.get_last_lr()[0]
+ cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale", cur_grad_scale, params.batch_idx_train
+ )
+
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ model=model,
+ sp=sp,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+ if world_size > 1:
+ setup_dist(rank, world_size, params.master_port)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info("Training started")
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ device = torch.device("cpu")
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ logging.info(f"Device: {device}")
+
+ sp = spm.SentencePieceProcessor()
+ sp.load(params.bpe_model)
+
+ # is defined in local/train_bpe_model.py
+ params.blank_id = sp.piece_to_id("")
+ params.vocab_size = sp.get_piece_size()
+
+ logging.info(params)
+
+ logging.info("About to create model")
+ model = get_surt_model(params)
+
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ assert params.save_every_n >= params.average_period
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ model.to(device)
+
+ if checkpoints is None and params.model_init_ckpt is not None:
+ logging.info(
+ f"Initializing model with checkpoint from {params.model_init_ckpt}"
+ )
+ init_ckpt = torch.load(params.model_init_ckpt, map_location=device)
+ model.load_state_dict(init_ckpt["model"], strict=True)
+
+ if world_size > 1:
+ logging.info("Using DDP")
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ parameters_names = []
+ parameters_names.append(
+ [name_param_pair[0] for name_param_pair in model.named_parameters()]
+ )
+ optimizer = ScaledAdam(
+ model.parameters(),
+ lr=params.base_lr,
+ clipping_scale=2.0,
+ parameters_names=parameters_names,
+ )
+
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if params.print_diagnostics:
+ diagnostic = diagnostics.attach_diagnostics(model)
+
+ libricss = LibriCssAsrDataModule(args)
+
+ train_cuts_ihm = libricss.libricss_cuts(split="dev", type="ihm-mix")
+ train_cuts_sdm = libricss.libricss_cuts(split="dev", type="sdm")
+ train_cuts = train_cuts_ihm + train_cuts_sdm
+
+ # This will create 2 copies of the sessions with different segmentation
+ train_cuts = train_cuts.trim_to_supervision_groups(
+ max_pause=0.1
+ ) + train_cuts.trim_to_supervision_groups(max_pause=0.5)
+ dev_cuts = libricss.libricss_cuts(split="dev", type="sdm")
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = libricss.train_dataloaders(
+ train_cuts,
+ sampler_state_dict=sampler_state_dict,
+ return_sources=False,
+ strict=False,
+ )
+ valid_dl = libricss.valid_dataloaders(dev_cuts)
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sp=sp,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+ sp: spm.SentencePieceProcessor,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ sp:
+ The BPE model.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+ y = [sp.encode(text_ch) for text_ch in batch["text"]]
+ num_tokens = [sum(len(yi) for yi in y_ch) for y_ch in y]
+ logging.info(f"num tokens: {num_tokens}")
+
+
+def main():
+ parser = get_parser()
+ LibriCssAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = args.world_size
+ assert world_size >= 1
+ if world_size > 1:
+ mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
+ else:
+ run(rank=0, world_size=1, args=args)
+
+
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+torch.multiprocessing.set_sharing_strategy("file_system")
+
+if __name__ == "__main__":
+ main()
diff --git a/egs/libricss/SURT/dprnn_zipformer/zipformer.py b/egs/libricss/SURT/dprnn_zipformer/zipformer.py
new file mode 120000
index 000000000..ec183baa7
--- /dev/null
+++ b/egs/libricss/SURT/dprnn_zipformer/zipformer.py
@@ -0,0 +1 @@
+../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py
\ No newline at end of file
diff --git a/egs/libricss/SURT/heat.png b/egs/libricss/SURT/heat.png
new file mode 100644
index 000000000..ac7ecfff4
Binary files /dev/null and b/egs/libricss/SURT/heat.png differ
diff --git a/egs/libricss/SURT/local/add_source_feats.py b/egs/libricss/SURT/local/add_source_feats.py
new file mode 100755
index 000000000..c9775561f
--- /dev/null
+++ b/egs/libricss/SURT/local/add_source_feats.py
@@ -0,0 +1,85 @@
+#!/usr/bin/env python3
+# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file adds source features as temporal arrays to the mixture manifests.
+It looks for manifests in the directory data/manifests.
+"""
+import logging
+from pathlib import Path
+
+import numpy as np
+from lhotse import CutSet, LilcomChunkyWriter, load_manifest, load_manifest_lazy
+from tqdm import tqdm
+
+
+def add_source_feats(num_jobs=1):
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/fbank")
+
+ for type_affix in ["full", "ov40"]:
+ logging.info(f"Adding source features for {type_affix}")
+ mixed_name_clean = f"train_clean_{type_affix}"
+ mixed_name_rvb = f"train_rvb_{type_affix}"
+
+ logging.info("Reading mixed cuts")
+ mixed_cuts_clean = load_manifest_lazy(
+ src_dir / f"cuts_{mixed_name_clean}.jsonl.gz"
+ )
+ mixed_cuts_rvb = load_manifest_lazy(src_dir / f"cuts_{mixed_name_rvb}.jsonl.gz")
+
+ logging.info("Reading source cuts")
+ source_cuts = load_manifest(src_dir / "librispeech_cuts_train_trimmed.jsonl.gz")
+
+ logging.info("Adding source features to the mixed cuts")
+ with tqdm() as pbar, CutSet.open_writer(
+ src_dir / f"cuts_{mixed_name_clean}_sources.jsonl.gz"
+ ) as cut_writer_clean, CutSet.open_writer(
+ src_dir / f"cuts_{mixed_name_rvb}_sources.jsonl.gz"
+ ) as cut_writer_rvb, LilcomChunkyWriter(
+ output_dir / f"feats_train_{type_affix}_sources"
+ ) as source_feat_writer:
+ for cut_clean, cut_rvb in zip(mixed_cuts_clean, mixed_cuts_rvb):
+ assert cut_rvb.id == cut_clean.id + "_rvb"
+ # Create source_feats and source_feat_offsets
+ # (See `lhotse.datasets.K2SurtDataset` for details)
+ source_feats = []
+ source_feat_offsets = []
+ cur_offset = 0
+ for sup in sorted(
+ cut_clean.supervisions, key=lambda s: (s.start, s.speaker)
+ ):
+ source_cut = source_cuts[sup.id]
+ source_feats.append(source_cut.load_features())
+ source_feat_offsets.append(cur_offset)
+ cur_offset += source_cut.num_frames
+ cut_clean.source_feats = source_feat_writer.store_array(
+ cut_clean.id, np.concatenate(source_feats, axis=0)
+ )
+ cut_clean.source_feat_offsets = source_feat_offsets
+ cut_writer_clean.write(cut_clean)
+ cut_rvb.source_feats = cut_clean.source_feats
+ cut_rvb.source_feat_offsets = cut_clean.source_feat_offsets
+ cut_writer_rvb.write(cut_rvb)
+ pbar.update(1)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ add_source_feats()
diff --git a/egs/libricss/SURT/local/compute_fbank_libricss.py b/egs/libricss/SURT/local/compute_fbank_libricss.py
new file mode 100755
index 000000000..afd66899c
--- /dev/null
+++ b/egs/libricss/SURT/local/compute_fbank_libricss.py
@@ -0,0 +1,105 @@
+#!/usr/bin/env python3
+# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the LibriCSS dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+import logging
+from pathlib import Path
+
+import pyloudnorm as pyln
+import torch
+import torch.multiprocessing
+from lhotse import LilcomChunkyWriter, load_manifest_lazy
+from lhotse.features.kaldifeat import (
+ KaldifeatFbank,
+ KaldifeatFbankConfig,
+ KaldifeatFrameOptions,
+ KaldifeatMelOptions,
+)
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+torch.multiprocessing.set_sharing_strategy("file_system")
+
+
+def compute_fbank_libricss():
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/fbank")
+
+ sampling_rate = 16000
+ num_mel_bins = 80
+
+ extractor = KaldifeatFbank(
+ KaldifeatFbankConfig(
+ frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
+ mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
+ device="cuda",
+ )
+ )
+
+ logging.info("Reading manifests")
+ cuts_ihm_mix = load_manifest_lazy(
+ src_dir / "libricss-ihm-mix_segments_all.jsonl.gz"
+ )
+ cuts_sdm = load_manifest_lazy(src_dir / "libricss-sdm_segments_all.jsonl.gz")
+
+ for name, cuts in [("ihm-mix", cuts_ihm_mix), ("sdm", cuts_sdm)]:
+ dev_cuts = cuts.filter(lambda c: "session0" in c.id)
+ test_cuts = cuts.filter(lambda c: "session0" not in c.id)
+
+ # If SDM cuts, apply loudness normalization
+ if name == "sdm":
+ dev_cuts = dev_cuts.normalize_loudness(target=-23.0)
+ test_cuts = test_cuts.normalize_loudness(target=-23.0)
+
+ logging.info(f"Extracting fbank features for {name} dev cuts")
+ _ = dev_cuts.compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=output_dir / f"libricss-{name}_feats_dev",
+ manifest_path=src_dir / f"cuts_dev_libricss-{name}.jsonl.gz",
+ batch_duration=500,
+ num_workers=2,
+ storage_type=LilcomChunkyWriter,
+ overwrite=True,
+ )
+
+ logging.info(f"Extracting fbank features for {name} test cuts")
+ _ = test_cuts.compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=output_dir / f"libricss-{name}_feats_test",
+ manifest_path=src_dir / f"cuts_test_libricss-{name}.jsonl.gz",
+ batch_duration=2000,
+ num_workers=4,
+ storage_type=LilcomChunkyWriter,
+ overwrite=True,
+ )
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ compute_fbank_libricss()
diff --git a/egs/libricss/SURT/local/compute_fbank_librispeech.py b/egs/libricss/SURT/local/compute_fbank_librispeech.py
new file mode 100755
index 000000000..5c8aece9c
--- /dev/null
+++ b/egs/libricss/SURT/local/compute_fbank_librispeech.py
@@ -0,0 +1,111 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the LibriSpeech dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import logging
+from pathlib import Path
+
+import torch
+from lhotse import CutSet, LilcomChunkyWriter
+from lhotse.features.kaldifeat import (
+ KaldifeatFbank,
+ KaldifeatFbankConfig,
+ KaldifeatFrameOptions,
+ KaldifeatMelOptions,
+)
+from lhotse.recipes.utils import read_manifests_if_cached
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+torch.multiprocessing.set_sharing_strategy("file_system")
+
+
+def compute_fbank_librispeech():
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/fbank")
+ num_mel_bins = 80
+
+ dataset_parts = (
+ "train-clean-100",
+ "train-clean-360",
+ "train-other-500",
+ )
+ prefix = "librispeech"
+ suffix = "jsonl.gz"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix=prefix,
+ suffix=suffix,
+ )
+ assert manifests is not None
+
+ assert len(manifests) == len(dataset_parts), (
+ len(manifests),
+ len(dataset_parts),
+ list(manifests.keys()),
+ dataset_parts,
+ )
+
+ extractor = KaldifeatFbank(
+ KaldifeatFbankConfig(
+ frame_opts=KaldifeatFrameOptions(sampling_rate=16000),
+ mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
+ device="cuda",
+ )
+ )
+
+ for partition, m in manifests.items():
+ cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
+ if (output_dir / cuts_filename).is_file():
+ logging.info(f"{partition} already exists - skipping.")
+ continue
+ logging.info(f"Processing {partition}")
+ cut_set = CutSet.from_manifests(
+ recordings=m["recordings"],
+ supervisions=m["supervisions"],
+ )
+
+ cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
+
+ cut_set = cut_set.compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=f"{output_dir}/{prefix}_feats_{partition}",
+ manifest_path=f"{src_dir}/{cuts_filename}",
+ batch_duration=4000,
+ num_workers=2,
+ storage_type=LilcomChunkyWriter,
+ overwrite=True,
+ )
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ compute_fbank_librispeech()
diff --git a/egs/libricss/SURT/local/compute_fbank_lsmix.py b/egs/libricss/SURT/local/compute_fbank_lsmix.py
new file mode 100755
index 000000000..da42f8ba1
--- /dev/null
+++ b/egs/libricss/SURT/local/compute_fbank_lsmix.py
@@ -0,0 +1,188 @@
+#!/usr/bin/env python3
+# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the synthetically mixed LibriSpeech
+train and dev sets.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+import logging
+import random
+import warnings
+from pathlib import Path
+
+import torch
+import torch.multiprocessing
+from lhotse import LilcomChunkyWriter, load_manifest
+from lhotse.cut import MixedCut, MixTrack, MultiCut
+from lhotse.features.kaldifeat import (
+ KaldifeatFbank,
+ KaldifeatFbankConfig,
+ KaldifeatFrameOptions,
+ KaldifeatMelOptions,
+)
+from lhotse.recipes.utils import read_manifests_if_cached
+from lhotse.utils import fix_random_seed, uuid4
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+torch.multiprocessing.set_sharing_strategy("file_system")
+
+
+def compute_fbank_lsmix():
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/fbank")
+
+ sampling_rate = 16000
+ num_mel_bins = 80
+
+ extractor = KaldifeatFbank(
+ KaldifeatFbankConfig(
+ frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
+ mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
+ device="cuda",
+ )
+ )
+
+ logging.info("Reading manifests")
+ manifests = read_manifests_if_cached(
+ dataset_parts=["train_clean_full", "train_clean_ov40"],
+ types=["cuts"],
+ output_dir=src_dir,
+ prefix="lsmix",
+ suffix="jsonl.gz",
+ lazy=True,
+ )
+
+ cs = {}
+ cs["clean_full"] = manifests["train_clean_full"]["cuts"]
+ cs["clean_ov40"] = manifests["train_clean_ov40"]["cuts"]
+
+ # only uses RIRs and noises from REVERB challenge
+ real_rirs = load_manifest(src_dir / "real-rir_recordings_all.jsonl.gz").filter(
+ lambda r: "RVB2014" in r.id
+ )
+ noises = load_manifest(src_dir / "iso-noise_recordings_all.jsonl.gz").filter(
+ lambda r: "RVB2014" in r.id
+ )
+
+ # Apply perturbation to the training cuts
+ logging.info("Applying perturbation to the training cuts")
+ cs["rvb_full"] = cs["clean_full"].map(
+ lambda c: augment(
+ c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True
+ )
+ )
+ cs["rvb_ov40"] = cs["clean_ov40"].map(
+ lambda c: augment(
+ c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True
+ )
+ )
+
+ for type_affix in ["full", "ov40"]:
+ for rvb_affix in ["clean", "rvb"]:
+ logging.info(
+ f"Extracting fbank features for {type_affix} {rvb_affix} training cuts"
+ )
+ cuts = cs[f"{rvb_affix}_{type_affix}"]
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ _ = cuts.compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=output_dir
+ / f"lsmix_feats_train_{rvb_affix}_{type_affix}",
+ manifest_path=src_dir
+ / f"cuts_train_{rvb_affix}_{type_affix}.jsonl.gz",
+ batch_duration=5000,
+ num_workers=4,
+ storage_type=LilcomChunkyWriter,
+ overwrite=True,
+ )
+
+
+def augment(cut, perturb_snr=False, rirs=None, noises=None, perturb_loudness=False):
+ """
+ Given a mixed cut, this function optionally applies the following augmentations:
+ - Perturbing the SNRs of the tracks (in range [-5, 5] dB)
+ - Reverberation using a randomly selected RIR
+ - Adding noise
+ - Perturbing the loudness (in range [-20, -25] dB)
+ """
+ out_cut = cut.drop_features()
+
+ # Perturb the SNRs (optional)
+ if perturb_snr:
+ snrs = [random.uniform(-5, 5) for _ in range(len(cut.tracks))]
+ for i, (track, snr) in enumerate(zip(out_cut.tracks, snrs)):
+ if i == 0:
+ # Skip the first track since it is the reference
+ continue
+ track.snr = snr
+
+ # Reverberate the cut (optional)
+ if rirs is not None:
+ # Select an RIR at random
+ rir = random.choice(rirs)
+ # Select a channel at random
+ rir_channel = random.choice(list(range(rir.num_channels)))
+ # Reverberate the cut
+ out_cut = out_cut.reverb_rir(rir_recording=rir, rir_channels=[rir_channel])
+
+ # Add noise (optional)
+ if noises is not None:
+ # Select a noise recording at random
+ noise = random.choice(noises).to_cut()
+ if isinstance(noise, MultiCut):
+ noise = noise.to_mono()[0]
+ # Select an SNR at random
+ snr = random.uniform(10, 30)
+ # Repeat the noise to match the duration of the cut
+ noise = repeat_cut(noise, out_cut.duration)
+ out_cut = MixedCut(
+ id=out_cut.id,
+ tracks=[
+ MixTrack(cut=out_cut, type="MixedCut"),
+ MixTrack(cut=noise, type="DataCut", snr=snr),
+ ],
+ )
+
+ # Perturb the loudness (optional)
+ if perturb_loudness:
+ target_loudness = random.uniform(-20, -25)
+ out_cut = out_cut.normalize_loudness(target_loudness, mix_first=True)
+ return out_cut
+
+
+def repeat_cut(cut, duration):
+ while cut.duration < duration:
+ cut = cut.mix(cut, offset_other_by=cut.duration)
+ return cut.truncate(duration=duration)
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+ logging.basicConfig(format=formatter, level=logging.INFO)
+
+ fix_random_seed(42)
+ compute_fbank_lsmix()
diff --git a/egs/libricss/SURT/local/compute_fbank_musan.py b/egs/libricss/SURT/local/compute_fbank_musan.py
new file mode 100755
index 000000000..1fcf951f9
--- /dev/null
+++ b/egs/libricss/SURT/local/compute_fbank_musan.py
@@ -0,0 +1,114 @@
+#!/usr/bin/env python3
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This file computes fbank features of the musan dataset.
+It looks for manifests in the directory data/manifests.
+
+The generated fbank features are saved in data/fbank.
+"""
+
+import logging
+from pathlib import Path
+
+import torch
+from lhotse import CutSet, LilcomChunkyWriter, combine
+from lhotse.features.kaldifeat import (
+ KaldifeatFbank,
+ KaldifeatFbankConfig,
+ KaldifeatFrameOptions,
+ KaldifeatMelOptions,
+)
+from lhotse.recipes.utils import read_manifests_if_cached
+
+# Torch's multithreaded behavior needs to be disabled or
+# it wastes a lot of CPU and slow things down.
+# Do this outside of main() in case it needs to take effect
+# even when we are not invoking the main (e.g. when spawning subprocesses).
+torch.set_num_threads(1)
+torch.set_num_interop_threads(1)
+
+
+def compute_fbank_musan():
+ src_dir = Path("data/manifests")
+ output_dir = Path("data/fbank")
+
+ sampling_rate = 16000
+ num_mel_bins = 80
+
+ dataset_parts = (
+ "music",
+ "speech",
+ "noise",
+ )
+ prefix = "musan"
+ suffix = "jsonl.gz"
+ manifests = read_manifests_if_cached(
+ dataset_parts=dataset_parts,
+ output_dir=src_dir,
+ prefix=prefix,
+ suffix=suffix,
+ )
+ assert manifests is not None
+
+ assert len(manifests) == len(dataset_parts), (
+ len(manifests),
+ len(dataset_parts),
+ list(manifests.keys()),
+ dataset_parts,
+ )
+
+ musan_cuts_path = src_dir / "musan_cuts.jsonl.gz"
+
+ if musan_cuts_path.is_file():
+ logging.info(f"{musan_cuts_path} already exists - skipping")
+ return
+
+ logging.info("Extracting features for Musan")
+
+ extractor = KaldifeatFbank(
+ KaldifeatFbankConfig(
+ frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
+ mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
+ device="cuda",
+ )
+ )
+
+ # create chunks of Musan with duration 5 - 10 seconds
+ _ = (
+ CutSet.from_manifests(
+ recordings=combine(part["recordings"] for part in manifests.values())
+ )
+ .cut_into_windows(10.0)
+ .filter(lambda c: c.duration > 5)
+ .compute_and_store_features_batch(
+ extractor=extractor,
+ storage_path=output_dir / "musan_feats",
+ manifest_path=musan_cuts_path,
+ batch_duration=500,
+ num_workers=4,
+ storage_type=LilcomChunkyWriter,
+ )
+ )
+
+
+if __name__ == "__main__":
+ formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
+
+ logging.basicConfig(format=formatter, level=logging.INFO)
+ compute_fbank_musan()
diff --git a/egs/libricss/SURT/prepare.sh b/egs/libricss/SURT/prepare.sh
new file mode 100755
index 000000000..028240e44
--- /dev/null
+++ b/egs/libricss/SURT/prepare.sh
@@ -0,0 +1,204 @@
+#!/usr/bin/env bash
+
+set -eou pipefail
+
+stage=-1
+stop_stage=100
+
+# We assume dl_dir (download dir) contains the following
+# directories and files. If not, they will be downloaded
+# by this script automatically.
+#
+# - $dl_dir/librispeech
+# You can find audio and transcripts for LibriSpeech in this path.
+#
+# - $dl_dir/libricss
+# You can find audio and transcripts for LibriCSS in this path.
+#
+# - $dl_dir/musan
+# This directory contains the following directories downloaded from
+# http://www.openslr.org/17/
+#
+# - music
+# - noise
+# - speech
+#
+# - $dl_dir/rirs_noises
+# This directory contains the RIRS_NOISES corpus downloaded from https://openslr.org/28/.
+#
+dl_dir=$PWD/download
+
+. shared/parse_options.sh || exit 1
+
+# All files generated by this script are saved in "data".
+# You can safely remove "data" and rerun this script to regenerate it.
+mkdir -p data
+vocab_size=500
+
+log() {
+ # This function is from espnet
+ local fname=${BASH_SOURCE[1]##*/}
+ echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
+}
+
+log "dl_dir: $dl_dir"
+
+if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
+ log "Stage 0: Download data"
+
+ # If you have pre-downloaded it to /path/to/librispeech,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/librispeech $dl_dir/librispeech
+ #
+ if [ ! -d $dl_dir/librispeech ]; then
+ lhotse download librispeech $dl_dir/librispeech
+ fi
+
+ # If you have pre-downloaded it to /path/to/libricss,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/libricss $dl_dir/libricss
+ #
+ if [ ! -d $dl_dir/libricss ]; then
+ lhotse download libricss $dl_dir/libricss
+ fi
+
+ # If you have pre-downloaded it to /path/to/musan,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/musan $dl_dir/
+ #
+ if [ ! -d $dl_dir/musan ]; then
+ lhotse download musan $dl_dir
+ fi
+
+ # If you have pre-downloaded it to /path/to/rirs_noises,
+ # you can create a symlink
+ #
+ # ln -sfv /path/to/rirs_noises $dl_dir/
+ #
+ if [ ! -d $dl_dir/rirs_noises ]; then
+ lhotse download rirs_noises $dl_dir
+ fi
+fi
+
+if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
+ log "Stage 1: Prepare LibriSpeech manifests"
+ # We assume that you have downloaded the LibriSpeech corpus
+ # to $dl_dir/librispeech. We perform text normalization for the transcripts.
+ # NOTE: Alignments are required for this recipe.
+ mkdir -p data/manifests
+ lhotse prepare librispeech -p train-clean-100 -p train-clean-360 -p train-other-500 -p dev-clean \
+ -j 4 --alignments-dir $dl_dir/libri_alignments/LibriSpeech $dl_dir/librispeech data/manifests/
+fi
+
+if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
+ log "Stage 2: Prepare LibriCSS manifests"
+ # We assume that you have downloaded the LibriCSS corpus
+ # to $dl_dir/libricss. We perform text normalization for the transcripts.
+ mkdir -p data/manifests
+ for mic in sdm ihm-mix; do
+ lhotse prepare libricss --type $mic --segmented $dl_dir/libricss data/manifests/
+ done
+fi
+
+if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
+ log "Stage 3: Prepare musan manifest and RIRs"
+ # We assume that you have downloaded the musan corpus
+ # to $dl_dir/musan
+ mkdir -p data/manifests
+ lhotse prepare musan $dl_dir/musan data/manifests
+
+ # We assume that you have downloaded the RIRS_NOISES corpus
+ # to $dl_dir/rirs_noises
+ lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises data/manifests
+fi
+
+if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
+ log "Stage 4: Extract features for LibriSpeech, trim to alignments, and shuffle the cuts"
+ python local/compute_fbank_librispeech.py
+ lhotse combine data/manifests/librispeech_cuts_train* - |\
+ lhotse cut trim-to-alignments --type word --max-pause 0.2 - - |\
+ shuf | gzip -c > data/manifests/librispeech_cuts_train_trimmed.jsonl.gz
+fi
+
+if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
+ log "Stage 5: Create simulated mixtures from LibriSpeech (train and dev). This may take a while."
+ # We create a high overlap set which will be used during the model warmup phase, and a
+ # full training set that will be used for the subsequent training.
+
+ gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\
+ grep -v "0L" | grep -v "OV10" |\
+ gzip -c > data/manifests/libricss-sdm_supervisions_all_v1.jsonl.gz
+
+ gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\
+ grep "OV40" |\
+ gzip -c > data/manifests/libricss-sdm_supervisions_ov40.jsonl.gz
+
+ # Warmup mixtures (100k) based on high overlap (OV40)
+ log "Generating 100k anechoic train mixtures for warmup"
+ lhotse workflows simulate-meetings \
+ --method conversational \
+ --fit-to-supervisions data/manifests/libricss-sdm_supervisions_ov40.jsonl.gz \
+ --num-meetings 100000 \
+ --num-speakers-per-meeting 2,3 \
+ --max-duration-per-speaker 15.0 \
+ --max-utterances-per-speaker 3 \
+ --seed 1234 \
+ --num-jobs 4 \
+ data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \
+ data/manifests/lsmix_cuts_train_clean_ov40.jsonl.gz
+
+ # Full training set (2,3 speakers) anechoic
+ log "Generating anechoic ${part} set (full)"
+ lhotse workflows simulate-meetings \
+ --method conversational \
+ --fit-to-supervisions data/manifests/libricss-sdm_supervisions_all_v1.jsonl.gz \
+ --num-repeats 1 \
+ --num-speakers-per-meeting 2,3 \
+ --max-duration-per-speaker 15.0 \
+ --max-utterances-per-speaker 3 \
+ --seed 1234 \
+ --num-jobs 4 \
+ data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \
+ data/manifests/lsmix_cuts_train_clean_full.jsonl.gz
+fi
+
+if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
+ log "Stage 6: Compute fbank features for musan"
+ mkdir -p data/fbank
+ python local/compute_fbank_musan.py
+fi
+
+if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
+ log "Stage 7: Compute fbank features for simulated Libri-mix"
+ mkdir -p data/fbank
+ python local/compute_fbank_lsmix.py
+fi
+
+if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
+ log "Stage 8: Add source feats to mixtures (useful for auxiliary tasks)"
+ python local/add_source_feats.py
+
+ log "Combining lsmix-clean and lsmix-rvb"
+ for type in full ov40; do
+ cat <(gunzip -c data/manifests/cuts_train_clean_${type}_sources.jsonl.gz) \
+ <(gunzip -c data/manifests/cuts_train_rvb_${type}_sources.jsonl.gz) |\
+ shuf | gzip -c > data/manifests/cuts_train_comb_${type}_sources.jsonl.gz
+ done
+fi
+
+if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
+ log "Stage 9: Compute fbank features for LibriCSS"
+ mkdir -p data/fbank
+ python local/compute_fbank_libricss.py
+fi
+
+if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
+ log "Stage 10: Download LibriSpeech BPE model from HuggingFace."
+ mkdir -p data/lang_bpe_500
+ pushd data/lang_bpe_500
+ wget https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/resolve/main/data/lang_bpe_500/bpe.model
+ popd
+fi
diff --git a/egs/libricss/SURT/shared b/egs/libricss/SURT/shared
new file mode 120000
index 000000000..4cbd91a7e
--- /dev/null
+++ b/egs/libricss/SURT/shared
@@ -0,0 +1 @@
+../../../icefall/shared
\ No newline at end of file
diff --git a/egs/libricss/SURT/surt.png b/egs/libricss/SURT/surt.png
new file mode 100644
index 000000000..fcc8119d4
Binary files /dev/null and b/egs/libricss/SURT/surt.png differ
diff --git a/icefall/utils.py b/icefall/utils.py
index dfe9a7b42..0feff9dc8 100644
--- a/icefall/utils.py
+++ b/icefall/utils.py
@@ -429,6 +429,8 @@ def store_transcripts(
texts:
An iterable of tuples. The first element is the cur_id, the second is
the reference transcript and the third element is the predicted result.
+ If it is a multi-talker ASR system, the ref and hyp may also be lists of
+ strings.
Returns:
Return None.
"""
@@ -886,8 +888,167 @@ def write_error_stats_with_timestamps(
hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
+ return float(tot_err_rate), float(mean_delay), float(var_delay)
- return tot_err_rate, mean_delay, var_delay
+
+def write_surt_error_stats(
+ f: TextIO,
+ test_set_name: str,
+ results: List[Tuple[str, str]],
+ enable_log: bool = True,
+ num_channels: int = 2,
+) -> float:
+ """Write statistics based on predicted results and reference transcripts for SURT
+ multi-talker ASR systems. The difference between this and the `write_error_stats`
+ is that this function finds the optimal speaker-agnostic WER using the ``meeteval``
+ toolkit.
+
+ Args:
+ f: File to write the statistics to.
+ test_set_name: Name of the test set.
+ results: List of tuples containing the utterance ID and the predicted
+ transcript.
+ enable_log: Whether to enable logging.
+ num_channels: Number of output channels/branches. Defaults to 2.
+ Returns:
+ Return None.
+ """
+ from meeteval.wer import wer
+
+ subs: Dict[Tuple[str, str], int] = defaultdict(int)
+ ins: Dict[str, int] = defaultdict(int)
+ dels: Dict[str, int] = defaultdict(int)
+ ref_lens: List[int] = []
+
+ print(
+ "Search below for sections starting with PER-UTT DETAILS:, "
+ "SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
+ file=f,
+ )
+
+ print("", file=f)
+ print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
+
+ # `words` stores counts per word, as follows:
+ # corr, ref_sub, hyp_sub, ins, dels
+ words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
+ num_corr = 0
+ ERR = "*"
+ for cut_id, ref, hyp in results:
+ # First compute the optimal assignment of references to output channels
+ orc_wer = wer.orc_word_error_rate(ref, hyp)
+ assignment = orc_wer.assignment
+ refs = [[] for _ in range(num_channels)]
+ # Assign references to channels
+ for i, ref_text in zip(assignment, ref):
+ refs[i] += ref_text.split()
+ hyps = [hyp_text.split() for hyp_text in hyp]
+ # Now compute the WER for each channel
+ for ref_c, hyp_c in zip(refs, hyps):
+ ref_lens.append(len(ref_c))
+ ali = kaldialign.align(ref_c, hyp_c, ERR)
+ for ref_word, hyp_word in ali:
+ if ref_word == ERR:
+ ins[hyp_word] += 1
+ words[hyp_word][3] += 1
+ elif hyp_word == ERR:
+ dels[ref_word] += 1
+ words[ref_word][4] += 1
+ elif hyp_word != ref_word:
+ subs[(ref_word, hyp_word)] += 1
+ words[ref_word][1] += 1
+ words[hyp_word][2] += 1
+ else:
+ words[ref_word][0] += 1
+ num_corr += 1
+ combine_successive_errors = True
+ if combine_successive_errors:
+ ali = [[[x], [y]] for x, y in ali]
+ for i in range(len(ali) - 1):
+ if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
+ ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
+ ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
+ ali[i] = [[], []]
+ ali = [
+ [
+ list(filter(lambda a: a != ERR, x)),
+ list(filter(lambda a: a != ERR, y)),
+ ]
+ for x, y in ali
+ ]
+ ali = list(filter(lambda x: x != [[], []], ali))
+ ali = [
+ [
+ ERR if x == [] else " ".join(x),
+ ERR if y == [] else " ".join(y),
+ ]
+ for x, y in ali
+ ]
+
+ print(
+ f"{cut_id}:\t"
+ + " ".join(
+ (
+ ref_word
+ if ref_word == hyp_word
+ else f"({ref_word}->{hyp_word})"
+ for ref_word, hyp_word in ali
+ )
+ ),
+ file=f,
+ )
+ ref_len = sum(ref_lens)
+ sub_errs = sum(subs.values())
+ ins_errs = sum(ins.values())
+ del_errs = sum(dels.values())
+ tot_errs = sub_errs + ins_errs + del_errs
+ tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
+
+ if enable_log:
+ logging.info(
+ f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
+ f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
+ f"{del_errs} del, {sub_errs} sub ]"
+ )
+
+ print(f"%WER = {tot_err_rate}", file=f)
+ print(
+ f"Errors: {ins_errs} insertions, {del_errs} deletions, "
+ f"{sub_errs} substitutions, over {ref_len} reference "
+ f"words ({num_corr} correct)",
+ file=f,
+ )
+
+ print("", file=f)
+ print("SUBSTITUTIONS: count ref -> hyp", file=f)
+
+ for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
+ print(f"{count} {ref} -> {hyp}", file=f)
+
+ print("", file=f)
+ print("DELETIONS: count ref", file=f)
+ for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
+ print(f"{count} {ref}", file=f)
+
+ print("", file=f)
+ print("INSERTIONS: count hyp", file=f)
+ for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
+ print(f"{count} {hyp}", file=f)
+
+ print("", file=f)
+ print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
+ for _, word, counts in sorted(
+ [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
+ ):
+ (corr, ref_sub, hyp_sub, ins, dels) = counts
+ tot_errs = ref_sub + hyp_sub + ins + dels
+ ref_count = corr + ref_sub + dels
+ hyp_count = corr + hyp_sub + ins
+
+ print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
+
+ print(f"%WER = {tot_err_rate}", file=f)
+ return float(tot_err_rate)
class MetricsTracker(collections.defaultdict):