SURT multi-talker ASR recipe (#1126)

* merge upstream

* add SURT model and training

* add libricss decoding

* add chunk width randomization

* decode SURT with libricss

* initial commit for zipformer_ctc

* remove unwanted changes

* remove changes to other recipe

* fix zipformer softlink

* fix for JIT export

* add missing file

* fix symbolic links

* update results

* clean commit for SURT recipe

* training libricss surt model

* remove unwanted files

* remove unwanted changes

* remove changes in librispeech

* change some files to symlinks

* remove unwanted changes in utils

* add export script

* add README

* minor fix in README

* add assets for README

* replace some files with symlinks

* remove unused decoding methods

* fix symlink

* address comments from @csukuangfj
This commit is contained in:
Desh Raj 2023-07-04 13:25:58 +02:00 committed by GitHub
parent 856c0f2a60
commit a4402b88e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 6704 additions and 1 deletions

249
egs/libricss/SURT/README.md Normal file
View File

@ -0,0 +1,249 @@
# Introduction
This is a multi-talker ASR recipe for the LibriCSS dataset. We train a Streaming
Unmixing and Recognition Transducer (SURT) model for the task. In this README,
we will describe the task, the model, and the training process. We will also
provide links to pre-trained models and training logs.
## Task
LibriCSS is a multi-talker meeting corpus formed from mixing together LibriSpeech utterances
and replaying in a real meeting room. It consists of 10 1-hour sessions of audio, each
recorded on a 7-channel microphone. The sessions are recorded at a sampling rate of 16 kHz.
For more information, refer to the paper:
Z. Chen et al., "Continuous speech separation: dataset and analysis,"
ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP),
Barcelona, Spain, 2020
In this recipe, we perform the "continuous, streaming, multi-talker ASR" task on LibriCSS.
* By "continuous", we mean that the model should be able to transcribe unsegmented audio
without the need of an external VAD.
* By "streaming", we mean that the model has limited right context. We use a right-context
of at most 32 frames (320 ms).
* By "multi-talker", we mean that the model should be able to transcribe overlapping speech
from multiple speakers.
For now, we do not care about speaker attribution, i.e., the transcription is speaker
agnostic. The evaluation depends on the particular model type. In this case, we use
the optimal reference combination WER (ORC-WER) metric as implemented in the
[meeteval](https://github.com/fgnt/meeteval) toolkit.
## Model
We use the Streaming Unmixing and Recognition Transducer (SURT) model for this task.
The model is based on the papers:
- Lu, Liang et al. “Streaming End-to-End Multi-Talker Speech Recognition.” IEEE Signal Processing Letters 28 (2020): 803-807.
- Raj, Desh et al. “Continuous Streaming Multi-Talker ASR with Dual-Path Transducers.” ICASSP 2022 - 2022 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) (2021): 7317-7321.
The model is a combination of a speech separation model and a speech recognition model,
but trained end-to-end with a single loss function. The overall architecture is shown
in the figure below. Note that this architecture is slightly different from the one
in the above papers. A detailed description of the model can be found in the following
paper: [SURT 2.0: Advanced in transducer-based multi-talker ASR](https://arxiv.org/abs/2306.10559).
<p align="center">
<img src="surt.png">
Streaming Unmixing and Recognition Transducer
</p>
In the [dprnn_zipformer](./dprnn_zipformer) recipe, for example, we use a DPRNN-based masking network
and a Zipfomer-based recognition network. But other combinations are possible as well.
## Training objective
We train the model using the pruned transducer loss, similar to other ASR recipes in
icefall. However, an important consideration is how to assign references to the output
channels (2 in this case). For this, we use the heuristic error assignment training (HEAT)
strategy, which assigns references to the first available channel based on their start
times. An illustrative example is shown in the figure below:
<p align="center">
<img src="heat.png">
Illustration of HEAT-based reference assignment.
</p>
## Description of the recipe
### Pre-requisites
The recipes in this directory need the following packages to be installed:
- [meeteval](https://github.com/fgnt/meeteval)
- [einops](https://github.com/arogozhnikov/einops)
Additionally, we initialize the "recognition" transducer with a pre-trained model,
trained on LibriSpeech. For this, please run the following from within `egs/librispeech/ASR`:
```bash
./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1,2,3"
python pruned_transducer_stateless7_streaming/train.py \
--use-fp16 True \
--exp-dir pruned_transducer_stateless7_streaming/exp \
--world-size 4 \
--max-duration 800 \
--num-epochs 10 \
--keep-last-k 1 \
--manifest-dir data/manifests \
--enable-musan true \
--master-port 54321 \
--bpe-model data/lang_bpe_500/bpe.model \
--num-encoder-layers 2,2,2,2,2 \
--feedforward-dims 768,768,768,768,768 \
--nhead 8,8,8,8,8 \
--encoder-dims 256,256,256,256,256 \
--attention-dims 192,192,192,192,192 \
--encoder-unmasked-dims 192,192,192,192,192 \
--zipformer-downsampling-factors 1,2,4,8,2 \
--cnn-module-kernels 31,31,31,31,31 \
--decoder-dim 512 \
--joiner-dim 512
```
The above is for SURT-base (~26M). For SURT-large (~38M), use `--num-encoder-layers 2,4,3,2,4`.
Once the above model is trained for 10 epochs, copy it to `egs/libricss/SURT/exp`:
```bash
cp -r pruned_transducer_stateless7_streaming/exp/epoch-10.pt exp/zipformer_base.pt
```
**NOTE:** We also provide this pre-trained checkpoint (see the section below), so you can skip
the above step if you want.
### Training
To train the model, run the following from within `egs/libricss/SURT`:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
python dprnn_zipformer/train.py \
--use-fp16 True \
--exp-dir dprnn_zipformer/exp/surt_base \
--world-size 4 \
--max-duration 500 \
--max-duration-valid 250 \
--max-cuts 200 \
--num-buckets 50 \
--num-epochs 30 \
--enable-spec-aug True \
--enable-musan False \
--ctc-loss-scale 0.2 \
--heat-loss-scale 0.2 \
--base-lr 0.004 \
--model-init-ckpt exp/zipformer_base.pt \
--chunk-width-randomization True \
--num-mask-encoder-layers 4 \
--num-encoder-layers 2,2,2,2,2
```
The above is for SURT-base (~26M). For SURT-large (~38M), use:
```bash
--num-mask-encoder-layers 6 \
--num-encoder-layers 2,4,3,2,4 \
--model-init-ckpt exp/zipformer_large.pt \
```
**NOTE:** You may need to decrease the `--max-duration` for SURT-large to avoid OOM.
### Adaptation
The training step above only trains on simulated mixtures. For best results, we also
adapt the final model on the LibriCSS dev set. For this, run the following from within
`egs/libricss/SURT`:
```bash
export CUDA_VISIBLE_DEVICES="0"
python dprnn_zipformer/train_adapt.py \
--use-fp16 True \
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
--world-size 1 \
--max-duration 500 \
--max-duration-valid 250 \
--max-cuts 200 \
--num-buckets 50 \
--num-epochs 8 \
--lr-epochs 2 \
--enable-spec-aug True \
--enable-musan False \
--ctc-loss-scale 0.2 \
--base-lr 0.0004 \
--model-init-ckpt dprnn_zipformer/exp/surt_base/epoch-30.pt \
--chunk-width-randomization True \
--num-mask-encoder-layers 4 \
--num-encoder-layers 2,2,2,2,2
```
For SURT-large, use the following config:
```bash
--num-mask-encoder-layers 6 \
--num-encoder-layers 2,4,3,2,4 \
--model-init-ckpt dprnn_zipformer/exp/surt_large/epoch-30.pt \
--num-epochs 15 \
--lr-epochs 4 \
```
### Decoding
To decode the model, run the following from within `egs/libricss/SURT`:
#### Greedy search
```bash
export CUDA_VISIBLE_DEVICES="0"
python dprnn_zipformer/decode.py \
--epoch 8 --avg 1 --use-averaged-model False \
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
--max-duration 250 \
--decoding-method greedy_search
```
#### Beam search
```bash
python dprnn_zipformer/decode.py \
--epoch 8 --avg 1 --use-averaged-model False \
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
--max-duration 250 \
--decoding-method modified_beam_search \
--beam-size 4
```
## Results (using beam search)
#### IHM-Mix
| Model | # params | 0L | 0S | OV10 | OV20 | OV30 | OV40 | Avg. |
|------------|:-------:|:----:|:---:|----:|:----:|:----:|:----:|:----:|
| dprnn_zipformer (base) | 26.7 | 5.1 | 4.2 | 13.7 | 18.7 | 20.5 | 20.6 | 13.8 |
| dprnn_zipformer (large) | 37.9 | 4.6 | 3.8 | 12.7 | 14.3 | 16.7 | 21.2 | 12.2 |
#### SDM
| Model | # params | 0L | 0S | OV10 | OV20 | OV30 | OV40 | Avg. |
|------------|:-------:|:----:|:---:|----:|:----:|:----:|:----:|:----:|
| dprnn_zipformer (base) | 26.7 | 6.8 | 7.2 | 21.4 | 24.5 | 28.6 | 31.2 | 20.0 |
| dprnn_zipformer (large) | 37.9 | 6.4 | 6.9 | 17.9 | 19.7 | 25.2 | 25.5 | 16.9 |
## Pre-trained models and logs
* Pre-trained models: <https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer>
* Training logs:
- surt_base: <https://tensorboard.dev/experiment/YLGJTkBETb2aqDQ61jbxvQ/>
- surt_base_adapt: <https://tensorboard.dev/experiment/pjXMFVL9RMej85rMHyd0EQ/>
- surt_large: <https://tensorboard.dev/experiment/82HvYqfrSOKZ4w8Jod2QMw/>
- surt_large_adapt: <https://tensorboard.dev/experiment/5oIdEgRqS9Wb6yVuxaExEw/>

View File

@ -0,0 +1,372 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
# Copyright 2023 Johns Hopkins Univrtsity (Author: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutMix,
DynamicBucketingSampler,
K2SurtDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class _SeedWorkers:
def __init__(self, seed: int):
self.seed = seed
def __call__(self, worker_id: int):
fix_random_seed(self.seed + worker_id)
class LibriCssAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/manifests"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--max-duration-valid",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--max-cuts",
type=int,
default=100,
help="Maximum number of cuts in a single batch. You can "
"reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help=(
"When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available."
),
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=True,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(
self,
cuts_train: CutSet,
sampler_state_dict: Optional[Dict[str, Any]] = None,
return_sources: bool = True,
strict: bool = True,
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
if self.args.enable_musan:
logging.info("Enable MUSAN")
logging.info("About to get Musan cuts")
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
transforms.append(
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
)
else:
logging.info("Disable MUSAN")
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SurtDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
return_sources=return_sources,
strict=strict,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
quadratic_duration=30.0,
max_cuts=self.args.max_cuts,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SingleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
max_cuts=self.args.max_cuts,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
# 'seed' is derived from the current random state, which will have
# previously been set in the main process.
seed = torch.randint(0, 100000, ()).item()
worker_init_fn = _SeedWorkers(seed)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
worker_init_fn=worker_init_fn,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
logging.info("About to create dev dataset")
validate = K2SurtDataset(
input_strategy=OnTheFlyFeatures(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
)
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
return_sources=False,
strict=False,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration_valid,
max_cuts=self.args.max_cuts,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = K2SurtDataset(
input_strategy=OnTheFlyFeatures(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
)
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
return_sources=False,
strict=False,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration_valid,
max_cuts=self.args.max_cuts,
shuffle=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def lsmix_cuts(
self,
rvb_affix: str = "clean",
type_affix: str = "full",
sources: bool = True,
) -> CutSet:
logging.info("About to get train cuts")
source_affix = "_sources" if sources else ""
cs = load_manifest_lazy(
self.args.manifest_dir
/ f"cuts_train_{rvb_affix}_{type_affix}{source_affix}.jsonl.gz"
)
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0)
return cs
@lru_cache()
def libricss_cuts(self, split="dev", type="sdm") -> CutSet:
logging.info(f"About to get LibriCSS {split} {type} cuts")
cs = load_manifest_lazy(
self.args.manifest_dir / f"cuts_{split}_libricss-{type}.jsonl.gz"
)
return cs

View File

@ -0,0 +1,730 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple, Union
import k2
import torch
from model import SURT
from icefall import NgramLmStateCost
from icefall.utils import DecodingResults
def greedy_search(
model: SURT,
encoder_out: torch.Tensor,
max_sym_per_frame: int,
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""Greedy search for a single utterance.
Args:
model:
An instance of `SURT`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
max_sym_per_frame:
Maximum number of symbols per frame. If it is set to 0, the WER
would be 100%.
return_timestamps:
Whether to return timestamps.
Returns:
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
assert encoder_out.ndim == 4
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
unk_id = getattr(model, "unk_id", blank_id)
device = next(model.parameters()).device
decoder_input = torch.tensor(
[-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
encoder_out = model.joiner.encoder_proj(encoder_out)
T = encoder_out.size(1)
t = 0
hyp = [blank_id] * context_size
# timestamp[i] is the frame index after subsampling
# on which hyp[i] is decoded
timestamp = []
# Maximum symbols per utterance.
max_sym_per_utt = 1000
# symbols per frame
sym_per_frame = 0
# symbols per utterance decoded so far
sym_per_utt = 0
while t < T and sym_per_utt < max_sym_per_utt:
if sym_per_frame >= max_sym_per_frame:
sym_per_frame = 0
t += 1
continue
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
)
# logits is (1, 1, 1, vocab_size)
y = logits.argmax().item()
if y not in (blank_id, unk_id):
hyp.append(y)
timestamp.append(t)
decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape(
1, context_size
)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
sym_per_utt += 1
sym_per_frame += 1
else:
sym_per_frame = 0
t += 1
hyp = hyp[context_size:] # remove blanks
if not return_timestamps:
return hyp
else:
return DecodingResults(
hyps=[hyp],
timestamps=[timestamp],
)
def greedy_search_batch(
model: SURT,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The SURT model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
return_timestamps:
Whether to return timestamps.
Returns:
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
device = next(model.parameters()).device
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)]
# timestamp[n][i] is the frame index after subsampling
# on which hyp[n][i] is decoded
timestamps = [[] for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out: (N, 1, decoder_out_dim)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
)
# logits'shape (batch_size, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
assert logits.ndim == 2, logits.shape
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v not in (blank_id, unk_id):
hyps[i].append(v)
timestamps[i].append(t)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(timestamps[unsorted_indices[i]])
if not return_timestamps:
return ans
else:
return DecodingResults(
hyps=ans,
timestamps=ans_timestamps,
)
def modified_beam_search(
model: SURT,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4,
temperature: float = 1.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The SURT model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam:
Number of active paths during the beam search.
temperature:
Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns:
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[],
)
)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_timestamp.append(t)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
sorted_timestamps = [h.timestamp for h in best_hyps]
ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
if not return_timestamps:
return ans
else:
return DecodingResults(
hyps=ans,
timestamps=ans_timestamps,
)
def beam_search(
model: SURT,
encoder_out: torch.Tensor,
beam: int = 4,
temperature: float = 1.0,
return_timestamps: bool = False,
) -> Union[List[int], DecodingResults]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
espnet/nets/beam_search_SURT.py#L247 is used as a reference.
Args:
model:
An instance of `SURT`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
temperature:
Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns:
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
decoder_input = torch.tensor(
[blank_id] * context_size,
device=device,
dtype=torch.int64,
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
encoder_out = model.joiner.encoder_proj(encoder_out)
T = encoder_out.size(1)
t = 0
B = HypothesisList()
B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0, timestamp=[]))
max_sym_per_utt = 20000
sym_per_utt = 0
decoder_cache: Dict[str, torch.Tensor] = {}
while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2)
# fmt: on
A = B
B = HypothesisList()
joint_cache: Dict[str, torch.Tensor] = {}
# TODO(fangjun): Implement prefix search to update the `log_prob`
# of hypotheses in A
while True:
y_star = A.get_most_probable()
A.remove(y_star)
cached_key = y_star.key
if cached_key not in decoder_cache:
decoder_input = torch.tensor(
[y_star.ys[-context_size:]],
device=device,
dtype=torch.int64,
).reshape(1, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
decoder_cache[cached_key] = decoder_out
else:
decoder_out = decoder_cache[cached_key]
cached_key += f"-t-{t}"
if cached_key not in joint_cache:
logits = model.joiner(
current_encoder_out,
decoder_out.unsqueeze(1),
project_input=False,
)
# TODO(fangjun): Scale the blank posterior
log_prob = (logits / temperature).log_softmax(dim=-1)
# log_prob is (1, 1, 1, vocab_size)
log_prob = log_prob.squeeze()
# Now log_prob is (vocab_size,)
joint_cache[cached_key] = log_prob
else:
log_prob = joint_cache[cached_key]
# First, process the blank symbol
skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob
# ys[:] returns a copy of ys
B.add(
Hypothesis(
ys=y_star.ys[:],
log_prob=new_y_star_log_prob,
timestamp=y_star.timestamp[:],
)
)
# Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1)
for i, v in zip(indices.tolist(), values.tolist()):
if i in (blank_id, unk_id):
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v
new_timestamp = y_star.timestamp + [t]
A.add(
Hypothesis(
ys=new_ys,
log_prob=new_log_prob,
timestamp=new_timestamp,
)
)
# Check whether B contains more than "beam" elements more probable
# than the most probable in A
A_most_probable = A.get_most_probable()
kept_B = B.filter(A_most_probable.log_prob)
if len(kept_B) >= beam:
B = kept_B.topk(beam)
break
t += 1
best_hyp = B.get_most_probable(length_norm=True)
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
if not return_timestamps:
return ys
else:
return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp])
@dataclass
class Hypothesis:
# The predicted tokens so far.
# Newly predicted tokens are appended to `ys`.
ys: List[int]
# The log prob of ys.
# It contains only one entry.
log_prob: torch.Tensor
# timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded
timestamp: List[int] = field(default_factory=list)
# the lm score for next token given the current ys
lm_score: Optional[torch.Tensor] = None
# the RNNLM states (h and c in LSTM)
state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
# N-gram LM state
state_cost: Optional[NgramLmStateCost] = None
@property
def key(self) -> str:
"""Return a string representation of self.ys"""
return "_".join(map(str, self.ys))
class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
"""
Args:
data:
A dict of Hypotheses. Its key is its `value.key`.
"""
if data is None:
self._data = {}
else:
self._data = data
@property
def data(self) -> Dict[str, Hypothesis]:
return self._data
def add(self, hyp: Hypothesis) -> None:
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
`log-sum-exp` with the existed one.
Args:
hyp:
The hypothesis to be added.
"""
key = hyp.key
if key in self:
old_hyp = self._data[key] # shallow copy
torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
else:
self._data[key] = hyp
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
"""Get the most probable hypothesis, i.e., the one with
the largest `log_prob`.
Args:
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.
Returns:
Return the hypothesis that has the largest `log_prob`.
"""
if length_norm:
return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
Caution:
`self` is modified **in-place**.
Args:
hyp:
The hypothesis to be removed from `self`.
Note: It must be contained in `self`. Otherwise,
an exception is raised.
"""
key = hyp.key
assert key in self, f"{key} does not exist"
del self._data[key]
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
Caution:
`self` is not modified. Instead, a new HypothesisList is returned.
Returns:
Return a new HypothesisList containing all hypotheses from `self`
with `log_prob` being greater than the given `threshold`.
"""
ans = HypothesisList()
for _, hyp in self._data.items():
if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy
return ans
def topk(self, k: int) -> "HypothesisList":
"""Return the top-k hypothesis."""
hyps = list(self._data.items())
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
ans = HypothesisList(dict(hyps))
return ans
def __contains__(self, key: str):
return key in self._data
def __iter__(self):
return iter(self._data.values())
def __len__(self) -> int:
return len(self._data)
def __str__(self) -> str:
s = []
for key in self:
s.append(key)
return ", ".join(s)
def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps].
Args:
hyps:
len(hyps) == batch_size. It contains the current hypothesis for
each utterance in the batch.
Returns:
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
the shape is on CPU.
"""
num_hyps = [len(h) for h in hyps]
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
# to get exclusive sum later.
num_hyps.insert(0, 0)
num_hyps = torch.tensor(num_hyps)
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
ans = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
)
return ans

View File

@ -0,0 +1,654 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./dprnn_zipformer/decode.py \
--epoch 30 \
--avg 9 \
--use-averaged-model true \
--exp-dir ./dprnn_zipformer/exp \
--max-duration 600 \
--decoding-method greedy_search
(2) modified beam search
./dprnn_zipformer/decode.py \
--epoch 30 \
--avg 9 \
--use-averaged-model true \
--exp-dir ./dprnn_zipformer/exp \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
"""
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriCssAsrDataModule
from beam_search import (
beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from lhotse.utils import EPSILON
from train import add_model_arguments, get_params, get_surt_model
from icefall import LmScorer, NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_surt_error_stats,
)
OVERLAP_RATIOS = ["0L", "0S", "OV10", "OV20", "OV30", "OV40"]
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="dprnn_zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe_500",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- beam_search
- modified_beam_search
""",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame.
Used only when --decoding_method is greedy_search""",
)
parser.add_argument(
"--save-masks",
type=str2bool,
default=False,
help="""If true, save masks generated by unmixing module.""",
)
add_model_arguments(parser)
return parser
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
feature_lens = batch["input_lens"].to(device)
# Apply the mask encoder
B, T, F = feature.shape
processed = model.mask_encoder(feature) # B,T,F*num_channels
masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1)
x_masked = [feature * m for m in masks]
masks_dict = {}
if params.save_masks:
# To save the masks, we split them by batch and trim each mask to the length of
# the corresponding feature. We save them in a dict, where the key is the
# cut ID and the value is the mask.
for i in range(B):
mask = torch.cat(
[x_masked[j][i, : feature_lens[i]] for j in range(params.num_channels)],
dim=-1,
)
mask = mask.cpu().numpy()
masks_dict[batch["cuts"][i].id] = mask
# Recognition
# Concatenate the inputs along the batch axis
h = torch.cat(x_masked, dim=0)
h_lens = feature_lens.repeat(params.num_channels)
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
if model.joint_encoder_layer is not None:
encoder_out = model.joint_encoder_layer(encoder_out)
def _group_channels(hyps: List[str]) -> List[List[str]]:
"""
Currently we have a batch of size M*B, where M is the number of
channels and B is the batch size. We need to group the hypotheses
into B groups, each of which contains M hypotheses.
Example:
hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2']
_group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']]
"""
assert len(hyps) == B * params.num_channels
out_hyps = []
for i in range(B):
out_hyps.append(hyps[i::B])
return out_hyps
hyps = []
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp)
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp)
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append(sp.decode(hyp))
if params.decoding_method == "greedy_search":
return {"greedy_search": _group_channels(hyps)}, masks_dict
else:
return {f"beam_size_{params.beam_size}": _group_channels(hyps)}, masks_dict
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
if params.decoding_method == "greedy_search":
log_interval = 50
else:
log_interval = 20
results = defaultdict(list)
masks = {}
for batch_idx, batch in enumerate(dl):
cut_ids = [cut.id for cut in batch["cuts"]]
cuts_batch = batch["cuts"]
hyps_dict, masks_dict = decode_one_batch(
params=params,
model=model,
sp=sp,
)
masks.update(masks_dict)
for name, hyps in hyps_dict.items():
this_batch = []
for cut_id, hyp_words in zip(cut_ids, hyps):
# Reference is a list of supervision texts sorted by start time.
ref_words = [
s.text.strip()
for s in sorted(
cuts_batch[cut_id].supervisions, key=lambda s: s.start
)
]
this_batch.append((cut_id, ref_words, hyp_words))
results[name].extend(this_batch)
num_cuts += len(cut_ids)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results, masks_dict
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_surt_error_stats(
f,
f"{test_set_name}-{key}",
results,
enable_log=True,
num_channels=params.num_channels,
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
def save_masks(
params: AttributeDict,
test_set_name: str,
masks: List[torch.Tensor],
):
masks_path = params.res_dir / f"masks-{test_set_name}.txt"
torch.save(masks, masks_path)
logging.info(f"The masks are stored in {masks_path}")
@torch.no_grad()
def main():
parser = get_parser()
LmScorer.add_arguments(parser)
LibriCssAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
params.update(vars(args))
assert params.decoding_method in (
"greedy_search",
"beam_search",
"modified_beam_search",
), f"Decoding method {params.decoding_method} is not supported."
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"Device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> and <unk> are defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_surt_model(params)
assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
model.encoder.decode_chunk_size,
params.decode_chunk_len,
)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
libricss = LibriCssAsrDataModule(args)
dev_cuts = libricss.libricss_cuts(split="dev", type="ihm-mix").to_eager()
dev_cuts_grouped = [dev_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS]
test_cuts = libricss.libricss_cuts(split="test", type="ihm-mix").to_eager()
test_cuts_grouped = [
test_cuts.filter(lambda x: ol in x.id) for ol in OVERLAP_RATIOS
]
for dev_set, ol in zip(dev_cuts_grouped, OVERLAP_RATIOS):
dev_dl = libricss.test_dataloaders(dev_set)
results_dict, masks = decode_dataset(
dl=dev_dl,
params=params,
model=model,
sp=sp,
)
save_results(
params=params,
test_set_name=f"dev_{ol}",
results_dict=results_dict,
)
if params.save_masks:
save_masks(
params=params,
test_set_name=f"dev_{ol}",
masks=masks,
)
for test_set, ol in zip(test_cuts_grouped, OVERLAP_RATIOS):
test_dl = libricss.test_dataloaders(test_set)
results_dict, masks = decode_dataset(
dl=test_dl,
params=params,
model=model,
sp=sp,
)
save_results(
params=params,
test_set_name=f"test_{ol}",
results_dict=results_dict,
)
if params.save_masks:
save_masks(
params=params,
test_set_name=f"test_{ol}",
masks=masks,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/decoder.py

View File

@ -0,0 +1,305 @@
import random
from typing import Optional, Tuple
import torch
import torch.nn as nn
from einops import rearrange
from scaling import ActivationBalancer, BasicNorm, DoubleSwish, ScaledLinear, ScaledLSTM
from torch.autograd import Variable
EPS = torch.finfo(torch.get_default_dtype()).eps
def _pad_segment(input, segment_size):
# Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L342
# input is the features: (B, N, T)
batch_size, dim, seq_len = input.shape
segment_stride = segment_size // 2
rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size
if rest > 0:
pad = Variable(torch.zeros(batch_size, dim, rest)).type(input.type())
input = torch.cat([input, pad], 2)
pad_aux = Variable(torch.zeros(batch_size, dim, segment_stride)).type(input.type())
input = torch.cat([pad_aux, input, pad_aux], 2)
return input, rest
def split_feature(input, segment_size):
# Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L358
# split the feature into chunks of segment size
# input is the features: (B, N, T)
input, rest = _pad_segment(input, segment_size)
batch_size, dim, seq_len = input.shape
segment_stride = segment_size // 2
segments1 = (
input[:, :, :-segment_stride]
.contiguous()
.view(batch_size, dim, -1, segment_size)
)
segments2 = (
input[:, :, segment_stride:]
.contiguous()
.view(batch_size, dim, -1, segment_size)
)
segments = (
torch.cat([segments1, segments2], 3)
.view(batch_size, dim, -1, segment_size)
.transpose(2, 3)
)
return segments.contiguous(), rest
def merge_feature(input, rest):
# Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py#L385
# merge the splitted features into full utterance
# input is the features: (B, N, L, K)
batch_size, dim, segment_size, _ = input.shape
segment_stride = segment_size // 2
input = (
input.transpose(2, 3).contiguous().view(batch_size, dim, -1, segment_size * 2)
) # B, N, K, L
input1 = (
input[:, :, :, :segment_size]
.contiguous()
.view(batch_size, dim, -1)[:, :, segment_stride:]
)
input2 = (
input[:, :, :, segment_size:]
.contiguous()
.view(batch_size, dim, -1)[:, :, :-segment_stride]
)
output = input1 + input2
if rest > 0:
output = output[:, :, :-rest]
return output.contiguous() # B, N, T
class RNNEncoderLayer(nn.Module):
"""
RNNEncoderLayer is made up of lstm and feedforward networks.
Args:
input_size:
The number of expected features in the input (required).
hidden_size:
The hidden dimension of rnn layer.
dropout:
The dropout value (default=0.1).
layer_dropout:
The dropout value for model-level warmup (default=0.075).
"""
def __init__(
self,
input_size: int,
hidden_size: int,
dropout: float = 0.1,
bidirectional: bool = False,
) -> None:
super(RNNEncoderLayer, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
assert hidden_size >= input_size, (hidden_size, input_size)
self.lstm = ScaledLSTM(
input_size=input_size,
hidden_size=hidden_size // 2 if bidirectional else hidden_size,
proj_size=0,
num_layers=1,
dropout=0.0,
batch_first=True,
bidirectional=bidirectional,
)
self.norm_final = BasicNorm(input_size)
# try to ensure the output is close to zero-mean (or at least, zero-median). # noqa
self.balancer = ActivationBalancer(
num_channels=input_size,
channel_dim=-1,
min_positive=0.45,
max_positive=0.55,
max_abs=6.0,
)
self.dropout = nn.Dropout(dropout)
def forward(
self,
src: torch.Tensor,
states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
warmup: float = 1.0,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Pass the input through the encoder layer.
Args:
src:
The sequence to the encoder layer (required).
Its shape is (S, N, E), where S is the sequence length,
N is the batch size, and E is the feature number.
states:
A tuple of 2 tensors (optional). It is for streaming inference.
states[0] is the hidden states of all layers,
with shape of (1, N, input_size);
states[1] is the cell states of all layers,
with shape of (1, N, hidden_size).
"""
src_orig = src
# alpha = 1.0 means fully use this encoder layer, 0.0 would mean
# completely bypass it.
alpha = warmup if self.training else 1.0
# lstm module
src_lstm, new_states = self.lstm(src, states)
src = self.dropout(src_lstm) + src
src = self.norm_final(self.balancer(src))
if alpha != 1.0:
src = alpha * src + (1 - alpha) * src_orig
return src
# dual-path RNN
class DPRNN(nn.Module):
"""Deep dual-path RNN.
Source: https://github.com/espnet/espnet/blob/master/espnet2/enh/layers/dprnn.py
args:
input_size: int, dimension of the input feature. The input should have shape
(batch, seq_len, input_size).
hidden_size: int, dimension of the hidden state.
output_size: int, dimension of the output size.
dropout: float, dropout ratio. Default is 0.
num_blocks: int, number of stacked RNN layers. Default is 1.
"""
def __init__(
self,
feature_dim,
input_size,
hidden_size,
output_size,
dropout=0.1,
num_blocks=1,
segment_size=50,
chunk_width_randomization=False,
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.hidden_size = hidden_size
self.segment_size = segment_size
self.chunk_width_randomization = chunk_width_randomization
self.input_embed = nn.Sequential(
ScaledLinear(feature_dim, input_size),
BasicNorm(input_size),
ActivationBalancer(
num_channels=input_size,
channel_dim=-1,
min_positive=0.45,
max_positive=0.55,
),
)
# dual-path RNN
self.row_rnn = nn.ModuleList([])
self.col_rnn = nn.ModuleList([])
for _ in range(num_blocks):
# intra-RNN is non-causal
self.row_rnn.append(
RNNEncoderLayer(
input_size, hidden_size, dropout=dropout, bidirectional=True
)
)
self.col_rnn.append(
RNNEncoderLayer(
input_size, hidden_size, dropout=dropout, bidirectional=False
)
)
# output layer
self.out_embed = nn.Sequential(
ScaledLinear(input_size, output_size),
BasicNorm(output_size),
ActivationBalancer(
num_channels=output_size,
channel_dim=-1,
min_positive=0.45,
max_positive=0.55,
),
)
def forward(self, input):
# input shape: B, T, F
input = self.input_embed(input)
B, T, D = input.shape
if self.chunk_width_randomization and self.training:
segment_size = random.randint(self.segment_size // 2, self.segment_size)
else:
segment_size = self.segment_size
input, rest = split_feature(input.transpose(1, 2), segment_size)
# input shape: batch, N, dim1, dim2
# apply RNN on dim1 first and then dim2
# output shape: B, output_size, dim1, dim2
# input = input.to(device)
batch_size, _, dim1, dim2 = input.shape
output = input
for i in range(len(self.row_rnn)):
row_input = (
output.permute(0, 3, 2, 1)
.contiguous()
.view(batch_size * dim2, dim1, -1)
) # B*dim2, dim1, N
output = self.row_rnn[i](row_input) # B*dim2, dim1, H
output = (
output.view(batch_size, dim2, dim1, -1).permute(0, 3, 2, 1).contiguous()
) # B, N, dim1, dim2
col_input = (
output.permute(0, 2, 3, 1)
.contiguous()
.view(batch_size * dim1, dim2, -1)
) # B*dim1, dim2, N
output = self.col_rnn[i](col_input) # B*dim1, dim2, H
output = (
output.view(batch_size, dim1, dim2, -1).permute(0, 3, 1, 2).contiguous()
) # B, N, dim1, dim2
output = merge_feature(output, rest)
output = output.transpose(1, 2)
output = self.out_embed(output)
# Apply ReLU to the output
output = torch.relu(output)
return output
if __name__ == "__main__":
model = DPRNN(
80,
256,
256,
160,
dropout=0.1,
num_blocks=4,
segment_size=32,
chunk_width_randomization=True,
)
input = torch.randn(2, 1002, 80)
print(sum(p.numel() for p in model.parameters()))
print(model(input).shape)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/encoder_interface.py

View File

@ -0,0 +1,306 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This script converts several saved checkpoints
# to a single one using model averaging.
"""
Usage:
(1) Export to torchscript model using torch.jit.script()
./dprnn_zipformer/export.py \
--exp-dir ./dprnn_zipformer/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 9 \
--jit 1
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
load it by `torch.jit.load("cpu_jit.pt")`.
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
Check
https://github.com/k2-fsa/sherpa
for how to use the exported models outside of icefall.
(2) Export `model.state_dict()`
./dprnn_zipformer/export.py \
--exp-dir ./dprnn_zipformer/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 30 \
--avg 9
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
load it by `icefall.checkpoint.load_checkpoint()`.
To use the generated file with `dprnn_zipformer/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR
./dprnn_zipformer/decode.py \
--exp-dir ./dprnn_zipformer/exp \
--epoch 9999 \
--avg 1 \
--max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model
"""
import argparse
import logging
from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from scaling_converter import convert_scaled_to_non_scaled
from train import add_model_arguments, get_params, get_surt_model
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import str2bool
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=9,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="dprnn_zipformer/exp",
help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--jit",
type=str2bool,
default=False,
help="""True to save a model after applying torch.jit.script.
It will generate a file named cpu_jit.pt
Check ./jit_pretrained.py for how to use it.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
add_model_arguments(parser)
return parser
@torch.no_grad()
def main():
args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_surt_model(params)
model.to(device)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to("cpu")
model.eval()
if params.jit is True:
convert_scaled_to_non_scaled(model, inplace=True)
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script")
model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt"
model.save(str(filename))
logging.info(f"Saved to {filename}")
else:
logging.info("Not using torchscript. Export model.state_dict()")
# Save it using a format so that it can be loaded
# by :func:`load_checkpoint`
filename = params.exp_dir / "pretrained.pt"
torch.save({"model": model.state_dict()}, str(filename))
logging.info(f"Saved to {filename}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/joiner.py

View File

@ -0,0 +1,316 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang)
# Copyright 2023 Johns Hopkins University (author: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
class SURT(nn.Module):
"""It implements Streaming Unmixing and Recognition Transducer (SURT).
https://arxiv.org/abs/2011.13148
"""
def __init__(
self,
mask_encoder: nn.Module,
encoder: EncoderInterface,
joint_encoder_layer: Optional[nn.Module],
decoder: nn.Module,
joiner: nn.Module,
num_channels: int,
encoder_dim: int,
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
):
"""
Args:
mask_encoder:
It is the masking network. It generates a mask for each channel of the
encoder. These masks are applied to the input features, and then passed
to the transcription network.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, decoder_dim).
It should contain one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
num_channels:
It is the number of channels that the input features will be split into.
In general, it should be equal to the maximum number of simultaneously
active speakers. For most real scenarios, using 2 channels is sufficient.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
assert hasattr(decoder, "blank_id")
self.mask_encoder = mask_encoder
self.encoder = encoder
self.joint_encoder_layer = joint_encoder_layer
self.decoder = decoder
self.joiner = joiner
self.num_channels = num_channels
self.simple_am_proj = nn.Linear(
encoder_dim,
vocab_size,
)
self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size)
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
nn.Linear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
def forward_helper(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
reduction: str = "sum",
beam_size: int = 10,
use_double_scores: bool = False,
subsampling_factor: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute transducer loss for one branch of the SURT model.
"""
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
if self.joint_encoder_layer is not None:
encoder_out = self.joint_encoder_layer(encoder_out)
# compute ctc log-probs
ctc_output = self.ctc_output(encoder_out)
# For the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
# sos_y_padded: [B, S + 1], start with SOS.
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
# decoder_out: [B, S + 1, decoder_dim]
decoder_out = self.decoder(sos_y_padded)
# Note: y does not start with SOS
# y_padded : [B, S]
y_padded = y.pad(mode="constant", padding_value=0)
y_padded = y_padded.to(torch.int64)
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device)
boundary[:, 2] = y_lens
boundary[:, 3] = x_lens
lm = self.simple_lm_proj(decoder_out)
am = self.simple_am_proj(encoder_out)
with torch.cuda.amp.autocast(enabled=False):
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
lm=lm.float(),
am=am.float(),
symbols=y_padded,
termination_symbol=blank_id,
lm_only_scale=lm_scale,
am_only_scale=am_scale,
boundary=boundary,
reduction=reduction,
return_grad=True,
)
# ranges : [B, T, prune_range]
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=prune_range,
)
# am_pruned : [B, T, prune_range, encoder_dim]
# lm_pruned : [B, T, prune_range, decoder_dim]
am_pruned, lm_pruned = k2.do_rnnt_pruning(
am=self.joiner.encoder_proj(encoder_out),
lm=self.joiner.decoder_proj(decoder_out),
ranges=ranges,
)
# logits : [B, T, prune_range, vocab_size]
# project_input=False since we applied the decoder's input projections
# prior to do_rnnt_pruning (this is an optimization for speed).
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
with torch.cuda.amp.autocast(enabled=False):
pruned_loss = k2.rnnt_loss_pruned(
logits=logits.float(),
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
reduction=reduction,
)
# Compute ctc loss
supervision_segments = torch.stack(
(
torch.arange(len(x_lens), device="cpu"),
torch.zeros_like(x_lens, device="cpu"),
torch.clone(x_lens).detach().cpu(),
),
dim=1,
).to(torch.int32)
# We need to sort supervision_segments in decreasing order of num_frames
indices = torch.argsort(supervision_segments[:, 2], descending=True)
supervision_segments = supervision_segments[indices]
# Works with a BPE model
decoding_graph = k2.ctc_graph(y, modified=False, device=x.device)
dense_fsa_vec = k2.DenseFsaVec(
ctc_output,
supervision_segments,
allow_truncate=subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=beam_size,
reduction="none",
use_double_scores=use_double_scores,
)
return (simple_loss, pruned_loss, ctc_loss)
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
prune_range: int = 5,
am_scale: float = 0.0,
lm_scale: float = 0.0,
reduction: str = "sum",
beam_size: int = 10,
use_double_scores: bool = False,
subsampling_factor: int = 1,
return_masks: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor of shape (N*num_channels, S). It contains the labels
of the N utterances. The labels are in the range [0, vocab_size). All
the channels are concatenated together one after another.
prune_range:
The prune range for rnnt loss, it means how many symbols(context)
we are considering for each frame to compute the loss.
am_scale:
The scale to smooth the loss with am (output of encoder network)
part
lm_scale:
The scale to smooth the loss with lm (output of predictor network)
part
reduction:
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
beam_size:
The beam size used in CTC decoding.
use_double_scores:
If True, use double precision for CTC decoding.
subsampling_factor:
The subsampling factor of the model. It is used to compute the
supervision segments for CTC loss.
return_masks:
If True, return the masks as well as masked features.
Returns:
Return the transducer loss.
Note:
Regarding am_scale & lm_scale, it will make the loss-function one of
the form:
lm_scale * lm_probs + am_scale * am_probs +
(1-lm_scale-am_scale) * combined_probs
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0), (x.size(), x_lens.size())
# Apply the mask encoder
B, T, F = x.shape
processed = self.mask_encoder(x) # B,T,F*num_channels
masks = processed.view(B, T, F, self.num_channels).unbind(dim=-1)
x_masked = [x * m for m in masks]
# Recognition
# Stack the inputs along the batch axis
h = torch.cat(x_masked, dim=0)
h_lens = torch.cat([x_lens for _ in range(self.num_channels)], dim=0)
simple_loss, pruned_loss, ctc_loss = self.forward_helper(
h,
h_lens,
y,
prune_range,
am_scale,
lm_scale,
reduction=reduction,
beam_size=beam_size,
use_double_scores=use_double_scores,
subsampling_factor=subsampling_factor,
)
# Chunks the outputs into 2 parts along batch axis and then stack them along a new axis.
simple_loss = torch.stack(
torch.chunk(simple_loss, self.num_channels, dim=0), dim=0
)
pruned_loss = torch.stack(
torch.chunk(pruned_loss, self.num_channels, dim=0), dim=0
)
ctc_loss = torch.stack(torch.chunk(ctc_loss, self.num_channels, dim=0), dim=0)
if return_masks:
return (simple_loss, pruned_loss, ctc_loss, x_masked, masks)
else:
return (simple_loss, pruned_loss, ctc_loss, x_masked)

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/optim.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py

