mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Add pseudo-labeling based semi-supervised training recipe
This commit is contained in:
parent
1730fce688
commit
60bfbffda7
118
egs/librispeech/PL/README.md
Normal file
118
egs/librispeech/PL/README.md
Normal file
@ -0,0 +1,118 @@
|
||||
# Introduction
|
||||
|
||||
This is a pseudo-labeling based semi-supervised ASR recipe for the LibriSpeech dataset. The ASR model is Zipformer Transducer. The labeled data is Labeled data is LibriSpeech train-clean-100. Unlabeled data can be LibriSpeech "train-clean-360 + train-other-500" for conventional semi-supervised learning or TedLium3 training set for unsupervised domain adaptation.
|
||||
|
||||
## Description of the recipe
|
||||
|
||||
### Preparation of data
|
||||
|
||||
The data required in this recipe is the same with LibriSpeech/TedLium3 ASR recipe. And the tokenizer of LibriSpeech is used to build the model. Therefore, we can reuse the `prepare.sh` scripts in those recipes.
|
||||
|
||||
### Supervised training for the seed ASR model
|
||||
|
||||
Firstly, we need to perform supervised training on the LibriSpeech train-clean-100 subset to generate the seed model for the following pseudo-labeling based semi-supervsed training.
|
||||
|
||||
```
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
./zipformer/train_seed.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 70 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp_seed \
|
||||
--max-duration 1000
|
||||
```
|
||||
|
||||
For better performance of the seed model, we average the checkpoints as follows:
|
||||
|
||||
```
|
||||
./zipformer/generate_averaged_model.py \
|
||||
--epoch 70 \
|
||||
--avg 30 \
|
||||
--exp-dir ./zipformer/exp_seed
|
||||
```
|
||||
|
||||
The above command generates the final seed model `./zipformer/exp_seed/epoch-70-avg-30.pt`
|
||||
|
||||
### Semi-supervised training for the final ASR model
|
||||
|
||||
Then, we peform semi-supervised training with the seed model as the initialization.
|
||||
|
||||
- Conventional semi-supervised learning setting where unlabeled data is "train-clean-360 + train-other-500":
|
||||
|
||||
```
|
||||
./zipformer/train_pl.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 20 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp_pl_librispeech \
|
||||
--max-duration 1000 \
|
||||
--seed-model-path "zipformer/exp_seed/epoch-70-avg-30.pt" \
|
||||
--unlabeled-dataset "librispeech"
|
||||
```
|
||||
|
||||
- Unsupervised domain adaptation setting where unlabeled data is TedLium3 training set:
|
||||
|
||||
```
|
||||
./zipformer/train_pl.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 20 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp_pl_tedlium \
|
||||
--max-duration 1000 \
|
||||
--seed-model-path "zipformer/exp_seed/epoch-70-avg-30.pt" \
|
||||
--unlabeled-dataset "tedlium"
|
||||
```
|
||||
|
||||
### Decode
|
||||
|
||||
Finally, we decode the ASR model to evaluate the performance.
|
||||
|
||||
- Evaluate on the LibriSpeech dataset:
|
||||
|
||||
```
|
||||
./zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--exp-dir ./zipformer/exp_pl_librispeech \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
--dataset "librispeech"
|
||||
```
|
||||
|
||||
- Evaluate on the TedLium3 dataset:
|
||||
|
||||
```
|
||||
./zipformer/decode.py \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--exp-dir ./zipformer/exp_pl_tedlium \
|
||||
--max-duration 600 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
--dataset "tedlium"
|
||||
```
|
||||
|
||||
## Results
|
||||
|
||||
- Conventional semi-supervised learning (LibriSpeech 100h/LibriSpeech 860h)
|
||||
|
||||
| Model | test-clean | test-other | comment |
|
||||
|-------------------------|------------|------------|---------------------|
|
||||
| supervised seed model | 5.45 | 13.7 | --epoch 70 --avg 30 |
|
||||
| pseudo-labeling model | 4.33 | 9.61 | --epoch 20 --avg 10 |
|
||||
|
||||
- Unsupervised domain adaptation (LibriSpeech 100h/TedLium3)
|
||||
|
||||
| Model | tedlium3 dev | tedlium3 test | comment |
|
||||
|-------------------------|------------|------------|---------------------|
|
||||
| supervised seed model | 18.29 | 18.16 | --epoch 70 --avg 30 |
|
||||
| pseudo-labeling model | 14.97 | 14.65 | --epoch 20 --avg 10 |
|
||||
|
||||
|
||||
## Pre-trained models and logs
|
||||
|
||||
You can find the pre-trained models, training logs, tensorboard logs, decoding logs and decoding results at <https://huggingface.co/zhu-han/icefall-pl-librispeech-zipformer-medium-2023-08-06>
|
654
egs/librispeech/PL/zipformer/asr_datamodule.py
Normal file
654
egs/librispeech/PL/zipformer/asr_datamodule.py
Normal file
@ -0,0 +1,654 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SimpleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from dataset import K2UnlabeledSpeechRecognitionDataset
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
from lhotse.utils import fastcopy
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
def tedlium_text_process(segment):
|
||||
if segment.text is None:
|
||||
return segment
|
||||
return fastcopy(segment, text=segment.text.upper().replace(" '", "'").replace("<UNK>", "").replace(" ", " "))
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="""Used only when --mini-libri is False.When enabled,
|
||||
use 960h LibriSpeech. Otherwise, use 100h subset.""",
|
||||
)
|
||||
group.add_argument(
|
||||
"--mini-libri",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="True for mini librispeech",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
def train_labeled_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SpeechRecognitionDataset(
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=self.args.num_buckets * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
|
||||
def train_unlabeled_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_frame_masks=num_frame_masks*2,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2*2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2UnlabeledSpeechRecognitionDataset(
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2UnlabeledSpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
buffer_size=self.args.num_buckets * 2000,
|
||||
shuffle_buffer_size=self.args.num_buckets * 5000,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SimpleCutSampler.")
|
||||
train_sampler = SimpleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_5_cuts(self) -> CutSet:
|
||||
logging.info("mini_librispeech: About to get train-clean-5 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-5.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_360_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-360 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_other_500_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-other-500 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_all_shuf_cuts(self) -> CutSet:
|
||||
logging.info(
|
||||
"About to get the shuffled train-clean-100, \
|
||||
train-clean-360 and train-other-500 cuts"
|
||||
)
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_unlabeled_shuf_cuts(self) -> CutSet:
|
||||
logging.info(
|
||||
"About to get the shuffled \
|
||||
train-clean-360 and train-other-500 cuts"
|
||||
)
|
||||
train_clean_360_cuts = self.train_clean_360_cuts()
|
||||
train_other_500_cuts = self.train_other_500_cuts()
|
||||
return CutSet.mux(
|
||||
train_clean_360_cuts,
|
||||
train_other_500_cuts,
|
||||
weights=[
|
||||
104014, # len(train_clean_360_cuts)
|
||||
148688, # len(train_other_500_cuts)
|
||||
],
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_2_cuts(self) -> CutSet:
|
||||
logging.info("mini_librispeech: About to get dev-clean-2 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-clean-2.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_tedlium_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "tedlium_cuts_train.jsonl.gz"
|
||||
).map_supervisions(tedlium_text_process)
|
||||
|
||||
@lru_cache()
|
||||
def dev_tedlium_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz"
|
||||
).map_supervisions(tedlium_text_process)
|
||||
|
||||
@lru_cache()
|
||||
def test_tedlium_cuts(self) -> CutSet:
|
||||
logging.info("About to get test cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz"
|
||||
).map_supervisions(tedlium_text_process)
|
||||
|
1
egs/librispeech/PL/zipformer/beam_search.py
Symbolic link
1
egs/librispeech/PL/zipformer/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/beam_search.py
|
232
egs/librispeech/PL/zipformer/checkpoint.py
Normal file
232
egs/librispeech/PL/zipformer/checkpoint.py
Normal file
@ -0,0 +1,232 @@
|
||||
# Copyright 2021-2022 Xiaomi Corporation (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
|
||||
# use duck typing for LRScheduler since we have different possibilities, see
|
||||
# our class LRScheduler.
|
||||
LRSchedulerType = object
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
filename: Path,
|
||||
model: Union[nn.Module, DDP],
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
model_ema: Optional[nn.Module] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
labeled_sampler: Optional[CutSampler] = None,
|
||||
unlabeled_sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save training information to a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
The checkpoint filename.
|
||||
model:
|
||||
The model to be saved. We only save its `state_dict()`.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
model_ema:
|
||||
The EMA version of model.
|
||||
params:
|
||||
User defined parameters, e.g., epoch, loss.
|
||||
optimizer:
|
||||
The optimizer to be saved. We only save its `state_dict()`.
|
||||
scheduler:
|
||||
The scheduler to be saved. We only save its `state_dict()`.
|
||||
scalar:
|
||||
The GradScaler to be saved. We only save its `state_dict()`.
|
||||
labeled_sampler:
|
||||
The sampler used in the labeled training dataset. We only save its `state_dict()`.
|
||||
unlabeled_sampler:
|
||||
The sampler used in the unlabeled training dataset. We only save its `state_dict()`.
|
||||
rank:
|
||||
Used in DDP. We save checkpoint only for the node whose rank is 0.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
|
||||
logging.info(f"Saving checkpoint to {filename}")
|
||||
|
||||
if isinstance(model, DDP):
|
||||
model = model.module
|
||||
|
||||
checkpoint = {
|
||||
"model": model.state_dict(),
|
||||
"optimizer": optimizer.state_dict() if optimizer is not None else None,
|
||||
"scheduler": scheduler.state_dict() if scheduler is not None else None,
|
||||
"grad_scaler": scaler.state_dict() if scaler is not None else None,
|
||||
"labeled_sampler": labeled_sampler.state_dict() if labeled_sampler is not None else None,
|
||||
"unlabeled_sampler": unlabeled_sampler.state_dict() if unlabeled_sampler is not None else None,
|
||||
}
|
||||
|
||||
if model_avg is not None:
|
||||
checkpoint["model_avg"] = model_avg.to(torch.float32).state_dict()
|
||||
if model_ema is not None:
|
||||
checkpoint["model_ema"] = model_ema.model.to(torch.float32).state_dict()
|
||||
|
||||
if params:
|
||||
for k, v in params.items():
|
||||
assert k not in checkpoint
|
||||
checkpoint[k] = v
|
||||
|
||||
torch.save(checkpoint, filename)
|
||||
|
||||
|
||||
def load_checkpoint(
|
||||
filename: Path,
|
||||
model: nn.Module,
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
model_ema: Optional[nn.Module] = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
labeled_sampler: Optional[CutSampler] = None,
|
||||
unlabeled_sampler: Optional[CutSampler] = None,
|
||||
strict: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
TODO: document it
|
||||
"""
|
||||
logging.info(f"Loading checkpoint from {filename}")
|
||||
checkpoint = torch.load(filename, map_location="cpu")
|
||||
|
||||
if next(iter(checkpoint["model"])).startswith("module."):
|
||||
logging.info("Loading checkpoint saved by DDP")
|
||||
|
||||
dst_state_dict = model.state_dict()
|
||||
src_state_dict = checkpoint["model"]
|
||||
for key in dst_state_dict.keys():
|
||||
src_key = "{}.{}".format("module", key)
|
||||
dst_state_dict[key] = src_state_dict.pop(src_key)
|
||||
assert len(src_state_dict) == 0
|
||||
model.load_state_dict(dst_state_dict, strict=strict)
|
||||
else:
|
||||
model.load_state_dict(checkpoint["model"], strict=strict)
|
||||
|
||||
checkpoint.pop("model")
|
||||
|
||||
if model_avg is not None and "model_avg" in checkpoint:
|
||||
logging.info("Loading averaged model")
|
||||
model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)
|
||||
checkpoint.pop("model_avg")
|
||||
|
||||
if model_ema is not None and "model_ema" in checkpoint:
|
||||
logging.info("Loading ema model")
|
||||
model_ema.model.load_state_dict(checkpoint["model_ema"], strict=strict)
|
||||
checkpoint.pop("model_ema")
|
||||
|
||||
def load(name, obj):
|
||||
s = checkpoint.get(name, None)
|
||||
if obj and s:
|
||||
obj.load_state_dict(s)
|
||||
checkpoint.pop(name)
|
||||
|
||||
load("optimizer", optimizer)
|
||||
load("scheduler", scheduler)
|
||||
load("grad_scaler", scaler)
|
||||
load("labeled_sampler", labeled_sampler)
|
||||
load("unlabeled_sampler", unlabeled_sampler)
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def save_checkpoint_with_global_batch_idx(
|
||||
out_dir: Path,
|
||||
global_batch_idx: int,
|
||||
model: Union[nn.Module, DDP],
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
model_ema: Optional[nn.Module] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
optimizer: Optional[Optimizer] = None,
|
||||
scheduler: Optional[LRSchedulerType] = None,
|
||||
scaler: Optional[GradScaler] = None,
|
||||
labeled_sampler: Optional[CutSampler] = None,
|
||||
unlabeled_sampler: Optional[CutSampler] = None,
|
||||
rank: int = 0,
|
||||
):
|
||||
"""Save training info after processing given number of batches.
|
||||
|
||||
Args:
|
||||
out_dir:
|
||||
The directory to save the checkpoint.
|
||||
global_batch_idx:
|
||||
The number of batches processed so far from the very start of the
|
||||
training. The saved checkpoint will have the following filename:
|
||||
|
||||
f'out_dir / checkpoint-{global_batch_idx}.pt'
|
||||
model:
|
||||
The neural network model whose `state_dict` will be saved in the
|
||||
checkpoint.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
model_ema:
|
||||
The EMA version of model.
|
||||
params:
|
||||
A dict of training configurations to be saved.
|
||||
optimizer:
|
||||
The optimizer used in the training. Its `state_dict` will be saved.
|
||||
scheduler:
|
||||
The learning rate scheduler used in the training. Its `state_dict` will
|
||||
be saved.
|
||||
scaler:
|
||||
The scaler used for mix precision training. Its `state_dict` will
|
||||
be saved.
|
||||
labeled_sampler:
|
||||
The sampler used in the labeled training dataset. We only save its `state_dict()`.
|
||||
unlabeled_sampler:
|
||||
The sampler used in the unlabeled training dataset. We only save its `state_dict()`.
|
||||
rank:
|
||||
The rank ID used in DDP training of the current node. Set it to 0
|
||||
if DDP is not used.
|
||||
"""
|
||||
out_dir = Path(out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
filename = out_dir / f"checkpoint-{global_batch_idx}.pt"
|
||||
save_checkpoint(
|
||||
filename=filename,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
model_ema=model_ema,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
scaler=scaler,
|
||||
labeled_sampler=labeled_sampler,
|
||||
unlabeled_sampler=unlabeled_sampler,
|
||||
rank=rank,
|
||||
)
|
225
egs/librispeech/PL/zipformer/dataset.py
Normal file
225
egs/librispeech/PL/zipformer/dataset.py
Normal file
@ -0,0 +1,225 @@
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
import torch
|
||||
from torch.utils.data.dataloader import DataLoader, default_collate
|
||||
|
||||
from lhotse import validate
|
||||
from lhotse.cut import CutSet
|
||||
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
|
||||
from lhotse.utils import compute_num_frames, ifnone
|
||||
from lhotse.workarounds import Hdf5MemoryIssueFix
|
||||
import copy
|
||||
|
||||
|
||||
class K2UnlabeledSpeechRecognitionDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
The PyTorch Dataset for the speech recognition task using k2 library.
|
||||
|
||||
This dataset expects to be queried with lists of cut IDs,
|
||||
for which it loads features and automatically collates/batches them.
|
||||
|
||||
To use it with a PyTorch DataLoader, set ``batch_size=None``
|
||||
and provide a :class:`SimpleCutSampler` sampler.
|
||||
|
||||
Each item in this dataset is a dict of:
|
||||
|
||||
.. code-block::
|
||||
|
||||
{
|
||||
'inputs': float tensor with shape determined by :attr:`input_strategy`:
|
||||
- single-channel:
|
||||
- features: (B, T, F)
|
||||
- audio: (B, T)
|
||||
- multi-channel: currently not supported
|
||||
'supervisions': [
|
||||
{
|
||||
'sequence_idx': Tensor[int] of shape (S,)
|
||||
'text': List[str] of len S
|
||||
|
||||
# For feature input strategies
|
||||
'start_frame': Tensor[int] of shape (S,)
|
||||
'num_frames': Tensor[int] of shape (S,)
|
||||
|
||||
# For audio input strategies
|
||||
'start_sample': Tensor[int] of shape (S,)
|
||||
'num_samples': Tensor[int] of shape (S,)
|
||||
|
||||
# Optionally, when return_cuts=True
|
||||
'cut': List[AnyCut] of len S
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Dimension symbols legend:
|
||||
* ``B`` - batch size (number of Cuts)
|
||||
* ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions)
|
||||
* ``T`` - number of frames of the longest Cut
|
||||
* ``F`` - number of features
|
||||
|
||||
The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
return_cuts: bool = False,
|
||||
cut_transforms: List[Callable[[CutSet], CutSet]] = None,
|
||||
input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None,
|
||||
input_strategy: BatchIO = PrecomputedFeatures(),
|
||||
):
|
||||
"""
|
||||
k2 ASR IterableDataset constructor.
|
||||
|
||||
:param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut
|
||||
objects used to create that batch.
|
||||
:param cut_transforms: A list of transforms to be applied on each sampled batch,
|
||||
before converting cuts to an input representation (audio/features).
|
||||
Examples: cut concatenation, noise cuts mixing, etc.
|
||||
:param input_transforms: A list of transforms to be applied on each sampled batch,
|
||||
after the cuts are converted to audio/features.
|
||||
Examples: normalization, SpecAugment, etc.
|
||||
:param input_strategy: Converts cuts into a collated batch of audio/features.
|
||||
By default, reads pre-computed features from disk.
|
||||
"""
|
||||
super().__init__()
|
||||
# Initialize the fields
|
||||
self.return_cuts = return_cuts
|
||||
self.cut_transforms = ifnone(cut_transforms, [])
|
||||
self.input_transforms = ifnone(input_transforms, [])
|
||||
self.input_strategy = input_strategy
|
||||
|
||||
# This attribute is a workaround to constantly growing HDF5 memory
|
||||
# throughout the epoch. It regularly closes open file handles to
|
||||
# reset the internal HDF5 caches.
|
||||
self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100)
|
||||
|
||||
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]:
|
||||
"""
|
||||
Return a new batch, with the batch size automatically determined using the constraints
|
||||
of max_frames and max_cuts.
|
||||
"""
|
||||
validate_for_asr(cuts)
|
||||
|
||||
self.hdf5_fix.update()
|
||||
|
||||
# Sort the cuts by duration so that the first one determines the batch time dimensions.
|
||||
cuts = cuts.sort_by_duration(ascending=False)
|
||||
|
||||
# Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts
|
||||
# the supervision boundaries.
|
||||
for tnfm in self.cut_transforms:
|
||||
cuts = tnfm(cuts)
|
||||
|
||||
# Sort the cuts again after transforms
|
||||
cuts = cuts.sort_by_duration(ascending=False)
|
||||
|
||||
# Get a tensor with batched feature matrices, shape (B, T, F)
|
||||
# Collation performs auto-padding, if necessary.
|
||||
input_tpl = self.input_strategy(cuts)
|
||||
if len(input_tpl) == 3:
|
||||
# An input strategy with fault tolerant audio reading mode.
|
||||
# "cuts" may be a subset of the original "cuts" variable,
|
||||
# that only has cuts for which we succesfully read the audio.
|
||||
inputs, _, cuts = input_tpl
|
||||
else:
|
||||
inputs, _ = input_tpl
|
||||
|
||||
# Get a dict of tensors that encode the positional information about supervisions
|
||||
# in the batch of feature matrices. The tensors are named "sequence_idx",
|
||||
# "start_frame/sample" and "num_frames/samples".
|
||||
supervision_intervals = self.input_strategy.supervision_intervals(cuts)
|
||||
|
||||
# Apply all available transforms on the inputs, i.e. either audio or features.
|
||||
# This could be feature extraction, global MVN, SpecAugment, etc.
|
||||
segments = torch.stack(list(supervision_intervals.values()), dim=1)
|
||||
inputs_orig = copy.deepcopy(inputs)
|
||||
for tnfm in self.input_transforms:
|
||||
inputs = tnfm(inputs, supervision_segments=segments)
|
||||
|
||||
batch = {
|
||||
"inputs": inputs,
|
||||
"inputs_orig": inputs_orig,
|
||||
"supervisions": default_collate(
|
||||
[
|
||||
{
|
||||
"text": "",
|
||||
}
|
||||
for sequence_idx, cut in enumerate(cuts)
|
||||
for supervision in cut.supervisions
|
||||
]
|
||||
),
|
||||
}
|
||||
# Update the 'supervisions' field with sequence_idx and start/num frames/samples
|
||||
batch["supervisions"].update(supervision_intervals)
|
||||
if self.return_cuts:
|
||||
batch["supervisions"]["cut"] = [
|
||||
cut for cut in cuts for sup in cut.supervisions
|
||||
]
|
||||
|
||||
has_word_alignments = all(
|
||||
s.alignment is not None and "word" in s.alignment
|
||||
for c in cuts
|
||||
for s in c.supervisions
|
||||
)
|
||||
if has_word_alignments:
|
||||
# TODO: might need to refactor BatchIO API to move the following conditional logic
|
||||
# into these objects (e.g. use like: self.input_strategy.convert_timestamp(),
|
||||
# that returns either num_frames or num_samples depending on the strategy).
|
||||
words, starts, ends = [], [], []
|
||||
frame_shift = cuts[0].frame_shift
|
||||
sampling_rate = cuts[0].sampling_rate
|
||||
if frame_shift is None:
|
||||
try:
|
||||
frame_shift = self.input_strategy.extractor.frame_shift
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. "
|
||||
)
|
||||
for c in cuts:
|
||||
for s in c.supervisions:
|
||||
words.append([aliword.symbol for aliword in s.alignment["word"]])
|
||||
starts.append(
|
||||
[
|
||||
compute_num_frames(
|
||||
aliword.start,
|
||||
frame_shift=frame_shift,
|
||||
sampling_rate=sampling_rate,
|
||||
)
|
||||
for aliword in s.alignment["word"]
|
||||
]
|
||||
)
|
||||
ends.append(
|
||||
[
|
||||
compute_num_frames(
|
||||
aliword.end,
|
||||
frame_shift=frame_shift,
|
||||
sampling_rate=sampling_rate,
|
||||
)
|
||||
for aliword in s.alignment["word"]
|
||||
]
|
||||
)
|
||||
batch["supervisions"]["word"] = words
|
||||
batch["supervisions"]["word_start"] = starts
|
||||
batch["supervisions"]["word_end"] = ends
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
def validate_for_asr(cuts: CutSet) -> None:
|
||||
validate(cuts)
|
||||
tol = 2e-3 # 1ms
|
||||
for cut in cuts:
|
||||
for supervision in cut.supervisions:
|
||||
assert supervision.start >= -tol, (
|
||||
f"Supervisions starting before the cut are not supported for ASR"
|
||||
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
||||
)
|
||||
|
||||
# Supervision start time is relative to Cut ...
|
||||
# https://lhotse.readthedocs.io/en/v0.10_e/cuts.html
|
||||
#
|
||||
# 'supervision.end' is end of supervision inside the Cut
|
||||
assert supervision.end <= cut.duration + tol, (
|
||||
f"Supervisions ending after the cut "
|
||||
f"are not supported for ASR"
|
||||
f" (sup id: {supervision.id}, cut id: {cut.id})"
|
||||
)
|
1069
egs/librispeech/PL/zipformer/decode.py
Executable file
1069
egs/librispeech/PL/zipformer/decode.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/PL/zipformer/decoder.py
Symbolic link
1
egs/librispeech/PL/zipformer/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/decoder.py
|
1
egs/librispeech/PL/zipformer/encoder_interface.py
Symbolic link
1
egs/librispeech/PL/zipformer/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/encoder_interface.py
|
193
egs/librispeech/PL/zipformer/generate_averaged_model.py
Executable file
193
egs/librispeech/PL/zipformer/generate_averaged_model.py
Executable file
@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Yifan Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage:
|
||||
(1) use the checkpoint exp_dir/epoch-xxx.pt
|
||||
./zipformer/generate_averaged_model.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./zipformer/exp
|
||||
|
||||
It will generate a file `epoch-28-avg-15.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("epoch-28-avg-15.pt")`.
|
||||
|
||||
(2) use the checkpoint exp_dir/checkpoint-iter.pt
|
||||
./zipformer/generate_averaged_model.py \
|
||||
--iter 22000 \
|
||||
--avg 5 \
|
||||
--exp-dir ./zipformer/exp
|
||||
|
||||
It will generate a file `iter-22000-avg-5.pt` in the given `exp_dir`.
|
||||
You can later load it by `torch.load("iter-22000-avg-5.pt")`.
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from train_seed import add_model_arguments, get_model, get_params
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, find_checkpoints
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=9,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="zipformer/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/tokens.txt",
|
||||
help="Path to the tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
print("Script started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
print(f"Device: {device}")
|
||||
|
||||
symbol_table = k2.SymbolTable.from_file(params.tokens)
|
||||
params.blank_id = symbol_table["<blk>"]
|
||||
params.unk_id = symbol_table["<unk>"]
|
||||
params.vocab_size = len(symbol_table) - 2
|
||||
|
||||
print("About to create model")
|
||||
model = get_model(params)
|
||||
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
print(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
filename = params.exp_dir / f"iter-{params.iter}-avg-{params.avg}.pt"
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
print(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
filename = params.exp_dir / f"epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save({"model": model.state_dict()}, filename)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
print(f"Number of model parameters: {num_param}")
|
||||
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/librispeech/PL/zipformer/joiner.py
Symbolic link
1
egs/librispeech/PL/zipformer/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/joiner.py
|
360
egs/librispeech/PL/zipformer/model.py
Normal file
360
egs/librispeech/PL/zipformer/model.py
Normal file
@ -0,0 +1,360 @@
|
||||
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
|
||||
|
||||
class AsrModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
encoder_embed: nn.Module,
|
||||
encoder: EncoderInterface,
|
||||
decoder: Optional[nn.Module] = None,
|
||||
joiner: Optional[nn.Module] = None,
|
||||
encoder_dim: int = 384,
|
||||
decoder_dim: int = 512,
|
||||
vocab_size: int = 500,
|
||||
use_transducer: bool = True,
|
||||
use_ctc: bool = False,
|
||||
):
|
||||
"""A joint CTC & Transducer ASR model.
|
||||
|
||||
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf)
|
||||
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf)
|
||||
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf)
|
||||
|
||||
Args:
|
||||
encoder_embed:
|
||||
It is a Convolutional 2D subsampling module. It converts
|
||||
an input of shape (N, T, idim) to an output of of shape
|
||||
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
|
||||
`logit_lens` of shape (N,).
|
||||
decoder:
|
||||
It is the prediction network in the paper. Its input shape
|
||||
is (N, U) and its output shape is (N, U, decoder_dim).
|
||||
It should contain one attribute: `blank_id`.
|
||||
It is used when use_transducer is True.
|
||||
joiner:
|
||||
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim).
|
||||
Its output shape is (N, T, U, vocab_size). Note that its output contains
|
||||
unnormalized probs, i.e., not processed by log-softmax.
|
||||
It is used when use_transducer is True.
|
||||
use_transducer:
|
||||
Whether use transducer head. Default: True.
|
||||
use_ctc:
|
||||
Whether use CTC head. Default: False.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
assert (
|
||||
use_transducer or use_ctc
|
||||
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}"
|
||||
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder = encoder
|
||||
|
||||
self.use_transducer = use_transducer
|
||||
if use_transducer:
|
||||
# Modules for Transducer head
|
||||
assert decoder is not None
|
||||
assert hasattr(decoder, "blank_id")
|
||||
assert joiner is not None
|
||||
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
self.simple_am_proj = ScaledLinear(
|
||||
encoder_dim, vocab_size, initial_scale=0.25
|
||||
)
|
||||
self.simple_lm_proj = ScaledLinear(
|
||||
decoder_dim, vocab_size, initial_scale=0.25
|
||||
)
|
||||
else:
|
||||
assert decoder is None
|
||||
assert joiner is None
|
||||
|
||||
self.use_ctc = use_ctc
|
||||
if use_ctc:
|
||||
# Modules for CTC head
|
||||
self.ctc_output = nn.Sequential(
|
||||
nn.Dropout(p=0.1),
|
||||
nn.Linear(encoder_dim, vocab_size),
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
|
||||
def forward_encoder(
|
||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute encoder outputs.
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
|
||||
Returns:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
"""
|
||||
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def forward_ctc(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
target_lengths: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute CTC loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
targets:
|
||||
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||
to be un-padded and concatenated within 1 dimension.
|
||||
"""
|
||||
# Compute CTC log-prob
|
||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
ctc_loss = torch.nn.functional.ctc_loss(
|
||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||
targets=targets.cpu(),
|
||||
input_lengths=encoder_out_lens.cpu(),
|
||||
target_lengths=target_lengths.cpu(),
|
||||
reduction="sum",
|
||||
)
|
||||
return ctc_loss
|
||||
|
||||
def forward_transducer(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
y_lens: torch.Tensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute Transducer loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
"""
|
||||
# Now for the decoder, i.e., the prediction network
|
||||
blank_id = self.decoder.blank_id
|
||||
sos_y = add_sos(y, sos_id=blank_id)
|
||||
|
||||
# sos_y_padded: [B, S + 1], start with SOS.
|
||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||
|
||||
# decoder_out: [B, S + 1, decoder_dim]
|
||||
decoder_out = self.decoder(sos_y_padded)
|
||||
|
||||
# Note: y does not start with SOS
|
||||
# y_padded : [B, S]
|
||||
y_padded = y.pad(mode="constant", padding_value=0)
|
||||
|
||||
y_padded = y_padded.to(torch.int64)
|
||||
boundary = torch.zeros(
|
||||
(encoder_out.size(0), 4),
|
||||
dtype=torch.int64,
|
||||
device=encoder_out.device,
|
||||
)
|
||||
boundary[:, 2] = y_lens
|
||||
boundary[:, 3] = encoder_out_lens
|
||||
|
||||
lm = self.simple_lm_proj(decoder_out)
|
||||
am = self.simple_am_proj(encoder_out)
|
||||
|
||||
# if self.training and random.random() < 0.25:
|
||||
# lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04)
|
||||
# if self.training and random.random() < 0.25:
|
||||
# am = penalize_abs_values_gt(am, 30.0, 1.0e-04)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed(
|
||||
lm=lm.float(),
|
||||
am=am.float(),
|
||||
symbols=y_padded,
|
||||
termination_symbol=blank_id,
|
||||
lm_only_scale=lm_scale,
|
||||
am_only_scale=am_scale,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
# ranges : [B, T, prune_range]
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=prune_range,
|
||||
)
|
||||
|
||||
# am_pruned : [B, T, prune_range, encoder_dim]
|
||||
# lm_pruned : [B, T, prune_range, decoder_dim]
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
am=self.joiner.encoder_proj(encoder_out),
|
||||
lm=self.joiner.decoder_proj(decoder_out),
|
||||
ranges=ranges,
|
||||
)
|
||||
|
||||
# logits : [B, T, prune_range, vocab_size]
|
||||
|
||||
# project_input=False since we applied the decoder's input projections
|
||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||
logits = self.joiner(am_pruned, lm_pruned, project_input=False)
|
||||
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
logits=logits.float(),
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
return simple_loss, pruned_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
prune_range:
|
||||
The prune range for rnnt loss, it means how many symbols(context)
|
||||
we are considering for each frame to compute the loss.
|
||||
am_scale:
|
||||
The scale to smooth the loss with am (output of encoder network)
|
||||
part
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
Returns:
|
||||
Return the transducer losses and CTC loss,
|
||||
in form of (simple_loss, pruned_loss, ctc_loss)
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
the form:
|
||||
lm_scale * lm_probs + am_scale * am_probs +
|
||||
(1-lm_scale-am_scale) * combined_probs
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
|
||||
|
||||
device = x.device
|
||||
|
||||
# Compute encoder outputs
|
||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
if self.use_transducer:
|
||||
# Compute transducer loss
|
||||
simple_loss, pruned_loss = self.forward_transducer(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
y=y.to(device),
|
||||
y_lens=y_lens,
|
||||
prune_range=prune_range,
|
||||
am_scale=am_scale,
|
||||
lm_scale=lm_scale,
|
||||
)
|
||||
else:
|
||||
simple_loss = torch.empty(0)
|
||||
pruned_loss = torch.empty(0)
|
||||
|
||||
if self.use_ctc:
|
||||
# Compute CTC loss
|
||||
targets = y.values
|
||||
ctc_loss = self.forward_ctc(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
targets=targets,
|
||||
target_lengths=y_lens,
|
||||
)
|
||||
else:
|
||||
ctc_loss = torch.empty(0)
|
||||
|
||||
return simple_loss, pruned_loss, ctc_loss
|
1104
egs/librispeech/PL/zipformer/optim.py
Normal file
1104
egs/librispeech/PL/zipformer/optim.py
Normal file
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/PL/zipformer/scaling.py
Symbolic link
1
egs/librispeech/PL/zipformer/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/scaling.py
|
1
egs/librispeech/PL/zipformer/subsampling.py
Symbolic link
1
egs/librispeech/PL/zipformer/subsampling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/subsampling.py
|
1651
egs/librispeech/PL/zipformer/train_pl.py
Executable file
1651
egs/librispeech/PL/zipformer/train_pl.py
Executable file
File diff suppressed because it is too large
Load Diff
1366
egs/librispeech/PL/zipformer/train_seed.py
Executable file
1366
egs/librispeech/PL/zipformer/train_seed.py
Executable file
File diff suppressed because it is too large
Load Diff
1
egs/librispeech/PL/zipformer/zipformer.py
Symbolic link
1
egs/librispeech/PL/zipformer/zipformer.py
Symbolic link
@ -0,0 +1 @@
|
||||
../../ASR/zipformer/zipformer.py
|
Loading…
x
Reference in New Issue
Block a user