mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
add train + decode scripts
This commit is contained in:
parent
8d70a2aeca
commit
d80ed9377f
156
egs/ami/SURT/README.md
Normal file
156
egs/ami/SURT/README.md
Normal file
@ -0,0 +1,156 @@
|
||||
# Introduction
|
||||
|
||||
This is a multi-talker ASR recipe for the AMI and ICSI datasets. We train a Streaming
|
||||
Unmixing and Recognition Transducer (SURT) model for the task.
|
||||
|
||||
Please refer to the `egs/libricss/SURT` recipe README for details about the task and the
|
||||
model.
|
||||
|
||||
## Description of the recipe
|
||||
|
||||
### Pre-requisites
|
||||
|
||||
The recipes in this directory need the following packages to be installed:
|
||||
|
||||
- [meeteval](https://github.com/fgnt/meeteval)
|
||||
- [einops](https://github.com/arogozhnikov/einops)
|
||||
|
||||
Additionally, we initialize the model with the pre-trained model from the LibriCSS recipe.
|
||||
Please download this checkpoint (see below) or train the LibriCSS recipe first.
|
||||
|
||||
### Training
|
||||
|
||||
To train the model, run the following from within `egs/ami/SURT`:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
python dprnn_zipformer/train.py \
|
||||
--use-fp16 True \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base \
|
||||
--world-size 4 \
|
||||
--max-duration 500 \
|
||||
--max-duration-valid 250 \
|
||||
--max-cuts 200 \
|
||||
--num-buckets 50 \
|
||||
--num-epochs 30 \
|
||||
--enable-spec-aug True \
|
||||
--enable-musan False \
|
||||
--ctc-loss-scale 0.2 \
|
||||
--heat-loss-scale 0.2 \
|
||||
--base-lr 0.004 \
|
||||
--model-init-ckpt exp/libricss_base.pt \
|
||||
--chunk-width-randomization True \
|
||||
--num-mask-encoder-layers 4 \
|
||||
--num-encoder-layers 2,2,2,2,2
|
||||
```
|
||||
|
||||
The above is for SURT-base (~26M). For SURT-large (~38M), use:
|
||||
|
||||
```bash
|
||||
--model-init-ckpt exp/libricss_large.pt \
|
||||
--num-mask-encoder-layers 6 \
|
||||
--num-encoder-layers 2,4,3,2,4 \
|
||||
--model-init-ckpt exp/zipformer_large.pt \
|
||||
```
|
||||
|
||||
**NOTE:** You may need to decrease the `--max-duration` for SURT-large to avoid OOM.
|
||||
|
||||
### Adaptation
|
||||
|
||||
The training step above only trains on simulated mixtures. For best results, we also
|
||||
adapt the final model on the AMI+ICSI train set. For this, run the following from within
|
||||
`egs/ami/SURT`:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
python dprnn_zipformer/train_adapt.py \
|
||||
--use-fp16 True \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--world-size 4 \
|
||||
--max-duration 500 \
|
||||
--max-duration-valid 250 \
|
||||
--max-cuts 200 \
|
||||
--num-buckets 50 \
|
||||
--num-epochs 8 \
|
||||
--lr-epochs 2 \
|
||||
--enable-spec-aug True \
|
||||
--enable-musan False \
|
||||
--ctc-loss-scale 0.2 \
|
||||
--base-lr 0.0004 \
|
||||
--model-init-ckpt dprnn_zipformer/exp/surt_base/epoch-30.pt \
|
||||
--chunk-width-randomization True \
|
||||
--num-mask-encoder-layers 4 \
|
||||
--num-encoder-layers 2,2,2,2,2
|
||||
```
|
||||
|
||||
For SURT-large, use the following config:
|
||||
|
||||
```bash
|
||||
--num-mask-encoder-layers 6 \
|
||||
--num-encoder-layers 2,4,3,2,4 \
|
||||
--model-init-ckpt dprnn_zipformer/exp/surt_large/epoch-30.pt \
|
||||
--num-epochs 15 \
|
||||
--lr-epochs 4 \
|
||||
```
|
||||
|
||||
|
||||
### Decoding
|
||||
|
||||
To decode the model, run the following from within `egs/ami/SURT`:
|
||||
|
||||
#### Greedy search
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
python dprnn_zipformer/decode.py \
|
||||
--epoch 20 --avg 1 --use-averaged-model False \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--max-duration 250 \
|
||||
--decoding-method greedy_search
|
||||
```
|
||||
|
||||
#### Beam search
|
||||
|
||||
```bash
|
||||
python dprnn_zipformer/decode.py \
|
||||
--epoch 20 --avg 1 --use-averaged-model False \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--max-duration 250 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
```
|
||||
|
||||
## Results (using beam search)
|
||||
|
||||
**AMI**
|
||||
|
||||
| Model | IHM-Mix | SDM | MDM |
|
||||
|------------|:-------:|:----:|:----:|
|
||||
| SURT-base | 39.8 | 65.4 | 46.6 |
|
||||
| + adapt | 37.4 | 46.9 | 43.7 |
|
||||
| SURT-large | 36.8 | 62.5 | 44.4 |
|
||||
| + adapt | **35.1** | **44.6** | **41.4** |
|
||||
|
||||
**ICSI**
|
||||
|
||||
| Model | IHM-Mix | SDM |
|
||||
|------------|:-------:|:----:|
|
||||
| SURT-base | 28.3 | 60.0 |
|
||||
| + adapt | 26.3 | 33.9 |
|
||||
| SURT-large | 27.8 | 59.7 |
|
||||
| + adapt | **24.4** | **32.3** |
|
||||
|
||||
## Pre-trained models and logs
|
||||
|
||||
* LibriCSS pre-trained model (for initialization): [base](https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer/tree/main/exp/surt_base) [large](https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer/tree/main/exp/surt_large)
|
||||
|
||||
* Pre-trained models: <https://huggingface.co/desh2608/icefall-surt-ami-dprnn-zipformer>
|
||||
|
||||
* Training logs:
|
||||
- surt_base: <https://tensorboard.dev/experiment/8awy98VZSWegLmH4l2JWSA/>
|
||||
- surt_base_adapt: <https://tensorboard.dev/experiment/aGVgXVzYRDKbGUbPekcNjg/>
|
||||
- surt_large: <https://tensorboard.dev/experiment/ZXMkez0VSYKbPLqRk4clOQ/>
|
||||
- surt_large_adapt: <https://tensorboard.dev/experiment/WLKL1e7bTVyEjSonYSNYwg/>
|
399
egs/ami/SURT/dprnn_zipformer/asr_datamodule.py
Normal file
399
egs/ami/SURT/dprnn_zipformer/asr_datamodule.py
Normal file
@ -0,0 +1,399 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SurtDataset,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class AmiAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 SURT experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/manifests"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration-valid",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-cuts",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of cuts in a single batch. You can "
|
||||
"reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help=(
|
||||
"When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available."
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
sources: bool = False,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=sources,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
quadratic_duration=30.0,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
validate = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=False,
|
||||
strict=False,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration_valid,
|
||||
quadratic_duration=30.0,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=False,
|
||||
strict=False,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration_valid,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def aimix_train_cuts(
|
||||
self,
|
||||
rvb_affix: str = "clean",
|
||||
sources: bool = True,
|
||||
) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
source_affix = "_sources" if sources else ""
|
||||
cs = load_manifest_lazy(
|
||||
self.args.manifest_dir / f"cuts_train_{rvb_affix}{source_affix}.jsonl.gz"
|
||||
)
|
||||
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0)
|
||||
return cs
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(
|
||||
self,
|
||||
) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_train_ami_icsi.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def ami_cuts(self, split: str = "dev", type: str = "sdm") -> CutSet:
|
||||
logging.info(f"About to get AMI {split} {type} cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / f"cuts_ami-{type}_{split}.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def icsi_cuts(self, split: str = "dev", type: str = "sdm") -> CutSet:
|
||||
logging.info(f"About to get ICSI {split} {type} cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / f"cuts_icsi-{type}_{split}.jsonl.gz"
|
||||
)
|
1
egs/ami/SURT/dprnn_zipformer/beam_search.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/beam_search.py
|
622
egs/ami/SURT/dprnn_zipformer/decode.py
Executable file
622
egs/ami/SURT/dprnn_zipformer/decode.py
Executable file
@ -0,0 +1,622 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./dprnn_zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--avg 1 \
|
||||
--use-averaged-model false \
|
||||
--exp-dir ./dprnn_zipformer/exp_adapt \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./dprnn_zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--avg 1 \
|
||||
--use-averaged-model false \
|
||||
--exp-dir ./dprnn_zipformer/exp_adapt \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./dprnn_zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--avg 1 \
|
||||
--use-averaged-model false \
|
||||
--exp-dir ./dprnn_zipformer/exp_adapt \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AmiAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from lhotse.utils import EPSILON
|
||||
from train import add_model_arguments, get_params, get_surt_model
|
||||
|
||||
from icefall import LmScorer, NgramLm
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_surt_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=20,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="dprnn_zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
feature_lens = batch["input_lens"].to(device)
|
||||
|
||||
# Apply the mask encoder
|
||||
B, T, F = feature.shape
|
||||
processed = model.mask_encoder(feature) # B,T,F*num_channels
|
||||
masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1)
|
||||
x_masked = [feature * m for m in masks]
|
||||
|
||||
# Recognition
|
||||
# Stack the inputs along the batch axis
|
||||
h = torch.cat(x_masked, dim=0)
|
||||
h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
|
||||
|
||||
if model.joint_encoder_layer is not None:
|
||||
encoder_out = model.joint_encoder_layer(encoder_out)
|
||||
|
||||
def _group_channels(hyps: List[str]) -> List[List[str]]:
|
||||
"""
|
||||
Currently we have a batch of size M*B, where M is the number of
|
||||
channels and B is the batch size. We need to group the hypotheses
|
||||
into B groups, each of which contains M hypotheses.
|
||||
|
||||
Example:
|
||||
hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2']
|
||||
_group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']]
|
||||
"""
|
||||
assert len(hyps) == B * params.num_channels
|
||||
out_hyps = []
|
||||
for i in range(B):
|
||||
out_hyps.append(hyps[i::B])
|
||||
return out_hyps
|
||||
|
||||
hyps = []
|
||||
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp)
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp))
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": _group_channels(hyps)}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: _group_channels(hyps)}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": _group_channels(hyps)}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
cut_ids = [cut.id for cut in batch["cuts"]]
|
||||
cuts_batch = batch["cuts"]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
for cut_id, hyp_words in zip(cut_ids, hyps):
|
||||
# Reference is a list of supervision texts sorted by start time.
|
||||
ref_words = [
|
||||
s.text.strip()
|
||||
for s in sorted(
|
||||
cuts_batch[cut_id].supervisions, key=lambda s: s.start
|
||||
)
|
||||
]
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(cut_ids)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_surt_error_stats(
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
results,
|
||||
enable_log=True,
|
||||
num_channels=params.num_channels,
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LmScorer.add_arguments(parser)
|
||||
AmiAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"modified_beam_search",
|
||||
), f"Decoding method {params.decoding_method} is not supported."
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_surt_model(params)
|
||||
assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
|
||||
model.encoder.decode_chunk_size,
|
||||
params.decode_chunk_len,
|
||||
)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
ami = AmiAsrDataModule(args)
|
||||
|
||||
# NOTE(@desh2608): we filter segments longer than 120s to avoid OOM errors in decoding.
|
||||
# However, 99.9% of the segments are shorter than 120s, so this should not
|
||||
# substantially affect the results. In future, we will implement an overlapped
|
||||
# inference method to avoid OOM errors.
|
||||
|
||||
test_sets = {}
|
||||
for split in ["dev", "test"]:
|
||||
for type in ["ihm-mix", "sdm", "mdm8-bf"]:
|
||||
test_sets[f"ami-{split}_{type}"] = (
|
||||
ami.ami_cuts(split=split, type=type)
|
||||
.trim_to_supervision_groups(max_pause=0.0)
|
||||
.filter(lambda c: 0.1 < c.duration < 120.0)
|
||||
.to_eager()
|
||||
)
|
||||
|
||||
for split in ["dev", "test"]:
|
||||
for type in ["ihm-mix", "sdm"]:
|
||||
test_sets[f"icsi-{split}_{type}"] = (
|
||||
ami.icsi_cuts(split=split, type=type)
|
||||
.trim_to_supervision_groups(max_pause=0.0)
|
||||
.filter(lambda c: 0.1 < c.duration < 120.0)
|
||||
.to_eager()
|
||||
)
|
||||
|
||||
for test_set, test_cuts in test_sets.items():
|
||||
test_dl = ami.test_dataloaders(test_cuts)
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/ami/SURT/dprnn_zipformer/decoder.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/decoder.py
|
1
egs/ami/SURT/dprnn_zipformer/encoder_interface.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/encoder_interface.py
|
1
egs/ami/SURT/dprnn_zipformer/export.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/export.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/export.py
|
1
egs/ami/SURT/dprnn_zipformer/joiner.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/joiner.py
|
1
egs/ami/SURT/dprnn_zipformer/model.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/model.py
|
1
egs/ami/SURT/dprnn_zipformer/optim.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/optim.py
|
1
egs/ami/SURT/dprnn_zipformer/scaling.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/scaling.py
|
1
egs/ami/SURT/dprnn_zipformer/scaling_converter.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/scaling_converter.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/scaling_converter.py
|
150
egs/ami/SURT/dprnn_zipformer/test_model.py
Executable file
150
egs/ami/SURT/dprnn_zipformer/test_model.py
Executable file
@ -0,0 +1,150 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
To run this file, do:
|
||||
|
||||
cd icefall/egs/librispeech/ASR
|
||||
python ./pruned_transducer_stateless7_streaming/test_model.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from scaling_converter import convert_scaled_to_non_scaled
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
|
||||
def test_model():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = "2,4,3,2,4"
|
||||
params.feedforward_dims = "1024,1024,2048,2048,1024"
|
||||
params.nhead = "8,8,8,8,8"
|
||||
params.encoder_dims = "384,384,384,384,384"
|
||||
params.attention_dims = "192,192,192,192,192"
|
||||
params.encoder_unmasked_dims = "256,256,256,256,256"
|
||||
params.zipformer_downsampling_factors = "1,2,4,8,2"
|
||||
params.cnn_module_kernels = "31,31,31,31,31"
|
||||
params.decoder_dim = 512
|
||||
params.joiner_dim = 512
|
||||
params.num_left_chunks = 4
|
||||
params.short_chunk_size = 50
|
||||
params.decode_chunk_len = 32
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
# Test jit script
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
print("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
|
||||
|
||||
def test_model_jit_trace():
|
||||
params = get_params()
|
||||
params.vocab_size = 500
|
||||
params.blank_id = 0
|
||||
params.context_size = 2
|
||||
params.num_encoder_layers = "2,4,3,2,4"
|
||||
params.feedforward_dims = "1024,1024,2048,2048,1024"
|
||||
params.nhead = "8,8,8,8,8"
|
||||
params.encoder_dims = "384,384,384,384,384"
|
||||
params.attention_dims = "192,192,192,192,192"
|
||||
params.encoder_unmasked_dims = "256,256,256,256,256"
|
||||
params.zipformer_downsampling_factors = "1,2,4,8,2"
|
||||
params.cnn_module_kernels = "31,31,31,31,31"
|
||||
params.decoder_dim = 512
|
||||
params.joiner_dim = 512
|
||||
params.num_left_chunks = 4
|
||||
params.short_chunk_size = 50
|
||||
params.decode_chunk_len = 32
|
||||
model = get_transducer_model(params)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
convert_scaled_to_non_scaled(model, inplace=True)
|
||||
|
||||
# Test encoder
|
||||
def _test_encoder():
|
||||
encoder = model.encoder
|
||||
assert encoder.decode_chunk_size == params.decode_chunk_len // 2, (
|
||||
encoder.decode_chunk_size,
|
||||
params.decode_chunk_len,
|
||||
)
|
||||
T = params.decode_chunk_len + 7
|
||||
|
||||
x = torch.zeros(1, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.full((1,), T, dtype=torch.int32)
|
||||
states = encoder.get_init_state(device=x.device)
|
||||
encoder.__class__.forward = encoder.__class__.streaming_forward
|
||||
traced_encoder = torch.jit.trace(encoder, (x, x_lens, states))
|
||||
|
||||
states1 = encoder.get_init_state(device=x.device)
|
||||
states2 = traced_encoder.get_init_state(device=x.device)
|
||||
for i in range(5):
|
||||
x = torch.randn(1, T, 80, dtype=torch.float32)
|
||||
x_lens = torch.full((1,), T, dtype=torch.int32)
|
||||
y1, _, states1 = encoder.streaming_forward(x, x_lens, states1)
|
||||
y2, _, states2 = traced_encoder(x, x_lens, states2)
|
||||
assert torch.allclose(y1, y2, atol=1e-6), (i, (y1 - y2).abs().mean())
|
||||
|
||||
# Test decoder
|
||||
def _test_decoder():
|
||||
decoder = model.decoder
|
||||
y = torch.zeros(10, decoder.context_size, dtype=torch.int64)
|
||||
need_pad = torch.tensor([False])
|
||||
|
||||
traced_decoder = torch.jit.trace(decoder, (y, need_pad))
|
||||
d1 = decoder(y, need_pad)
|
||||
d2 = traced_decoder(y, need_pad)
|
||||
assert torch.equal(d1, d2), (d1 - d2).abs().mean()
|
||||
|
||||
# Test joiner
|
||||
def _test_joiner():
|
||||
joiner = model.joiner
|
||||
encoder_out_dim = joiner.encoder_proj.weight.shape[1]
|
||||
decoder_out_dim = joiner.decoder_proj.weight.shape[1]
|
||||
encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32)
|
||||
decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32)
|
||||
|
||||
traced_joiner = torch.jit.trace(joiner, (encoder_out, decoder_out))
|
||||
j1 = joiner(encoder_out, decoder_out)
|
||||
j2 = traced_joiner(encoder_out, decoder_out)
|
||||
assert torch.equal(j1, j2), (j1 - j2).abs().mean()
|
||||
|
||||
_test_encoder()
|
||||
_test_decoder()
|
||||
_test_joiner()
|
||||
|
||||
|
||||
def main():
|
||||
test_model()
|
||||
test_model_jit_trace()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1420
egs/ami/SURT/dprnn_zipformer/train.py
Executable file
1420
egs/ami/SURT/dprnn_zipformer/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1411
egs/ami/SURT/dprnn_zipformer/train_adapt.py
Executable file
1411
egs/ami/SURT/dprnn_zipformer/train_adapt.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/ami/SURT/dprnn_zipformer/zipformer.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/zipformer.py
|
Loading…
x
Reference in New Issue
Block a user