mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
SURT recipe for AMI and ICSI (#1133)
* merge upstream * add SURT model and training * add libricss decoding * add chunk width randomization * decode SURT with libricss * initial commit for zipformer_ctc * remove unwanted changes * remove changes to other recipe * fix zipformer softlink * fix for JIT export * add missing file * fix symbolic links * update results * clean commit for SURT recipe * training libricss surt model * remove unwanted files * remove unwanted changes * remove changes in librispeech * change some files to symlinks * remove unwanted changes in utils * add export script * add README * minor fix in README * add assets for README * replace some files with symlinks * remove unused decoding methods * initial commit for SURT AMI recipe * fix symlink * add train + decode scripts * add missing symlink * change files to symlink * change file type
This commit is contained in:
parent
ffe816e2a8
commit
41b16d7838
156
egs/ami/SURT/README.md
Normal file
156
egs/ami/SURT/README.md
Normal file
@ -0,0 +1,156 @@
|
||||
# Introduction
|
||||
|
||||
This is a multi-talker ASR recipe for the AMI and ICSI datasets. We train a Streaming
|
||||
Unmixing and Recognition Transducer (SURT) model for the task.
|
||||
|
||||
Please refer to the `egs/libricss/SURT` recipe README for details about the task and the
|
||||
model.
|
||||
|
||||
## Description of the recipe
|
||||
|
||||
### Pre-requisites
|
||||
|
||||
The recipes in this directory need the following packages to be installed:
|
||||
|
||||
- [meeteval](https://github.com/fgnt/meeteval)
|
||||
- [einops](https://github.com/arogozhnikov/einops)
|
||||
|
||||
Additionally, we initialize the model with the pre-trained model from the LibriCSS recipe.
|
||||
Please download this checkpoint (see below) or train the LibriCSS recipe first.
|
||||
|
||||
### Training
|
||||
|
||||
To train the model, run the following from within `egs/ami/SURT`:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
|
||||
python dprnn_zipformer/train.py \
|
||||
--use-fp16 True \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base \
|
||||
--world-size 4 \
|
||||
--max-duration 500 \
|
||||
--max-duration-valid 250 \
|
||||
--max-cuts 200 \
|
||||
--num-buckets 50 \
|
||||
--num-epochs 30 \
|
||||
--enable-spec-aug True \
|
||||
--enable-musan False \
|
||||
--ctc-loss-scale 0.2 \
|
||||
--heat-loss-scale 0.2 \
|
||||
--base-lr 0.004 \
|
||||
--model-init-ckpt exp/libricss_base.pt \
|
||||
--chunk-width-randomization True \
|
||||
--num-mask-encoder-layers 4 \
|
||||
--num-encoder-layers 2,2,2,2,2
|
||||
```
|
||||
|
||||
The above is for SURT-base (~26M). For SURT-large (~38M), use:
|
||||
|
||||
```bash
|
||||
--model-init-ckpt exp/libricss_large.pt \
|
||||
--num-mask-encoder-layers 6 \
|
||||
--num-encoder-layers 2,4,3,2,4 \
|
||||
--model-init-ckpt exp/zipformer_large.pt \
|
||||
```
|
||||
|
||||
**NOTE:** You may need to decrease the `--max-duration` for SURT-large to avoid OOM.
|
||||
|
||||
### Adaptation
|
||||
|
||||
The training step above only trains on simulated mixtures. For best results, we also
|
||||
adapt the final model on the AMI+ICSI train set. For this, run the following from within
|
||||
`egs/ami/SURT`:
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
python dprnn_zipformer/train_adapt.py \
|
||||
--use-fp16 True \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--world-size 4 \
|
||||
--max-duration 500 \
|
||||
--max-duration-valid 250 \
|
||||
--max-cuts 200 \
|
||||
--num-buckets 50 \
|
||||
--num-epochs 8 \
|
||||
--lr-epochs 2 \
|
||||
--enable-spec-aug True \
|
||||
--enable-musan False \
|
||||
--ctc-loss-scale 0.2 \
|
||||
--base-lr 0.0004 \
|
||||
--model-init-ckpt dprnn_zipformer/exp/surt_base/epoch-30.pt \
|
||||
--chunk-width-randomization True \
|
||||
--num-mask-encoder-layers 4 \
|
||||
--num-encoder-layers 2,2,2,2,2
|
||||
```
|
||||
|
||||
For SURT-large, use the following config:
|
||||
|
||||
```bash
|
||||
--num-mask-encoder-layers 6 \
|
||||
--num-encoder-layers 2,4,3,2,4 \
|
||||
--model-init-ckpt dprnn_zipformer/exp/surt_large/epoch-30.pt \
|
||||
--num-epochs 15 \
|
||||
--lr-epochs 4 \
|
||||
```
|
||||
|
||||
|
||||
### Decoding
|
||||
|
||||
To decode the model, run the following from within `egs/ami/SURT`:
|
||||
|
||||
#### Greedy search
|
||||
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
|
||||
python dprnn_zipformer/decode.py \
|
||||
--epoch 20 --avg 1 --use-averaged-model False \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--max-duration 250 \
|
||||
--decoding-method greedy_search
|
||||
```
|
||||
|
||||
#### Beam search
|
||||
|
||||
```bash
|
||||
python dprnn_zipformer/decode.py \
|
||||
--epoch 20 --avg 1 --use-averaged-model False \
|
||||
--exp-dir dprnn_zipformer/exp/surt_base_adapt \
|
||||
--max-duration 250 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
```
|
||||
|
||||
## Results (using beam search)
|
||||
|
||||
**AMI**
|
||||
|
||||
| Model | IHM-Mix | SDM | MDM |
|
||||
|------------|:-------:|:----:|:----:|
|
||||
| SURT-base | 39.8 | 65.4 | 46.6 |
|
||||
| + adapt | 37.4 | 46.9 | 43.7 |
|
||||
| SURT-large | 36.8 | 62.5 | 44.4 |
|
||||
| + adapt | **35.1** | **44.6** | **41.4** |
|
||||
|
||||
**ICSI**
|
||||
|
||||
| Model | IHM-Mix | SDM |
|
||||
|------------|:-------:|:----:|
|
||||
| SURT-base | 28.3 | 60.0 |
|
||||
| + adapt | 26.3 | 33.9 |
|
||||
| SURT-large | 27.8 | 59.7 |
|
||||
| + adapt | **24.4** | **32.3** |
|
||||
|
||||
## Pre-trained models and logs
|
||||
|
||||
* LibriCSS pre-trained model (for initialization): [base](https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer/tree/main/exp/surt_base) [large](https://huggingface.co/desh2608/icefall-surt-libricss-dprnn-zipformer/tree/main/exp/surt_large)
|
||||
|
||||
* Pre-trained models: <https://huggingface.co/desh2608/icefall-surt-ami-dprnn-zipformer>
|
||||
|
||||
* Training logs:
|
||||
- surt_base: <https://tensorboard.dev/experiment/8awy98VZSWegLmH4l2JWSA/>
|
||||
- surt_base_adapt: <https://tensorboard.dev/experiment/aGVgXVzYRDKbGUbPekcNjg/>
|
||||
- surt_large: <https://tensorboard.dev/experiment/ZXMkez0VSYKbPLqRk4clOQ/>
|
||||
- surt_large_adapt: <https://tensorboard.dev/experiment/WLKL1e7bTVyEjSonYSNYwg/>
|
399
egs/ami/SURT/dprnn_zipformer/asr_datamodule.py
Normal file
399
egs/ami/SURT/dprnn_zipformer/asr_datamodule.py
Normal file
@ -0,0 +1,399 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SurtDataset,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class AmiAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 SURT experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/manifests"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration-valid",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-cuts",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Maximum number of cuts in a single batch. You can "
|
||||
"reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help=(
|
||||
"When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available."
|
||||
),
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
sources: bool = False,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=sources,
|
||||
strict=False,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
quadratic_duration=30.0,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
validate = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=False,
|
||||
strict=False,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration_valid,
|
||||
quadratic_duration=30.0,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SurtDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
return_sources=False,
|
||||
strict=False,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration_valid,
|
||||
max_cuts=self.args.max_cuts,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def aimix_train_cuts(
|
||||
self,
|
||||
rvb_affix: str = "clean",
|
||||
sources: bool = True,
|
||||
) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
source_affix = "_sources" if sources else ""
|
||||
cs = load_manifest_lazy(
|
||||
self.args.manifest_dir / f"cuts_train_{rvb_affix}{source_affix}.jsonl.gz"
|
||||
)
|
||||
cs = cs.filter(lambda c: c.duration >= 1.0 and c.duration <= 30.0)
|
||||
return cs
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(
|
||||
self,
|
||||
) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "cuts_train_ami_icsi.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def ami_cuts(self, split: str = "dev", type: str = "sdm") -> CutSet:
|
||||
logging.info(f"About to get AMI {split} {type} cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / f"cuts_ami-{type}_{split}.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def icsi_cuts(self, split: str = "dev", type: str = "sdm") -> CutSet:
|
||||
logging.info(f"About to get ICSI {split} {type} cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / f"cuts_icsi-{type}_{split}.jsonl.gz"
|
||||
)
|
1
egs/ami/SURT/dprnn_zipformer/beam_search.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/beam_search.py
|
622
egs/ami/SURT/dprnn_zipformer/decode.py
Executable file
622
egs/ami/SURT/dprnn_zipformer/decode.py
Executable file
@ -0,0 +1,622 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./dprnn_zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--avg 1 \
|
||||
--use-averaged-model false \
|
||||
--exp-dir ./dprnn_zipformer/exp_adapt \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search (not recommended)
|
||||
./dprnn_zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--avg 1 \
|
||||
--use-averaged-model false \
|
||||
--exp-dir ./dprnn_zipformer/exp_adapt \
|
||||
--max-duration 600 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./dprnn_zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--avg 1 \
|
||||
--use-averaged-model false \
|
||||
--exp-dir ./dprnn_zipformer/exp_adapt \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import AmiAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from lhotse.utils import EPSILON
|
||||
from train import add_model_arguments, get_params, get_surt_model
|
||||
|
||||
from icefall import LmScorer, NgramLm
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_surt_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=20,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="dprnn_zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
feature_lens = batch["input_lens"].to(device)
|
||||
|
||||
# Apply the mask encoder
|
||||
B, T, F = feature.shape
|
||||
processed = model.mask_encoder(feature) # B,T,F*num_channels
|
||||
masks = processed.view(B, T, F, params.num_channels).unbind(dim=-1)
|
||||
x_masked = [feature * m for m in masks]
|
||||
|
||||
# Recognition
|
||||
# Stack the inputs along the batch axis
|
||||
h = torch.cat(x_masked, dim=0)
|
||||
h_lens = torch.cat([feature_lens for _ in range(params.num_channels)], dim=0)
|
||||
encoder_out, encoder_out_lens = model.encoder(x=h, x_lens=h_lens)
|
||||
|
||||
if model.joint_encoder_layer is not None:
|
||||
encoder_out = model.joint_encoder_layer(encoder_out)
|
||||
|
||||
def _group_channels(hyps: List[str]) -> List[List[str]]:
|
||||
"""
|
||||
Currently we have a batch of size M*B, where M is the number of
|
||||
channels and B is the batch size. We need to group the hypotheses
|
||||
into B groups, each of which contains M hypotheses.
|
||||
|
||||
Example:
|
||||
hyps = ['a1', 'b1', 'c1', 'a2', 'b2', 'c2']
|
||||
_group_channels(hyps) = [['a1', 'a2'], ['b1', 'b2'], ['c1', 'c2']]
|
||||
"""
|
||||
assert len(hyps) == B * params.num_channels
|
||||
out_hyps = []
|
||||
for i in range(B):
|
||||
out_hyps.append(hyps[i::B])
|
||||
return out_hyps
|
||||
|
||||
hyps = []
|
||||
if params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp)
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp)
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp))
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": _group_channels(hyps)}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: _group_channels(hyps)}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": _group_channels(hyps)}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
cut_ids = [cut.id for cut in batch["cuts"]]
|
||||
cuts_batch = batch["cuts"]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
for cut_id, hyp_words in zip(cut_ids, hyps):
|
||||
# Reference is a list of supervision texts sorted by start time.
|
||||
ref_words = [
|
||||
s.text.strip()
|
||||
for s in sorted(
|
||||
cuts_batch[cut_id].supervisions, key=lambda s: s.start
|
||||
)
|
||||
]
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(cut_ids)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_surt_error_stats(
|
||||
f,
|
||||
f"{test_set_name}-{key}",
|
||||
results,
|
||||
enable_log=True,
|
||||
num_channels=params.num_channels,
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = (
|
||||
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LmScorer.add_arguments(parser)
|
||||
AmiAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"modified_beam_search",
|
||||
), f"Decoding method {params.decoding_method} is not supported."
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-use-averaged-model"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_surt_model(params)
|
||||
assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, (
|
||||
model.encoder.decode_chunk_size,
|
||||
params.decode_chunk_len,
|
||||
)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
ami = AmiAsrDataModule(args)
|
||||
|
||||
# NOTE(@desh2608): we filter segments longer than 120s to avoid OOM errors in decoding.
|
||||
# However, 99.9% of the segments are shorter than 120s, so this should not
|
||||
# substantially affect the results. In future, we will implement an overlapped
|
||||
# inference method to avoid OOM errors.
|
||||
|
||||
test_sets = {}
|
||||
for split in ["dev", "test"]:
|
||||
for type in ["ihm-mix", "sdm", "mdm8-bf"]:
|
||||
test_sets[f"ami-{split}_{type}"] = (
|
||||
ami.ami_cuts(split=split, type=type)
|
||||
.trim_to_supervision_groups(max_pause=0.0)
|
||||
.filter(lambda c: 0.1 < c.duration < 120.0)
|
||||
.to_eager()
|
||||
)
|
||||
|
||||
for split in ["dev", "test"]:
|
||||
for type in ["ihm-mix", "sdm"]:
|
||||
test_sets[f"icsi-{split}_{type}"] = (
|
||||
ami.icsi_cuts(split=split, type=type)
|
||||
.trim_to_supervision_groups(max_pause=0.0)
|
||||
.filter(lambda c: 0.1 < c.duration < 120.0)
|
||||
.to_eager()
|
||||
)
|
||||
|
||||
for test_set, test_cuts in test_sets.items():
|
||||
test_dl = ami.test_dataloaders(test_cuts)
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/ami/SURT/dprnn_zipformer/decoder.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/decoder.py
|
1
egs/ami/SURT/dprnn_zipformer/dprnn.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/dprnn.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/dprnn.py
|
1
egs/ami/SURT/dprnn_zipformer/encoder_interface.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/encoder_interface.py
|
1
egs/ami/SURT/dprnn_zipformer/export.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/export.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/export.py
|
1
egs/ami/SURT/dprnn_zipformer/joiner.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/joiner.py
|
1
egs/ami/SURT/dprnn_zipformer/model.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/model.py
|
1
egs/ami/SURT/dprnn_zipformer/optim.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/optim.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/optim.py
|
1
egs/ami/SURT/dprnn_zipformer/scaling.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/scaling.py
|
1
egs/ami/SURT/dprnn_zipformer/scaling_converter.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/scaling_converter.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/scaling_converter.py
|
1
egs/ami/SURT/dprnn_zipformer/test_model.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/test_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/pruned_transducer_stateless7_streaming/test_model.py
|
1420
egs/ami/SURT/dprnn_zipformer/train.py
Executable file
1420
egs/ami/SURT/dprnn_zipformer/train.py
Executable file
File diff suppressed because it is too large
Load Diff
1411
egs/ami/SURT/dprnn_zipformer/train_adapt.py
Executable file
1411
egs/ami/SURT/dprnn_zipformer/train_adapt.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/ami/SURT/dprnn_zipformer/zipformer.py
Symbolic link
1
egs/ami/SURT/dprnn_zipformer/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../libricss/SURT/dprnn_zipformer/zipformer.py
|
78
egs/ami/SURT/local/add_source_feats.py
Executable file
78
egs/ami/SURT/local/add_source_feats.py
Executable file
@ -0,0 +1,78 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file adds source features as temporal arrays to the mixture manifests.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from lhotse import CutSet, LilcomChunkyWriter, load_manifest, load_manifest_lazy
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def add_source_feats():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
logging.info("Reading mixed cuts")
|
||||
mixed_cuts_clean = load_manifest_lazy(src_dir / "cuts_train_clean.jsonl.gz")
|
||||
mixed_cuts_reverb = load_manifest_lazy(src_dir / "cuts_train_reverb.jsonl.gz")
|
||||
|
||||
logging.info("Reading source cuts")
|
||||
source_cuts = load_manifest(src_dir / "ihm_cuts_train_trimmed.jsonl.gz")
|
||||
|
||||
logging.info("Adding source features to the mixed cuts")
|
||||
pbar = tqdm(total=len(mixed_cuts_clean), desc="Adding source features")
|
||||
with CutSet.open_writer(
|
||||
src_dir / "cuts_train_clean_sources.jsonl.gz"
|
||||
) as cut_writer_clean, CutSet.open_writer(
|
||||
src_dir / "cuts_train_reverb_sources.jsonl.gz"
|
||||
) as cut_writer_reverb, LilcomChunkyWriter(
|
||||
output_dir / "feats_train_clean_sources"
|
||||
) as source_feat_writer:
|
||||
for cut_clean, cut_reverb in zip(mixed_cuts_clean, mixed_cuts_reverb):
|
||||
assert cut_reverb.id == cut_clean.id + "_rvb"
|
||||
source_feats = []
|
||||
source_feat_offsets = []
|
||||
cur_offset = 0
|
||||
for sup in sorted(
|
||||
cut_clean.supervisions, key=lambda s: (s.start, s.speaker)
|
||||
):
|
||||
source_cut = source_cuts[sup.id]
|
||||
source_feats.append(source_cut.load_features())
|
||||
source_feat_offsets.append(cur_offset)
|
||||
cur_offset += source_cut.num_frames
|
||||
cut_clean.source_feats = source_feat_writer.store_array(
|
||||
cut_clean.id, np.concatenate(source_feats, axis=0)
|
||||
)
|
||||
cut_clean.source_feat_offsets = source_feat_offsets
|
||||
cut_writer_clean.write(cut_clean)
|
||||
# Also write the reverb cut
|
||||
cut_reverb.source_feats = cut_clean.source_feats
|
||||
cut_reverb.source_feat_offsets = cut_clean.source_feat_offsets
|
||||
cut_writer_reverb.write(cut_reverb)
|
||||
pbar.update(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
add_source_feats()
|
185
egs/ami/SURT/local/compute_fbank_aimix.py
Executable file
185
egs/ami/SURT/local/compute_fbank_aimix.py
Executable file
@ -0,0 +1,185 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the synthetically mixed AMI and ICSI
|
||||
train set.
|
||||
It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
import torchaudio
|
||||
from lhotse import (
|
||||
AudioSource,
|
||||
LilcomChunkyWriter,
|
||||
Recording,
|
||||
load_manifest,
|
||||
load_manifest_lazy,
|
||||
)
|
||||
from lhotse.audio import set_ffmpeg_torchaudio_info_enabled
|
||||
from lhotse.cut import MixedCut, MixTrack, MultiCut
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed, uuid4
|
||||
from tqdm import tqdm
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
torchaudio.set_audio_backend("soundfile")
|
||||
set_ffmpeg_torchaudio_info_enabled(False)
|
||||
|
||||
|
||||
def compute_fbank_aimix():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
train_cuts = load_manifest_lazy(src_dir / "ai-mix_cuts_clean_full.jsonl.gz")
|
||||
|
||||
# only uses RIRs and noises from REVERB challenge
|
||||
real_rirs = load_manifest(src_dir / "real-rir_recordings_all.jsonl.gz").filter(
|
||||
lambda r: "RVB2014" in r.id
|
||||
)
|
||||
noises = load_manifest(src_dir / "iso-noise_recordings_all.jsonl.gz").filter(
|
||||
lambda r: "RVB2014" in r.id
|
||||
)
|
||||
|
||||
# Apply perturbation to the training cuts
|
||||
logging.info("Applying perturbation to the training cuts")
|
||||
train_cuts_rvb = train_cuts.map(
|
||||
lambda c: augment(
|
||||
c, perturb_snr=True, rirs=real_rirs, noises=noises, perturb_loudness=True
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Extracting fbank features for training cuts")
|
||||
_ = train_cuts.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / "ai-mix_feats_clean",
|
||||
manifest_path=src_dir / "cuts_train_clean.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
_ = train_cuts_rvb.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / "ai-mix_feats_reverb",
|
||||
manifest_path=src_dir / "cuts_train_reverb.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
||||
def augment(cut, perturb_snr=False, rirs=None, noises=None, perturb_loudness=False):
|
||||
"""
|
||||
Given a mixed cut, this function optionally applies the following augmentations:
|
||||
- Perturbing the SNRs of the tracks (in range [-5, 5] dB)
|
||||
- Reverberation using a randomly selected RIR
|
||||
- Adding noise
|
||||
- Perturbing the loudness (in range [-20, -25] dB)
|
||||
"""
|
||||
out_cut = cut.drop_features()
|
||||
|
||||
# Perturb the SNRs (optional)
|
||||
if perturb_snr:
|
||||
snrs = [random.uniform(-5, 5) for _ in range(len(cut.tracks))]
|
||||
for i, (track, snr) in enumerate(zip(out_cut.tracks, snrs)):
|
||||
if i == 0:
|
||||
# Skip the first track since it is the reference
|
||||
continue
|
||||
track.snr = snr
|
||||
|
||||
# Reverberate the cut (optional)
|
||||
if rirs is not None:
|
||||
# Select an RIR at random
|
||||
rir = random.choice(rirs)
|
||||
# Select a channel at random
|
||||
rir_channel = random.choice(list(range(rir.num_channels)))
|
||||
# Reverberate the cut
|
||||
out_cut = out_cut.reverb_rir(rir_recording=rir, rir_channels=[rir_channel])
|
||||
|
||||
# Add noise (optional)
|
||||
if noises is not None:
|
||||
# Select a noise recording at random
|
||||
noise = random.choice(noises).to_cut()
|
||||
if isinstance(noise, MultiCut):
|
||||
noise = noise.to_mono()[0]
|
||||
# Select an SNR at random
|
||||
snr = random.uniform(10, 30)
|
||||
# Repeat the noise to match the duration of the cut
|
||||
noise = repeat_cut(noise, out_cut.duration)
|
||||
out_cut = MixedCut(
|
||||
id=out_cut.id,
|
||||
tracks=[
|
||||
MixTrack(cut=out_cut, type="MixedCut"),
|
||||
MixTrack(cut=noise, type="DataCut", snr=snr),
|
||||
],
|
||||
)
|
||||
|
||||
# Perturb the loudness (optional)
|
||||
if perturb_loudness:
|
||||
target_loudness = random.uniform(-20, -25)
|
||||
out_cut = out_cut.normalize_loudness(target_loudness, mix_first=True)
|
||||
return out_cut
|
||||
|
||||
|
||||
def repeat_cut(cut, duration):
|
||||
while cut.duration < duration:
|
||||
cut = cut.mix(cut, offset_other_by=cut.duration)
|
||||
return cut.truncate(duration=duration)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
fix_random_seed(42)
|
||||
compute_fbank_aimix()
|
94
egs/ami/SURT/local/compute_fbank_ami.py
Executable file
94
egs/ami/SURT/local/compute_fbank_ami.py
Executable file
@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the AMI dataset.
|
||||
We compute features for full recordings (i.e., without trimming to supervisions).
|
||||
This way we can create arbitrary segmentations later.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
from lhotse import CutSet, LilcomChunkyWriter
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
|
||||
def compute_fbank_ami():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
manifests = {}
|
||||
for part in ["ihm-mix", "sdm", "mdm8-bf"]:
|
||||
manifests[part] = read_manifests_if_cached(
|
||||
dataset_parts=["train", "dev", "test"],
|
||||
output_dir=src_dir,
|
||||
prefix=f"ami-{part}",
|
||||
suffix="jsonl.gz",
|
||||
)
|
||||
|
||||
for part in ["ihm-mix", "sdm", "mdm8-bf"]:
|
||||
for split in ["train", "dev", "test"]:
|
||||
logging.info(f"Processing {part} {split}")
|
||||
cuts = CutSet.from_manifests(
|
||||
**manifests[part][split]
|
||||
).compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"ami-{part}_{split}_feats",
|
||||
manifest_path=src_dir / f"cuts_ami-{part}_{split}.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_ami()
|
95
egs/ami/SURT/local/compute_fbank_icsi.py
Executable file
95
egs/ami/SURT/local/compute_fbank_icsi.py
Executable file
@ -0,0 +1,95 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the ICSI dataset.
|
||||
We compute features for full recordings (i.e., without trimming to supervisions).
|
||||
This way we can create arbitrary segmentations later.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
from lhotse import CutSet, LilcomChunkyWriter
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
|
||||
|
||||
def compute_fbank_icsi():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
manifests = {}
|
||||
for part in ["ihm-mix", "sdm"]:
|
||||
manifests[part] = read_manifests_if_cached(
|
||||
dataset_parts=["train"],
|
||||
output_dir=src_dir,
|
||||
prefix=f"icsi-{part}",
|
||||
suffix="jsonl.gz",
|
||||
)
|
||||
|
||||
for part in ["ihm-mix", "sdm"]:
|
||||
for split in ["train"]:
|
||||
logging.info(f"Processing {part} {split}")
|
||||
cuts = CutSet.from_manifests(
|
||||
**manifests[part][split]
|
||||
).compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"icsi-{part}_{split}_feats",
|
||||
manifest_path=src_dir / f"cuts_icsi-{part}_{split}.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_icsi()
|
101
egs/ami/SURT/local/compute_fbank_ihm.py
Executable file
101
egs/ami/SURT/local/compute_fbank_ihm.py
Executable file
@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file computes fbank features of the trimmed sub-segments which will be
|
||||
used for simulating the training mixtures.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
import torchaudio
|
||||
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
|
||||
from lhotse.audio import set_ffmpeg_torchaudio_info_enabled
|
||||
from lhotse.features.kaldifeat import (
|
||||
KaldifeatFbank,
|
||||
KaldifeatFbankConfig,
|
||||
KaldifeatFrameOptions,
|
||||
KaldifeatMelOptions,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
from tqdm import tqdm
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
# Do this outside of main() in case it needs to take effect
|
||||
# even when we are not invoking the main (e.g. when spawning subprocesses).
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
torch.multiprocessing.set_sharing_strategy("file_system")
|
||||
torchaudio.set_audio_backend("soundfile")
|
||||
set_ffmpeg_torchaudio_info_enabled(False)
|
||||
|
||||
|
||||
def compute_fbank_ihm():
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
|
||||
sampling_rate = 16000
|
||||
num_mel_bins = 80
|
||||
|
||||
extractor = KaldifeatFbank(
|
||||
KaldifeatFbankConfig(
|
||||
frame_opts=KaldifeatFrameOptions(sampling_rate=sampling_rate),
|
||||
mel_opts=KaldifeatMelOptions(num_bins=num_mel_bins),
|
||||
device="cuda",
|
||||
)
|
||||
)
|
||||
|
||||
logging.info("Reading manifests")
|
||||
manifests = {}
|
||||
for data in ["ami", "icsi"]:
|
||||
manifests[data] = read_manifests_if_cached(
|
||||
dataset_parts=["train"],
|
||||
output_dir=src_dir,
|
||||
types=["recordings", "supervisions"],
|
||||
prefix=f"{data}-ihm",
|
||||
suffix="jsonl.gz",
|
||||
)
|
||||
|
||||
logging.info("Computing features")
|
||||
for data in ["ami", "icsi"]:
|
||||
cs = CutSet.from_manifests(**manifests[data]["train"])
|
||||
cs = cs.trim_to_supervisions(keep_overlapping=False)
|
||||
cs = cs.normalize_loudness(target=-23.0, affix_id=False)
|
||||
cs = cs + cs.perturb_speed(0.9) + cs.perturb_speed(1.1)
|
||||
_ = cs.compute_and_store_features_batch(
|
||||
extractor=extractor,
|
||||
storage_path=output_dir / f"{data}-ihm_train_feats",
|
||||
manifest_path=src_dir / f"{data}-ihm_cuts_train.jsonl.gz",
|
||||
batch_duration=5000,
|
||||
num_workers=4,
|
||||
storage_type=LilcomChunkyWriter,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
compute_fbank_ihm()
|
146
egs/ami/SURT/local/prepare_ami_train_cuts.py
Executable file
146
egs/ami/SURT/local/prepare_ami_train_cuts.py
Executable file
@ -0,0 +1,146 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file creates AMI train segments.
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing
|
||||
from lhotse import LilcomChunkyWriter, load_manifest_lazy
|
||||
from lhotse.cut import Cut, CutSet
|
||||
from lhotse.utils import EPSILON, add_durations
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def cut_into_windows(cuts: CutSet, duration: float):
|
||||
"""
|
||||
This function takes a CutSet and cuts each cut into windows of roughly
|
||||
`duration` seconds. By roughly, we mean that we try to adjust for the last supervision
|
||||
that exceeds the duration, or is shorter than the duration.
|
||||
"""
|
||||
res = []
|
||||
with tqdm() as pbar:
|
||||
for cut in cuts:
|
||||
pbar.update(1)
|
||||
sups = cut.index_supervisions()[cut.id]
|
||||
sr = cut.sampling_rate
|
||||
start = 0.0
|
||||
end = duration
|
||||
num_tries = 0
|
||||
while start < cut.duration and num_tries < 2:
|
||||
# Find the supervision that are cut by the window endpoint
|
||||
hitlist = [iv for iv in sups.at(end) if iv.begin < end]
|
||||
# If there are no supervisions, we are done
|
||||
if not hitlist:
|
||||
res.append(
|
||||
cut.truncate(
|
||||
offset=start,
|
||||
duration=add_durations(end, -start, sampling_rate=sr),
|
||||
keep_excessive_supervisions=False,
|
||||
)
|
||||
)
|
||||
# Update the start and end for the next window
|
||||
start = end
|
||||
end = add_durations(end, duration, sampling_rate=sr)
|
||||
else:
|
||||
# find ratio of durations cut by the window endpoint
|
||||
ratios = [
|
||||
add_durations(end, -iv.end, sampling_rate=sr) / iv.length()
|
||||
for iv in hitlist
|
||||
]
|
||||
# we retain the supervisions that have >50% of their duration
|
||||
# in the window, and discard the others
|
||||
retained = []
|
||||
discarded = []
|
||||
for iv, ratio in zip(hitlist, ratios):
|
||||
if ratio > 0.5:
|
||||
retained.append(iv)
|
||||
else:
|
||||
discarded.append(iv)
|
||||
cur_end = max(iv.end for iv in retained) if retained else end
|
||||
res.append(
|
||||
cut.truncate(
|
||||
offset=start,
|
||||
duration=add_durations(cur_end, -start, sampling_rate=sr),
|
||||
keep_excessive_supervisions=False,
|
||||
)
|
||||
)
|
||||
# For the next window, we start at the earliest discarded supervision
|
||||
next_start = min(iv.begin for iv in discarded) if discarded else end
|
||||
next_end = add_durations(next_start, duration, sampling_rate=sr)
|
||||
# It may happen that next_start is the same as start, in which case
|
||||
# we will advance the window anyway
|
||||
if next_start == start:
|
||||
logging.warning(
|
||||
f"Next start is the same as start: {next_start} == {start} for cut {cut.id}"
|
||||
)
|
||||
start = end + EPSILON
|
||||
end = add_durations(start, duration, sampling_rate=sr)
|
||||
num_tries += 1
|
||||
else:
|
||||
start = next_start
|
||||
end = next_end
|
||||
return CutSet.from_cuts(res)
|
||||
|
||||
|
||||
def prepare_train_cuts():
|
||||
src_dir = Path("data/manifests")
|
||||
|
||||
logging.info("Loading the manifests")
|
||||
train_cuts_ihm = load_manifest_lazy(
|
||||
src_dir / "cuts_ami-ihm-mix_train.jsonl.gz"
|
||||
).map(lambda c: c.with_id(f"{c.id}_ihm-mix"))
|
||||
train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_ami-sdm_train.jsonl.gz").map(
|
||||
lambda c: c.with_id(f"{c.id}_sdm")
|
||||
)
|
||||
train_cuts_mdm = load_manifest_lazy(
|
||||
src_dir / "cuts_ami-mdm8-bf_train.jsonl.gz"
|
||||
).map(lambda c: c.with_id(f"{c.id}_mdm8-bf"))
|
||||
|
||||
# Combine all cuts into one CutSet
|
||||
train_cuts = train_cuts_ihm + train_cuts_sdm + train_cuts_mdm
|
||||
|
||||
train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5)
|
||||
train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0)
|
||||
|
||||
# Combine the two segmentations
|
||||
train_all = train_cuts_1 + train_cuts_2
|
||||
|
||||
# At this point, some of the cuts may be very long. We will cut them into windows of
|
||||
# roughly 30 seconds.
|
||||
logging.info("Cutting the segments into windows of 30 seconds")
|
||||
train_all_30 = cut_into_windows(train_all, duration=30.0)
|
||||
logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}")
|
||||
|
||||
# Show statistics
|
||||
train_all.describe(full=True)
|
||||
|
||||
# Save the cuts
|
||||
logging.info("Saving the cuts")
|
||||
train_all.to_file(src_dir / "cuts_train_ami.jsonl.gz")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
prepare_train_cuts()
|
67
egs/ami/SURT/local/prepare_icsi_train_cuts.py
Executable file
67
egs/ami/SURT/local/prepare_icsi_train_cuts.py
Executable file
@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Johns Hopkins University (authors: Desh Raj)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This file creates ICSI train segments.
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from lhotse import load_manifest_lazy
|
||||
from prepare_ami_train_cuts import cut_into_windows
|
||||
|
||||
|
||||
def prepare_train_cuts():
|
||||
src_dir = Path("data/manifests")
|
||||
|
||||
logging.info("Loading the manifests")
|
||||
train_cuts_ihm = load_manifest_lazy(
|
||||
src_dir / "cuts_icsi-ihm-mix_train.jsonl.gz"
|
||||
).map(lambda c: c.with_id(f"{c.id}_ihm-mix"))
|
||||
train_cuts_sdm = load_manifest_lazy(src_dir / "cuts_icsi-sdm_train.jsonl.gz").map(
|
||||
lambda c: c.with_id(f"{c.id}_sdm")
|
||||
)
|
||||
|
||||
# Combine all cuts into one CutSet
|
||||
train_cuts = train_cuts_ihm + train_cuts_sdm
|
||||
|
||||
train_cuts_1 = train_cuts.trim_to_supervision_groups(max_pause=0.5)
|
||||
train_cuts_2 = train_cuts.trim_to_supervision_groups(max_pause=0.0)
|
||||
|
||||
# Combine the two segmentations
|
||||
train_all = train_cuts_1 + train_cuts_2
|
||||
|
||||
# At this point, some of the cuts may be very long. We will cut them into windows of
|
||||
# roughly 30 seconds.
|
||||
logging.info("Cutting the segments into windows of 30 seconds")
|
||||
train_all_30 = cut_into_windows(train_all, duration=30.0)
|
||||
logging.info(f"Number of cuts after cutting into windows: {len(train_all_30)}")
|
||||
|
||||
# Show statistics
|
||||
train_all.describe(full=True)
|
||||
|
||||
# Save the cuts
|
||||
logging.info("Saving the cuts")
|
||||
train_all.to_file(src_dir / "cuts_train_icsi.jsonl.gz")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
prepare_train_cuts()
|
1
egs/ami/SURT/local/prepare_lang_bpe.py
Symbolic link
1
egs/ami/SURT/local/prepare_lang_bpe.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/prepare_lang_bpe.py
|
1
egs/ami/SURT/local/train_bpe_model.py
Symbolic link
1
egs/ami/SURT/local/train_bpe_model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../../librispeech/ASR/local/train_bpe_model.py
|
195
egs/ami/SURT/prepare.sh
Executable file
195
egs/ami/SURT/prepare.sh
Executable file
@ -0,0 +1,195 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=-1
|
||||
stop_stage=100
|
||||
|
||||
# We assume dl_dir (download dir) contains the following
|
||||
# directories and files. If not, they will be downloaded
|
||||
# by this script automatically.
|
||||
#
|
||||
# - $dl_dir/ami
|
||||
# You can find audio and transcripts for AMI in this path.
|
||||
#
|
||||
# - $dl_dir/icsi
|
||||
# You can find audio and transcripts for ICSI in this path.
|
||||
#
|
||||
# - $dl_dir/rirs_noises
|
||||
# This directory contains the RIRS_NOISES corpus downloaded from https://openslr.org/28/.
|
||||
#
|
||||
dl_dir=$PWD/download
|
||||
|
||||
. shared/parse_options.sh || exit 1
|
||||
|
||||
# All files generated by this script are saved in "data".
|
||||
# You can safely remove "data" and rerun this script to regenerate it.
|
||||
mkdir -p data
|
||||
vocab_size=500
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "Stage 0: Download data"
|
||||
|
||||
# If you have pre-downloaded it to /path/to/amicorpus,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/amicorpus $dl_dir/amicorpus
|
||||
#
|
||||
if [ ! -d $dl_dir/amicorpus ]; then
|
||||
for mic in ihm ihm-mix sdm mdm8-bf; do
|
||||
lhotse download ami --mic $mic $dl_dir/amicorpus
|
||||
done
|
||||
fi
|
||||
|
||||
# If you have pre-downloaded it to /path/to/icsi,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/icsi $dl_dir/icsi
|
||||
#
|
||||
if [ ! -d $dl_dir/icsi ]; then
|
||||
lhotse download icsi $dl_dir/icsi
|
||||
fi
|
||||
|
||||
# If you have pre-downloaded it to /path/to/rirs_noises,
|
||||
# you can create a symlink
|
||||
#
|
||||
# ln -sfv /path/to/rirs_noises $dl_dir/
|
||||
#
|
||||
if [ ! -d $dl_dir/rirs_noises ]; then
|
||||
lhotse download rirs_noises $dl_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "Stage 1: Prepare AMI manifests"
|
||||
# We assume that you have downloaded the AMI corpus
|
||||
# to $dl_dir/amicorpus. We perform text normalization for the transcripts.
|
||||
mkdir -p data/manifests
|
||||
for mic in ihm ihm-mix sdm mdm8-bf; do
|
||||
log "Preparing AMI manifest for $mic"
|
||||
lhotse prepare ami --mic $mic --max-words-per-segment 30 --merge-consecutive $dl_dir/amicorpus data/manifests/
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "Stage 2: Prepare ICSI manifests"
|
||||
# We assume that you have downloaded the ICSI corpus
|
||||
# to $dl_dir/icsi. We perform text normalization for the transcripts.
|
||||
mkdir -p data/manifests
|
||||
log "Preparing ICSI manifest"
|
||||
for mic in ihm ihm-mix sdm; do
|
||||
lhotse prepare icsi --mic $mic $dl_dir/icsi data/manifests/
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Prepare RIRs"
|
||||
# We assume that you have downloaded the RIRS_NOISES corpus
|
||||
# to $dl_dir/rirs_noises
|
||||
lhotse prepare rir-noise -p real_rir -p iso_noise $dl_dir/rirs_noises data/manifests
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 3: Extract features for AMI and ICSI recordings"
|
||||
python local/compute_fbank_ami.py
|
||||
python local/compute_fbank_icsi.py
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Create sources for simulating mixtures"
|
||||
# In the following script, we speed-perturb the IHM recordings and extract features.
|
||||
python local/compute_fbank_ihm.py
|
||||
lhotse combine data/manifests/ami-ihm_cuts_train.jsonl.gz \
|
||||
data/manifests/icsi-ihm_cuts_train.jsonl.gz - |\
|
||||
lhotse cut trim-to-alignments --type word --max-pause 0.5 - - |\
|
||||
lhotse filter 'duration<=12.0' - - |\
|
||||
shuf | gzip -c > data/manifests/ihm_cuts_train_trimmed.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Create training mixtures"
|
||||
lhotse workflows simulate-meetings \
|
||||
--method conversational \
|
||||
--same-spk-pause 0.5 \
|
||||
--diff-spk-pause 0.5 \
|
||||
--diff-spk-overlap 1.0 \
|
||||
--prob-diff-spk-overlap 0.8 \
|
||||
--num-meetings 200000 \
|
||||
--num-speakers-per-meeting 2,3 \
|
||||
--max-duration-per-speaker 15.0 \
|
||||
--max-utterances-per-speaker 3 \
|
||||
--seed 1234 \
|
||||
--num-jobs 2 \
|
||||
data/manifests/ihm_cuts_train_trimmed.jsonl.gz \
|
||||
data/manifests/ai-mix_cuts_clean.jsonl.gz
|
||||
|
||||
python local/compute_fbank_aimix.py
|
||||
|
||||
# Add source features to the manifest (will be used for masking loss)
|
||||
# This may take ~2 hours.
|
||||
python local/add_source_feats.py
|
||||
|
||||
# Combine clean and reverb
|
||||
cat <(gunzip -c data/manifests/cuts_train_clean_sources.jsonl.gz) \
|
||||
<(gunzip -c data/manifests/cuts_train_reverb_sources.jsonl.gz) |\
|
||||
shuf | gzip -c > data/manifests/cuts_train_comb_sources.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Create training mixtures from real sessions"
|
||||
python local/prepare_ami_train_cuts.py
|
||||
python local/prepare_icsi_train_cuts.py
|
||||
|
||||
# Combine AMI and ICSI
|
||||
cat <(gunzip -c data/manifests/cuts_train_ami.jsonl.gz) \
|
||||
<(gunzip -c data/manifests/cuts_train_icsi.jsonl.gz) |\
|
||||
shuf | gzip -c > data/manifests/cuts_train_ami_icsi.jsonl.gz
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "Stage 8: Dump transcripts for BPE model training (using AMI and ICSI)."
|
||||
mkdir -p data/lm
|
||||
cat <(gunzip -c data/manifests/ami-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \
|
||||
<(gunzip -c data/manifests/icsi-sdm_supervisions_train.jsonl.gz | jq '.text' | sed 's:"::g') \
|
||||
> data/lm/transcript_words.txt
|
||||
fi
|
||||
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "Stage 9: Prepare BPE based lang (combining AMI and ICSI)"
|
||||
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
mkdir -p $lang_dir
|
||||
|
||||
# Add special words to words.txt
|
||||
echo "<eps> 0" > $lang_dir/words.txt
|
||||
echo "!SIL 1" >> $lang_dir/words.txt
|
||||
echo "<UNK> 2" >> $lang_dir/words.txt
|
||||
|
||||
# Add regular words to words.txt
|
||||
cat data/lm/transcript_words.txt | grep -o -E '\w+' | sort -u | awk '{print $0,NR+2}' >> $lang_dir/words.txt
|
||||
|
||||
# Add remaining special word symbols expected by LM scripts.
|
||||
num_words=$(cat $lang_dir/words.txt | wc -l)
|
||||
echo "<s> ${num_words}" >> $lang_dir/words.txt
|
||||
num_words=$(cat $lang_dir/words.txt | wc -l)
|
||||
echo "</s> ${num_words}" >> $lang_dir/words.txt
|
||||
num_words=$(cat $lang_dir/words.txt | wc -l)
|
||||
echo "#0 ${num_words}" >> $lang_dir/words.txt
|
||||
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size \
|
||||
--transcript data/lm/transcript_words.txt
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||
fi
|
||||
fi
|
1
egs/ami/SURT/shared
Symbolic link
1
egs/ami/SURT/shared
Symbolic link
@ -0,0 +1 @@
|
||||
../../../icefall/shared
|
Loading…
x
Reference in New Issue
Block a user