BIN
egs/libricss/SURT/heat.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 298 KiB

View File

@ -0,0 +1,85 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file adds source features as temporal arrays to the mixture manifests.
It looks for manifests in the directory data/manifests.
"""
import logging
from pathlib import Path
import numpy as np
from lhotse import CutSet, LilcomChunkyWriter, load_manifest, load_manifest_lazy
from tqdm import tqdm
def add_source_feats(num_jobs=1):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
for type_affix in ["full", "ov40"]:
logging.info(f"Adding source features for {type_affix}")
mixed_name_clean = f"train_clean_{type_affix}"
mixed_name_rvb = f"train_rvb_{type_affix}"
logging.info("Reading mixed cuts")
mixed_cuts_clean = load_manifest_lazy(
src_dir / f"cuts_{mixed_name_clean}.jsonl.gz"
)
mixed_cuts_rvb = load_manifest_lazy(src_dir / f"cuts_{mixed_name_rvb}.jsonl.gz")
logging.info("Reading source cuts")
source_cuts = load_manifest(src_dir / "librispeech_cuts_train_trimmed.jsonl.gz")
logging.info("Adding source features to the mixed cuts")
with tqdm() as pbar, CutSet.open_writer(
src_dir / f"cuts_{mixed_name_clean}_sources.jsonl.gz"
) as cut_writer_clean, CutSet.open_writer(
src_dir / f"cuts_{mixed_name_rvb}_sources.jsonl.gz"
) as cut_writer_rvb, LilcomChunkyWriter(
output_dir / f"feats_train_{type_affix}_sources"
) as source_feat_writer:
for cut_clean, cut_rvb in zip(mixed_cuts_clean, mixed_cuts_rvb):
assert cut_rvb.id == cut_clean.id + "_rvb"
# Create source_feats and source_feat_offsets
# (See `lhotse.datasets.K2SurtDataset` for details)
source_feats = []
source_feat_offsets = []
cur_offset = 0
for sup in sorted(
cut_clean.supervisions, key=lambda s: (s.start, s.speaker)
):
source_cut = source_cuts[sup.id]
source_feats.append(source_cut.load_features())
source_feat_offsets.append(cur_offset)
cur_offset += source_cut.num_frames
cut_clean.source_feats = source_feat_writer.store_array(
cut_clean.id, np.concatenate(source_feats, axis=0)
)
cut_clean.source_feat_offsets = source_feat_offsets
cut_writer_clean.write(cut_clean)
cut_rvb.source_feats = cut_clean.source_feats
cut_rvb.source_feat_offsets = cut_clean.source_feat_offsets
cut_writer_rvb.write(cut_rvb)
pbar.update(1)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
add_source_feats()

View File

@ -0,0 +1,105 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the LibriCSS dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
from pathlib import Path
import pyloudnorm as pyln
import torch
import torch.multiprocessing
from lhotse import LilcomChunkyWriter, load_manifest_lazy
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def compute_fbank_libricss():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
sampling_rate = 16000
num_mel_bins = 80
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
logging.info("Reading manifests")
cuts_ihm_mix = load_manifest_lazy(
src_dir / "libricss-ihm-mix_segments_all.jsonl.gz"
)
cuts_sdm = load_manifest_lazy(src_dir / "libricss-sdm_segments_all.jsonl.gz")
for name, cuts in [("ihm-mix", cuts_ihm_mix), ("sdm", cuts_sdm)]:
dev_cuts = cuts.filter(lambda c: "session0" in c.id)
test_cuts = cuts.filter(lambda c: "session0" not in c.id)
# If SDM cuts, apply loudness normalization
if name == "sdm":
dev_cuts = dev_cuts.normalize_loudness(target=-23.0)
test_cuts = test_cuts.normalize_loudness(target=-23.0)
logging.info(f"Extracting fbank features for {name} dev cuts")
_ = dev_cuts.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"libricss-{name}_feats_dev",
manifest_path=src_dir / f"cuts_dev_libricss-{name}.jsonl.gz",
batch_duration=500,
num_workers=2,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
logging.info(f"Extracting fbank features for {name} test cuts")
_ = test_cuts.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / f"libricss-{name}_feats_test",
manifest_path=src_dir / f"cuts_test_libricss-{name}.jsonl.gz",
batch_duration=2000,
num_workers=4,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_libricss()

View File

@ -0,0 +1,111 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the LibriSpeech dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
from pathlib import Path
import torch
from lhotse import CutSet, LilcomChunkyWriter
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
from lhotse.recipes.utils import read_manifests_if_cached
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def compute_fbank_librispeech():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_mel_bins = 80
dataset_parts = (
"train-clean-100",
"train-clean-360",
"train-other-500",
)
prefix = "librispeech"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=16000),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
for partition, m in manifests.items():
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
if (output_dir / cuts_filename).is_file():
logging.info(f"{partition} already exists - skipping.")
continue
logging.info(f"Processing {partition}")
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
cut_set = cut_set.compute_and_store_features_batch(
extractor=extractor,
storage_path=f"{output_dir}/{prefix}_feats_{partition}",
manifest_path=f"{src_dir}/{cuts_filename}",
batch_duration=4000,
num_workers=2,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_librispeech()

View File

@ -0,0 +1,188 @@
#!/usr/bin/env python3
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the synthetically mixed LibriSpeech
train and dev sets.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
import random
import warnings
from pathlib import Path
import torch
import torch.multiprocessing
from lhotse import LilcomChunkyWriter, load_manifest
from lhotse.cut import MixedCut, MixTrack, MultiCut
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
from lhotse.recipes.utils import read_manifests_if_cached
from lhotse.utils import fix_random_seed, uuid4
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
torch.multiprocessing.set_sharing_strategy("file_system")
def compute_fbank_lsmix():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
sampling_rate = 16000
num_mel_bins = 80
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
logging.info("Reading manifests")
manifests = read_manifests_if_cached(
dataset_parts=["train_clean_full", "train_clean_ov40"],
types=["cuts"],
output_dir=src_dir,
prefix="lsmix",
suffix="jsonl.gz",
lazy=True,
)
cs = {}
cs["clean_full"] = manifests["train_clean_full"]["cuts"]
cs["clean_ov40"] = manifests["train_clean_ov40"]["cuts"]
# only uses RIRs and noises from REVERB challenge
real_rirs = load_manifest(src_dir / "real-rir_recordings_all.jsonl.gz").filter(
lambda r: "RVB2014" in r.id
)
noises = load_manifest(src_dir / "iso-noise_recordings_all.jsonl.gz").filter(
lambda r: "RVB2014" in r.id
)
# Apply perturbation to the training cuts
logging.info("Applying perturbation to the training cuts")
cs["rvb_full"] = cs["clean_full"].map(
lambda c: augment(
c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True
)
)
cs["rvb_ov40"] = cs["clean_ov40"].map(
lambda c: augment(
c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True
)
)
for type_affix in ["full", "ov40"]:
for rvb_affix in ["clean", "rvb"]:
logging.info(
f"Extracting fbank features for {type_affix} {rvb_affix} training cuts"
)
cuts = cs[f"{rvb_affix}_{type_affix}"]
with warnings.catch_warnings():
warnings.simplefilter("ignore")
_ = cuts.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir
/ f"lsmix_feats_train_{rvb_affix}_{type_affix}",
manifest_path=src_dir
/ f"cuts_train_{rvb_affix}_{type_affix}.jsonl.gz",
batch_duration=5000,
num_workers=4,
storage_type=LilcomChunkyWriter,
overwrite=True,
)
def augment(cut, perturb_snr=False, rirs=None, noises=None, perturb_loudness=False):
"""
Given a mixed cut, this function optionally applies the following augmentations:
- Perturbing the SNRs of the tracks (in range [-5, 5] dB)
- Reverberation using a randomly selected RIR
- Adding noise
- Perturbing the loudness (in range [-20, -25] dB)
"""
out_cut = cut.drop_features()
# Perturb the SNRs (optional)
if perturb_snr:
snrs = [random.uniform(-5, 5) for _ in range(len(cut.tracks))]
for i, (track, snr) in enumerate(zip(out_cut.tracks, snrs)):
if i == 0:
# Skip the first track since it is the reference
continue
track.snr = snr
# Reverberate the cut (optional)
if rirs is not None:
# Select an RIR at random
rir = random.choice(rirs)
# Select a channel at random
rir_channel = random.choice(list(range(rir.num_channels)))
# Reverberate the cut
out_cut = out_cut.reverb_rir(rir_recording=rir, rir_channels=[rir_channel])
# Add noise (optional)
if noises is not None:
# Select a noise recording at random
noise = random.choice(noises).to_cut()
if isinstance(noise, MultiCut):
noise = noise.to_mono()[0]
# Select an SNR at random
snr = random.uniform(10, 30)
# Repeat the noise to match the duration of the cut
noise = repeat_cut(noise, out_cut.duration)
out_cut = MixedCut(
id=out_cut.id,
tracks=[
MixTrack(cut=out_cut, type="MixedCut"),
MixTrack(cut=noise, type="DataCut", snr=snr),
],
)
# Perturb the loudness (optional)
if perturb_loudness:
target_loudness = random.uniform(-20, -25)
out_cut = out_cut.normalize_loudness(target_loudness, mix_first=True)
return out_cut
def repeat_cut(cut, duration):
while cut.duration < duration:
cut = cut.mix(cut, offset_other_by=cut.duration)
return cut.truncate(duration=duration)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
fix_random_seed(42)
compute_fbank_lsmix()

View File

@ -0,0 +1,114 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file computes fbank features of the musan dataset.
It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import logging
from pathlib import Path
import torch
from lhotse import CutSet, LilcomChunkyWriter, combine
from lhotse.features.kaldifeat import (
KaldifeatFbank,
KaldifeatFbankConfig,
KaldifeatFrameOptions,
KaldifeatMelOptions,
)
from lhotse.recipes.utils import read_manifests_if_cached
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_musan():
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
sampling_rate = 16000
num_mel_bins = 80
dataset_parts = (
"music",
"speech",
"noise",
)
prefix = "musan"
suffix = "jsonl.gz"
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
prefix=prefix,
suffix=suffix,
)
assert manifests is not None
assert len(manifests) == len(dataset_parts), (
len(manifests),
len(dataset_parts),
list(manifests.keys()),
dataset_parts,
)
musan_cuts_path = src_dir / "musan_cuts.jsonl.gz"
if musan_cuts_path.is_file():
logging.info(f"{musan_cuts_path} already exists - skipping")
return
logging.info("Extracting features for Musan")
extractor = KaldifeatFbank(
KaldifeatFbankConfig(
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
device="cuda",
)
)
# create chunks of Musan with duration 5 - 10 seconds
_ = (
CutSet.from_manifests(
recordings=combine(part["recordings"] for part in manifests.values())
)
.cut_into_windows(10.0)
.filter(lambda c: c.duration > 5)
.compute_and_store_features_batch(
extractor=extractor,
storage_path=output_dir / "musan_feats",
manifest_path=musan_cuts_path,
batch_duration=500,
num_workers=4,
storage_type=LilcomChunkyWriter,
)
)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
compute_fbank_musan()

204
egs/libricss/SURT/prepare.sh Executable file
View File

@ -0,0 +1,204 @@
#!/usr/bin/env bash
set -eou pipefail
stage=-1
stop_stage=100
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/librispeech
# You can find audio and transcripts for LibriSpeech in this path.
#
# - $dl_dir/libricss
# You can find audio and transcripts for LibriCSS in this path.
#
# - $dl_dir/musan
# This directory contains the following directories downloaded from
# http://www.openslr.org/17/
#
# - music
# - noise
# - speech
#
# - $dl_dir/rirs_noises
# This directory contains the RIRS_NOISES corpus downloaded from https://openslr.org/28/.
#
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
vocab_size=500
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/librispeech,
# you can create a symlink
#
# ln -sfv /path/to/librispeech $dl_dir/librispeech
#
if [ ! -d $dl_dir/librispeech ]; then
lhotse download librispeech $dl_dir/librispeech
fi
# If you have pre-downloaded it to /path/to/libricss,
# you can create a symlink
#
# ln -sfv /path/to/libricss $dl_dir/libricss
#
if [ ! -d $dl_dir/libricss ]; then
lhotse download libricss $dl_dir/libricss
fi
# If you have pre-downloaded it to /path/to/musan,
# you can create a symlink
#
# ln -sfv /path/to/musan $dl_dir/
#
if [ ! -d $dl_dir/musan ]; then
lhotse download musan $dl_dir
fi
# If you have pre-downloaded it to /path/to/rirs_noises,
# you can create a symlink
#
# ln -sfv /path/to/rirs_noises $dl_dir/
#
if [ ! -d $dl_dir/rirs_noises ]; then
lhotse download rirs_noises $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare LibriSpeech manifests"
# We assume that you have downloaded the LibriSpeech corpus
# to $dl_dir/librispeech. We perform text normalization for the transcripts.
# NOTE: Alignments are required for this recipe.
mkdir -p data/manifests
lhotse prepare librispeech -p train-clean-100 -p train-clean-360 -p train-other-500 -p dev-clean \
-j 4 --alignments-dir $dl_dir/libri_alignments/LibriSpeech $dl_dir/librispeech data/manifests/
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Prepare LibriCSS manifests"
# We assume that you have downloaded the LibriCSS corpus
# to $dl_dir/libricss. We perform text normalization for the transcripts.
mkdir -p data/manifests
for mic in sdm ihm-mix; do
lhotse prepare libricss --type $mic --segmented $dl_dir/libricss data/manifests/
done
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare musan manifest and RIRs"
# We assume that you have downloaded the musan corpus
# to $dl_dir/musan
mkdir -p data/manifests
lhotse prepare musan $dl_dir/musan data/manifests
# We assume that you have downloaded the RIRS_NOISES corpus
# to $dl_dir/rirs_noises
lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises data/manifests
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Extract features for LibriSpeech, trim to alignments, and shuffle the cuts"
python local/compute_fbank_librispeech.py
lhotse combine data/manifests/librispeech_cuts_train* - |\
lhotse cut trim-to-alignments --type word --max-pause 0.2 - - |\
shuf | gzip -c > data/manifests/librispeech_cuts_train_trimmed.jsonl.gz
fi
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
log "Stage 5: Create simulated mixtures from LibriSpeech (train and dev). This may take a while."
# We create a high overlap set which will be used during the model warmup phase, and a
# full training set that will be used for the subsequent training.
gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\
grep -v "0L" | grep -v "OV10" |\
gzip -c > data/manifests/libricss-sdm_supervisions_all_v1.jsonl.gz
gunzip -c data/manifests/libricss-sdm_supervisions_all.jsonl.gz |\
grep "OV40" |\
gzip -c > data/manifests/libricss-sdm_supervisions_ov40.jsonl.gz
# Warmup mixtures (100k) based on high overlap (OV40)
log "Generating 100k anechoic train mixtures for warmup"
lhotse workflows simulate-meetings \
--method conversational \
--fit-to-supervisions data/manifests/libricss-sdm_supervisions_ov40.jsonl.gz \
--num-meetings 100000 \
--num-speakers-per-meeting 2,3 \
--max-duration-per-speaker 15.0 \
--max-utterances-per-speaker 3 \
--seed 1234 \
--num-jobs 4 \
data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \
data/manifests/lsmix_cuts_train_clean_ov40.jsonl.gz
# Full training set (2,3 speakers) anechoic
log "Generating anechoic ${part} set (full)"
lhotse workflows simulate-meetings \
--method conversational \
--fit-to-supervisions data/manifests/libricss-sdm_supervisions_all_v1.jsonl.gz \
--num-repeats 1 \
--num-speakers-per-meeting 2,3 \
--max-duration-per-speaker 15.0 \
--max-utterances-per-speaker 3 \
--seed 1234 \
--num-jobs 4 \
data/manifests/librispeech_cuts_train_trimmed.jsonl.gz \
data/manifests/lsmix_cuts_train_clean_full.jsonl.gz
fi
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
log "Stage 6: Compute fbank features for musan"
mkdir -p data/fbank
python local/compute_fbank_musan.py
fi
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
log "Stage 7: Compute fbank features for simulated Libri-mix"
mkdir -p data/fbank
python local/compute_fbank_lsmix.py
fi
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
log "Stage 8: Add source feats to mixtures (useful for auxiliary tasks)"
python local/add_source_feats.py
log "Combining lsmix-clean and lsmix-rvb"
for type in full ov40; do
cat <(gunzip -c data/manifests/cuts_train_clean_${type}_sources.jsonl.gz) \
<(gunzip -c data/manifests/cuts_train_rvb_${type}_sources.jsonl.gz) |\
shuf | gzip -c > data/manifests/cuts_train_comb_${type}_sources.jsonl.gz
done
fi
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
log "Stage 9: Compute fbank features for LibriCSS"
mkdir -p data/fbank
python local/compute_fbank_libricss.py
fi
if [ $stage -le 10 ] && [ $stop_stage -ge 10 ]; then
log "Stage 10: Download LibriSpeech BPE model from HuggingFace."
mkdir -p data/lang_bpe_500
pushd data/lang_bpe_500
wget https://huggingface.co/Zengwei/icefall-asr-librispeech-pruned-transducer-stateless7-streaming-2022-12-29/resolve/main/data/lang_bpe_500/bpe.model
popd
fi

1
egs/libricss/SURT/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared

BIN
egs/libricss/SURT/surt.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

View File

@ -429,6 +429,8 @@ def store_transcripts(
texts:
An iterable of tuples. The first element is the cur_id, the second is
the reference transcript and the third element is the predicted result.
If it is a multi-talker ASR system, the ref and hyp may also be lists of
strings.
Returns:
Return None.
"""
@ -886,8 +888,167 @@ def write_error_stats_with_timestamps(
hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate), float(mean_delay), float(var_delay)
return tot_err_rate, mean_delay, var_delay
def write_surt_error_stats(
f: TextIO,
test_set_name: str,
results: List[Tuple[str, str]],
enable_log: bool = True,
num_channels: int = 2,
) -> float:
"""Write statistics based on predicted results and reference transcripts for SURT
multi-talker ASR systems. The difference between this and the `write_error_stats`
is that this function finds the optimal speaker-agnostic WER using the ``meeteval``
toolkit.
Args:
f: File to write the statistics to.
test_set_name: Name of the test set.
results: List of tuples containing the utterance ID and the predicted
transcript.
enable_log: Whether to enable logging.
num_channels: Number of output channels/branches. Defaults to 2.
Returns:
Return None.
"""
from meeteval.wer import wer
subs: Dict[Tuple[str, str], int] = defaultdict(int)
ins: Dict[str, int] = defaultdict(int)
dels: Dict[str, int] = defaultdict(int)
ref_lens: List[int] = []
print(
"Search below for sections starting with PER-UTT DETAILS:, "
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
file=f,
)
print("", file=f)
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
# `words` stores counts per word, as follows:
# corr, ref_sub, hyp_sub, ins, dels
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
num_corr = 0
ERR = "*"
for cut_id, ref, hyp in results:
# First compute the optimal assignment of references to output channels
orc_wer = wer.orc_word_error_rate(ref, hyp)
assignment = orc_wer.assignment
refs = [[] for _ in range(num_channels)]
# Assign references to channels
for i, ref_text in zip(assignment, ref):
refs[i] += ref_text.split()
hyps = [hyp_text.split() for hyp_text in hyp]
# Now compute the WER for each channel
for ref_c, hyp_c in zip(refs, hyps):
ref_lens.append(len(ref_c))
ali = kaldialign.align(ref_c, hyp_c, ERR)
for ref_word, hyp_word in ali:
if ref_word == ERR:
ins[hyp_word] += 1
words[hyp_word][3] += 1
elif hyp_word == ERR:
dels[ref_word] += 1
words[ref_word][4] += 1
elif hyp_word != ref_word:
subs[(ref_word, hyp_word)] += 1
words[ref_word][1] += 1
words[hyp_word][2] += 1
else:
words[ref_word][0] += 1
num_corr += 1
combine_successive_errors = True
if combine_successive_errors:
ali = [[[x], [y]] for x, y in ali]
for i in range(len(ali) - 1):
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
ali[i] = [[], []]
ali = [
[
list(filter(lambda a: a != ERR, x)),
list(filter(lambda a: a != ERR, y)),
]
for x, y in ali
]
ali = list(filter(lambda x: x != [[], []], ali))
ali = [
[
ERR if x == [] else " ".join(x),
ERR if y == [] else " ".join(y),
]
for x, y in ali
]
print(
f"{cut_id}:\t"
+ " ".join(
(
ref_word
if ref_word == hyp_word
else f"({ref_word}->{hyp_word})"
for ref_word, hyp_word in ali
)
),
file=f,
)
ref_len = sum(ref_lens)
sub_errs = sum(subs.values())
ins_errs = sum(ins.values())
del_errs = sum(dels.values())
tot_errs = sub_errs + ins_errs + del_errs
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
if enable_log:
logging.info(
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
f"{del_errs} del, {sub_errs} sub ]"
)
print(f"%WER = {tot_err_rate}", file=f)
print(
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
f"{sub_errs} substitutions, over {ref_len} reference "
f"words ({num_corr} correct)",
file=f,
)
print("", file=f)
print("SUBSTITUTIONS: count ref -> hyp", file=f)
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
print(f"{count} {ref} -> {hyp}", file=f)
print("", file=f)
print("DELETIONS: count ref", file=f)
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
print(f"{count} {ref}", file=f)
print("", file=f)
print("INSERTIONS: count hyp", file=f)
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
print(f"{count} {hyp}", file=f)
print("", file=f)
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
for _, word, counts in sorted(
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
):
(corr, ref_sub, hyp_sub, ins, dels) = counts
tot_errs = ref_sub + hyp_sub + ins + dels
ref_count = corr + ref_sub + dels
hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
print(f"%WER = {tot_err_rate}", file=f)
return float(tot_err_rate)
class MetricsTracker(collections.defaultdict):