From 45a122555c9fc2e32be3bdd7af7a6343f47baf3b Mon Sep 17 00:00:00 2001 From: root Date: Thu, 2 May 2024 09:40:40 +0900 Subject: [PATCH] remove outdated recipes --- .../pruned_transducer_stateless2/__init__.py | 0 .../pruned_transducer_stateless2/aishell.py | 377 --- .../asr_datamodule.py | 394 --- .../beam_search.py | 2942 ----------------- .../pruned_transducer_stateless2/conformer.py | 1600 --------- .../pruned_transducer_stateless2/decode.py | 690 ---- .../decode_aishell.py | 547 --- .../decode_stream.py | 1 - .../pruned_transducer_stateless2/decoder.py | 122 - .../encoder_interface.py | 43 - .../export-onnx.py | 518 --- .../pruned_transducer_stateless2/export.py | 238 -- .../pruned_transducer_stateless2/finetune.py | 1054 ------ .../jit_pretrained.py | 336 -- .../pruned_transducer_stateless2/joiner.py | 67 - .../ASR/pruned_transducer_stateless2/lstmp.py | 102 - .../ASR/pruned_transducer_stateless2/model.py | 207 -- .../onnx_check.py | 303 -- .../onnx_pretrained.py | 428 --- .../ASR/pruned_transducer_stateless2/optim.py | 319 -- .../pretrained.py | 339 -- .../pruned_transducer_stateless2/scaling.py | 1014 ------ .../scaling_converter.py | 320 -- .../streaming_beam_search.py | 1 - .../streaming_decode.py | 1 - .../test_model.py | 1 - .../pruned_transducer_stateless2/tokenizer.py | 1 - .../ASR/pruned_transducer_stateless2/train.py | 1061 ------ .../asr_datamodule.py | 1 - .../beam_search.py | 1 - .../decode.py | 846 ----- .../decode_stream.py | 1 - .../decoder.py | 1 - .../do_not_use_it_directly.py | 1261 ------- .../encoder_interface.py | 1 - .../export.py | 250 -- .../joiner.py | 1 - .../model.py | 1 - .../optim.py | 1 - .../pretrained.py | 347 -- .../scaling.py | 1 - .../scaling_converter.py | 1 - .../streaming-ncnn-decode.py | 1 - .../streaming_beam_search.py | 1 - .../streaming_decode.py | 597 ---- .../tokenizer.py | 1 - .../train.py | 1259 ------- .../zipformer.py | 1 - .../zipformer_for_ncnn_export_only.py | 1 - 49 files changed, 17601 deletions(-) delete mode 100644 egs/reazonspeech/ASR/pruned_transducer_stateless2/__init__.py delete mode 100644 egs/reazonspeech/ASR/pruned_transducer_stateless2/aishell.py delete mode 100644 egs/reazonspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py delete mode 100644 egs/reazonspeech/ASR/pruned_transducer_stateless2/beam_search.py delete mode 100644 egs/reazonspeech/ASR/pruned_transducer_stateless2/conformer.py delete mode 100755 egs/reazonspeech/ASR/pruned_transducer_stateless2/decode.py delete mode 100755 egs/reazonspeech/ASR/pruned_transducer_stateless2/decode_aishell.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless2/decode_stream.py delete mode 100644 egs/reazonspeech/ASR/pruned_transducer_stateless2/decoder.py delete mode 100644 egs/reazonspeech/ASR/pruned_transducer_stateless2/encoder_interface.py delete mode 100755 egs/reazonspeech/ASR/pruned_transducer_stateless2/export-onnx.py delete mode 100755 egs/reazonspeech/ASR/pruned_transducer_stateless2/export.py delete mode 100755 egs/reazonspeech/ASR/pruned_transducer_stateless2/finetune.py delete mode 100755 egs/reazonspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py delete mode 100644 egs/reazonspeech/ASR/pruned_transducer_stateless2/joiner.py delete mode 100644 egs/reazonspeech/ASR/pruned_transducer_stateless2/lstmp.py delete mode 100644 egs/reazonspeech/ASR/pruned_transducer_stateless2/model.py delete mode 100755 egs/reazonspeech/ASR/pruned_transducer_stateless2/onnx_check.py delete mode 100755 egs/reazonspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py delete mode 100644 egs/reazonspeech/ASR/pruned_transducer_stateless2/optim.py delete mode 100755 egs/reazonspeech/ASR/pruned_transducer_stateless2/pretrained.py delete mode 100644 egs/reazonspeech/ASR/pruned_transducer_stateless2/scaling.py delete mode 100644 egs/reazonspeech/ASR/pruned_transducer_stateless2/scaling_converter.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless2/streaming_decode.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless2/test_model.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless2/tokenizer.py delete mode 100755 egs/reazonspeech/ASR/pruned_transducer_stateless2/train.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py delete mode 100755 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/decode.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/decoder.py delete mode 100755 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py delete mode 100644 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/export.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/joiner.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/model.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/optim.py delete mode 100644 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/scaling.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py delete mode 100755 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/tokenizer.py delete mode 100755 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/train.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py delete mode 120000 egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/__init__.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/aishell.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/aishell.py deleted file mode 100644 index 6abe6c084..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/aishell.py +++ /dev/null @@ -1,377 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, List, Optional - -from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class AishellAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--drop-last", - type=str2bool, - default=True, - help="Whether to drop last batch. Used by sampler.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - group.add_argument( - "--enable-musan", - type=str2bool, - default=True, - help="When enabled, select noise from MUSAN and mix it" - "with training dataset. ", - ) - - def train_dataloaders( - self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - - transforms = [] - if self.args.enable_musan: - logging.info("Enable MUSAN") - transforms.append( - CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - ) - else: - logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - drop_last=self.args.drop_last, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_sampler.load_state_dict(sampler_state_dict) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - valid_dl = DataLoader( - validate, - sampler=valid_sampler, - batch_size=None, - num_workers=2, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - cuts_train = load_manifest_lazy( - self.args.manifest_dir / "aishell_cuts_train.jsonl.gz" - ) - return cuts_train - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz") - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz") diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py deleted file mode 100644 index c4116c1b8..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ /dev/null @@ -1,394 +0,0 @@ -# Copyright 2021 Piotr Żelasko -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import argparse -import inspect -import logging -from functools import lru_cache -from pathlib import Path -from typing import Any, Dict, List, Optional - -import torch -from lhotse import ( - CutSet, - Fbank, - FbankConfig, - load_manifest, - load_manifest_lazy, - set_caching_enabled, -) -from lhotse.dataset import ( - CutConcatenate, - CutMix, - DynamicBucketingSampler, - K2SpeechRecognitionDataset, - PrecomputedFeatures, - SimpleCutSampler, - SpecAugment, -) -from lhotse.dataset.input_strategies import OnTheFlyFeatures -from lhotse.utils import fix_random_seed -from torch.utils.data import DataLoader - -from icefall.utils import str2bool - - -class _SeedWorkers: - def __init__(self, seed: int): - self.seed = seed - - def __call__(self, worker_id: int): - fix_random_seed(self.seed + worker_id) - - -class ReazonSpeechAsrDataModule: - """ - DataModule for k2 ASR experiments. - It assumes there is always one train and valid dataloader, - but there can be multiple test dataloaders (e.g. LibriSpeech test-clean - and test-other). - It contains all the common data pipeline modules used in ASR - experiments, e.g.: - - dynamic batch size, - - bucketing samplers, - - cut concatenation, - - augmentation, - - on-the-fly feature extraction - This class should be derived for specific corpora used in ASR tasks. - """ - - def __init__(self, args: argparse.Namespace): - self.args = args - - @classmethod - def add_arguments(cls, parser: argparse.ArgumentParser): - group = parser.add_argument_group( - title="ASR data related options", - description="These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc.", - ) - group.add_argument( - "--manifest-dir", - type=Path, - default=Path("data/fbank"), - help="Path to directory with train/valid/test cuts.", - ) - group.add_argument( - "--max-duration", - type=int, - default=200.0, - help="Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM.", - ) - group.add_argument( - "--bucketing-sampler", - type=str2bool, - default=True, - help="When enabled, the batches will come from buckets of " - "similar duration (saves padding frames).", - ) - group.add_argument( - "--num-buckets", - type=int, - default=30, - help="The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets).", - ) - group.add_argument( - "--concatenate-cuts", - type=str2bool, - default=False, - help="When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding.", - ) - group.add_argument( - "--duration-factor", - type=float, - default=1.0, - help="Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch.", - ) - group.add_argument( - "--gap", - type=float, - default=1.0, - help="The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used.", - ) - group.add_argument( - "--on-the-fly-feats", - type=str2bool, - default=False, - help="When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available.", - ) - group.add_argument( - "--shuffle", - type=str2bool, - default=True, - help="When enabled (=default), the examples will be " - "shuffled for each epoch.", - ) - group.add_argument( - "--return-cuts", - type=str2bool, - default=True, - help="When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it.", - ) - - group.add_argument( - "--num-workers", - type=int, - default=2, - help="The number of training dataloader workers that " - "collect the batches.", - ) - - group.add_argument( - "--enable-spec-aug", - type=str2bool, - default=True, - help="When enabled, use SpecAugment for training dataset.", - ) - - group.add_argument( - "--spec-aug-time-warp-factor", - type=int, - default=80, - help="Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp.", - ) - - # group.add_argument( - # "--enable-musan", - # type=str2bool, - # default=True, - # help="When enabled, select noise from MUSAN and mix it" - # "with training dataset. ", - # ) - - def train_dataloaders( - self, - cuts_train: CutSet, - sampler_state_dict: Optional[Dict[str, Any]] = None, - ) -> DataLoader: - """ - Args: - cuts_train: - CutSet for training. - sampler_state_dict: - The state dict for the training sampler. - """ - # logging.info("About to get Musan cuts") - # cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") - - transforms = [] - # if self.args.enable_musan: - # logging.info("Enable MUSAN") - # transforms.append( - # CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) - # ) - # else: - # logging.info("Disable MUSAN") - - if self.args.concatenate_cuts: - logging.info( - f"Using cut concatenation with duration factor " - f"{self.args.duration_factor} and gap {self.args.gap}." - ) - # Cut concatenation should be the first transform in the list, - # so that if we e.g. mix noise in, it will fill the gaps between - # different utterances. - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - input_transforms = [] - if self.args.enable_spec_aug: - logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") - # Set the value of num_frame_masks according to Lhotse's version. - # In different Lhotse's versions, the default of num_frame_masks is - # different. - num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature( - SpecAugment.__init__ - ).parameters["num_frame_masks"] - if num_frame_masks_parameter.default == 1: - num_frame_masks = 2 - logging.info(f"Num frame mask: {num_frame_masks}") - input_transforms.append( - SpecAugment( - time_warp_factor=self.args.spec_aug_time_warp_factor, - num_frame_masks=num_frame_masks, - features_mask_size=27, - num_feature_masks=2, - frames_mask_size=100, - ) - ) - else: - logging.info("Disable SpecAugment") - - logging.info("About to create train dataset") - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.on_the_fly_feats: - # NOTE: the PerturbSpeed transform should be added only if we - # remove it from data prep stage. - # Add on-the-fly speed perturbation; since originally it would - # have increased epoch size by 3, we will apply prob 2/3 and use - # 3x more epochs. - # Speed perturbation probably should come first before - # concatenation, but in principle the transforms order doesn't have - # to be strict (e.g. could be randomized) - # transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa - # Drop feats to be on the safe side. - train = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - input_transforms=input_transforms, - return_cuts=self.args.return_cuts, - ) - - if self.args.bucketing_sampler: - logging.info("Using DynamicBucketingSampler.") - train_sampler = DynamicBucketingSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - num_buckets=self.args.num_buckets, - buffer_size=300000, - drop_last=True, - ) - else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - ) - logging.info("About to create train dataloader") - - # 'seed' is derived from the current random state, which will have - # previously been set in the main process. - seed = torch.randint(0, 100000, ()).item() - worker_init_fn = _SeedWorkers(seed) - - train_dl = DataLoader( - train, - sampler=train_sampler, - batch_size=None, - num_workers=self.args.num_workers, - persistent_workers=False, - worker_init_fn=worker_init_fn, - ) - - if sampler_state_dict is not None: - logging.info("Loading sampler state dict") - train_dl.sampler.load_state_dict(sampler_state_dict) - - return train_dl - - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: - transforms = [] - if self.args.concatenate_cuts: - transforms = [ - CutConcatenate( - duration_factor=self.args.duration_factor, gap=self.args.gap - ) - ] + transforms - - logging.info("About to create dev dataset") - if self.args.on_the_fly_feats: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), - return_cuts=self.args.return_cuts, - ) - else: - validate = K2SpeechRecognitionDataset( - cut_transforms=transforms, - return_cuts=self.args.return_cuts, - ) - - valid_sampler = DynamicBucketingSampler( - cuts_valid, - max_duration=self.args.max_duration, - shuffle=False, - ) - logging.info("About to create dev dataloader") - - valid_dl = DataLoader( - validate, - batch_size=None, - sampler=valid_sampler, - num_workers=self.args.num_workers, - persistent_workers=False, - ) - - return valid_dl - - def test_dataloaders(self, cuts: CutSet) -> DataLoader: - logging.info("About to create test dataset") - test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else PrecomputedFeatures(), - return_cuts=self.args.return_cuts, - ) - sampler = DynamicBucketingSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - ) - - test_dl = DataLoader( - test, - batch_size=None, - sampler=sampler, - num_workers=self.args.num_workers, - ) - return test_dl - - @lru_cache() - def train_cuts(self) -> CutSet: - logging.info("About to get train cuts") - return load_manifest_lazy(self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz") - - @lru_cache() - def valid_cuts(self) -> CutSet: - logging.info("About to get dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz") - - @lru_cache() - def test_cuts(self) -> List[CutSet]: - logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz") diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/beam_search.py deleted file mode 100644 index 7fcd242fc..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/beam_search.py +++ /dev/null @@ -1,2942 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Xiaoyu Yang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import warnings -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Tuple, Union - -import k2 -import sentencepiece as spm -import torch -from torch import nn - -from icefall import ContextGraph, ContextState, NgramLm, NgramLmStateCost -from icefall.decode import Nbest, one_best_decoding -from icefall.lm_wrapper import LmScorer -from icefall.rnn_lm.model import RnnLmModel -from icefall.transformer_lm.model import TransformerLM -from icefall.utils import ( - DecodingResults, - add_eos, - add_sos, - get_texts, - get_texts_with_timestamp, -) - - -def fast_beam_search_one_best( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - temperature: float = 1.0, - ilme_scale: float = 0.0, - blank_penalty: float = 0.0, - return_timestamps: bool = False, - allow_partial: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using fast beam search, and then - the shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ilme_scale=ilme_scale, - allow_partial=allow_partial, - blank_penalty=blank_penalty, - ) - - best_path = one_best_decoding(lattice) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest_LG( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - nbest_scale: float = 0.5, - use_double_scores: bool = True, - temperature: float = 1.0, - blank_penalty: float = 0.0, - ilme_scale: float = 0.0, - return_timestamps: bool = False, - allow_partial: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - The process to get the results is: - - (1) Use fast beam search to get a lattice - - (2) Select `num_paths` paths from the lattice using k2.random_paths() - - (3) Unique the selected paths - - (4) Intersect the selected paths with the lattice and compute the - shortest path from the intersection result - - (5) The path with the largest score is used as the decoding output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - use_double_scores: - True to use double precision for computation. False to use - single precision. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - allow_partial=allow_partial, - blank_penalty=blank_penalty, - ilme_scale=ilme_scale, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - # The following code is modified from nbest.intersect() - word_fsa = k2.invert(nbest.fsa) - if hasattr(lattice, "aux_labels"): - # delete token IDs as it is not needed - del word_fsa.aux_labels - word_fsa.scores.zero_() - word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) - path_to_utt_map = nbest.shape.row_ids(1) - - if hasattr(lattice, "aux_labels"): - # lattice has token IDs as labels and word IDs as aux_labels. - # inv_lattice has word IDs as labels and token IDs as aux_labels - inv_lattice = k2.invert(lattice) - inv_lattice = k2.arc_sort(inv_lattice) - else: - inv_lattice = k2.arc_sort(lattice) - - if inv_lattice.shape[0] == 1: - path_lattice = k2.intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=torch.zeros_like(path_to_utt_map), - sorted_match_a=True, - ) - else: - path_lattice = k2.intersect_device( - inv_lattice, - word_fsa_with_epsilon_loops, - b_to_a_map=path_to_utt_map, - sorted_match_a=True, - ) - - # path_lattice has word IDs as labels and token IDs as aux_labels - path_lattice = k2.top_sort(k2.connect(path_lattice)) - tot_scores = path_lattice.get_tot_scores( - use_double_scores=use_double_scores, - log_semiring=True, # Note: we always use True - ) - # See https://github.com/k2-fsa/icefall/pull/420 for why - # we always use log_semiring=True - - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - best_hyp_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, best_hyp_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - nbest_scale: float = 0.5, - use_double_scores: bool = True, - temperature: float = 1.0, - blank_penalty: float = 0.0, - return_timestamps: bool = False, - allow_partial: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - The process to get the results is: - - (1) Use fast beam search to get a lattice - - (2) Select `num_paths` paths from the lattice using k2.random_paths() - - (3) Unique the selected paths - - (4) Intersect the selected paths with the lattice and compute the - shortest path from the intersection result - - (5) The path with the largest score is used as the decoding output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - use_double_scores: - True to use double precision for computation. False to use - single precision. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - blank_penalty=blank_penalty, - temperature=temperature, - allow_partial=allow_partial, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - max_indexes = nbest.tot_scores().argmax() - - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search_nbest_oracle( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - num_paths: int, - ref_texts: List[List[int]], - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - blank_penalty: float = 0.0, - return_timestamps: bool = False, - allow_partial: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - A lattice is first obtained using fast beam search, and then - we select `num_paths` linear paths from the lattice. The path - that has the minimum edit distance with the given reference transcript - is used as the output. - - This is the best result we can achieve for any nbest based rescoring - methods. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - num_paths: - Number of paths to extract from the decoded lattice. - ref_texts: - A list-of-list of integers containing the reference transcripts. - If the decoding_graph is a trivial_graph, the integer ID is the - BPE token ID. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - allow_partial=allow_partial, - blank_penalty=blank_penalty, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - - hyps = nbest.build_levenshtein_graphs() - refs = k2.levenshtein_graph(ref_texts, device=hyps.device) - - levenshtein_alignment = k2.levenshtein_alignment( - refs=refs, - hyps=hyps, - hyp_to_ref_map=nbest.shape.row_ids(1), - sorted_match_ref=True, - ) - - tot_scores = levenshtein_alignment.get_tot_scores( - use_double_scores=False, log_semiring=False - ) - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - - max_indexes = ragged_tot_scores.argmax() - - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - return get_texts(best_path) - else: - return get_texts_with_timestamp(best_path) - - -def fast_beam_search( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - temperature: float = 1.0, - subtract_ilme: bool = False, - ilme_scale: float = 0.1, - allow_partial: bool = False, - blank_penalty: float = 0.0, -) -> k2.Fsa: - """It limits the maximum number of symbols per frame to 1. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi.. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - temperature: - Softmax temperature. - Returns: - Return an FsaVec with axes [utt][state][arc] containing the decoded - lattice. Note: When the input graph is a TrivialGraph, the returned - lattice is actually an acceptor. - """ - assert encoder_out.ndim == 3 - - context_size = model.decoder.context_size - vocab_size = model.decoder.vocab_size - - B, T, C = encoder_out.shape - - config = k2.RnntDecodingConfig( - vocab_size=vocab_size, - decoder_history_len=context_size, - beam=beam, - max_contexts=max_contexts, - max_states=max_states, - ) - individual_streams = [] - for i in range(B): - individual_streams.append(k2.RnntDecodingStream(decoding_graph)) - decoding_streams = k2.RnntDecodingStreams(individual_streams, config) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # shape is a RaggedShape of shape (B, context) - # contexts is a Tensor of shape (shape.NumElements(), context_size) - shape, contexts = decoding_streams.get_contexts() - # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 - contexts = contexts.to(torch.int64) - # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) - decoder_out = model.decoder(contexts, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # current_encoder_out is of shape - # (shape.NumElements(), 1, joiner_dim) - # fmt: off - current_encoder_out = torch.index_select( - encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) - ) - # fmt: on - logits = model.joiner( - current_encoder_out.unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - logits = logits.squeeze(1).squeeze(1) - - if blank_penalty != 0: - logits[:, 0] -= blank_penalty - - log_probs = (logits / temperature).log_softmax(dim=-1) - - if ilme_scale != 0: - ilme_logits = model.joiner( - torch.zeros_like( - current_encoder_out, device=current_encoder_out.device - ).unsqueeze(2), - decoder_out.unsqueeze(1), - project_input=False, - ) - ilme_logits = ilme_logits.squeeze(1).squeeze(1) - if blank_penalty != 0: - ilme_logits[:, 0] -= blank_penalty - ilme_log_probs = (ilme_logits / temperature).log_softmax(dim=-1) - log_probs -= ilme_scale * ilme_log_probs - - decoding_streams.advance(log_probs) - decoding_streams.terminate_and_flush_to_streams() - lattice = decoding_streams.format_output( - encoder_out_lens.tolist(), allow_partial=allow_partial - ) - - return lattice - - -def greedy_search( - model: nn.Module, - encoder_out: torch.Tensor, - max_sym_per_frame: int, - blank_penalty: float = 0.0, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """Greedy search for a single utterance. - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - max_sym_per_frame: - Maximum number of symbols per frame. If it is set to 0, the WER - would be 100%. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - - blank_id = model.decoder.blank_id - context_size = model.decoder.context_size - unk_id = getattr(model, "unk_id", blank_id) - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [-1] * (context_size - 1) + [blank_id], device=device, dtype=torch.int64 - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - hyp = [blank_id] * context_size - - # timestamp[i] is the frame index after subsampling - # on which hyp[i] is decoded - timestamp = [] - - # Maximum symbols per utterance. - max_sym_per_utt = 1000 - - # symbols per frame - sym_per_frame = 0 - - # symbols per utterance decoded so far - sym_per_utt = 0 - - while t < T and sym_per_utt < max_sym_per_utt: - if sym_per_frame >= max_sym_per_frame: - sym_per_frame = 0 - t += 1 - continue - - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits is (1, 1, 1, vocab_size) - - if blank_penalty != 0: - logits[:, :, :, 0] -= blank_penalty - - y = logits.argmax().item() - if y not in (blank_id, unk_id): - hyp.append(y) - timestamp.append(t) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sym_per_utt += 1 - sym_per_frame += 1 - else: - sym_per_frame = 0 - t += 1 - hyp = hyp[context_size:] # remove blanks - - if not return_timestamps: - return hyp - else: - return DecodingResults( - hyps=[hyp], - timestamps=[timestamp], - ) - - -def greedy_search_batch( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - blank_penalty: float = 0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C), where N >= 1. - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = next(model.parameters()).device - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)] - - # timestamp[n][i] is the frame index after subsampling - # on which hyp[n][i] is decoded - timestamps = [[] for _ in range(N)] - # scores[n][i] is the logits on which hyp[n][i] is decoded - scores = [[] for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out: (N, 1, decoder_out_dim) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - for t, batch_size in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = model.joiner( - current_encoder_out, decoder_out.unsqueeze(1), project_input=False - ) - # logits'shape (batch_size, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) - assert logits.ndim == 2, logits.shape - - if blank_penalty != 0: - logits[:, 0] -= blank_penalty - - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v not in (blank_id, unk_id): - hyps[i].append(v) - timestamps[i].append(t) - scores[i].append(logits[i, v].item()) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - ans_timestamps = [] - ans_scores = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(timestamps[unsorted_indices[i]]) - ans_scores.append(scores[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - scores=ans_scores, - ) - - -@dataclass -class Hypothesis: - # The predicted tokens so far. - # Newly predicted tokens are appended to `ys`. - ys: List[int] - - # The log prob of ys. - # It contains only one entry. - log_prob: torch.Tensor - - # timestamp[i] is the frame index after subsampling - # on which ys[i] is decoded - timestamp: List[int] = field(default_factory=list) - - # the lm score for next token given the current ys - lm_score: Optional[torch.Tensor] = None - - # the RNNLM states (h and c in LSTM) - state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None - - # N-gram LM state - state_cost: Optional[NgramLmStateCost] = None - - # Context graph state - context_state: Optional[ContextState] = None - - @property - def key(self) -> str: - """Return a string representation of self.ys""" - return "_".join(map(str, self.ys)) - - -class HypothesisList(object): - def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: - """ - Args: - data: - A dict of Hypotheses. Its key is its `value.key`. - """ - if data is None: - self._data = {} - else: - self._data = data - - @property - def data(self) -> Dict[str, Hypothesis]: - return self._data - - def add(self, hyp: Hypothesis) -> None: - """Add a Hypothesis to `self`. - - If `hyp` already exists in `self`, its probability is updated using - `log-sum-exp` with the existed one. - - Args: - hyp: - The hypothesis to be added. - """ - key = hyp.key - if key in self: - old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) - else: - self._data[key] = hyp - - def get_most_probable(self, length_norm: bool = False) -> Hypothesis: - """Get the most probable hypothesis, i.e., the one with - the largest `log_prob`. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - Returns: - Return the hypothesis that has the largest `log_prob`. - """ - if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) - else: - return max(self._data.values(), key=lambda hyp: hyp.log_prob) - - def remove(self, hyp: Hypothesis) -> None: - """Remove a given hypothesis. - - Caution: - `self` is modified **in-place**. - - Args: - hyp: - The hypothesis to be removed from `self`. - Note: It must be contained in `self`. Otherwise, - an exception is raised. - """ - key = hyp.key - assert key in self, f"{key} does not exist" - del self._data[key] - - def filter(self, threshold: torch.Tensor) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. - - Caution: - `self` is not modified. Instead, a new HypothesisList is returned. - - Returns: - Return a new HypothesisList containing all hypotheses from `self` - with `log_prob` being greater than the given `threshold`. - """ - ans = HypothesisList() - for _, hyp in self._data.items(): - if hyp.log_prob > threshold: - ans.add(hyp) # shallow copy - return ans - - def topk(self, k: int, length_norm: bool = False) -> "HypothesisList": - """Return the top-k hypothesis. - - Args: - length_norm: - If True, the `log_prob` of a hypothesis is normalized by the - number of tokens in it. - """ - hyps = list(self._data.items()) - - if length_norm: - hyps = sorted( - hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True - )[:k] - else: - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] - - ans = HypothesisList(dict(hyps)) - return ans - - def __contains__(self, key: str): - return key in self._data - - def __iter__(self): - return iter(self._data.values()) - - def __len__(self) -> int: - return len(self._data) - - def __str__(self) -> str: - s = [] - for key in self: - s.append(key) - return ", ".join(s) - - -def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: - """Return a ragged shape with axes [utt][num_hyps]. - - Args: - hyps: - len(hyps) == batch_size. It contains the current hypothesis for - each utterance in the batch. - Returns: - Return a ragged shape with 2 axes [utt][num_hyps]. Note that - the shape is on CPU. - """ - num_hyps = [len(h) for h in hyps] - - # torch.cumsum() is inclusive sum, so we put a 0 at the beginning - # to get exclusive sum later. - num_hyps.insert(0, 0) - - num_hyps = torch.tensor(num_hyps) - row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) - ans = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=row_splits[-1].item() - ) - return ans - - -def modified_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - context_graph: Optional[ContextGraph] = None, - beam: int = 4, - temperature: float = 1.0, - blank_penalty: float = 0.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - context_state=None if context_graph is None else context_graph.root, - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for t, batch_size in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - if blank_penalty != 0: - logits[:, 0] -= blank_penalty - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - context_score = 0 - new_context_state = None if context_graph is None else hyp.context_state - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - if context_graph is not None: - ( - context_score, - new_context_state, - ) = context_graph.forward_one_step(hyp.context_state, new_token) - - new_log_prob = topk_log_probs[k] + context_score - - new_hyp = Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - timestamp=new_timestamp, - context_state=new_context_state, - ) - B[i].add(new_hyp) - - B = B + finalized_B - - # finalize context_state, if the matched contexts do not reach final state - # we need to add the score on the corresponding backoff arc - if context_graph is not None: - finalized_B = [HypothesisList() for _ in range(len(B))] - for i, hyps in enumerate(B): - for hyp in list(hyps): - context_score, new_context_state = context_graph.finalize( - hyp.context_state - ) - finalized_B[i].add( - Hypothesis( - ys=hyp.ys, - log_prob=hyp.log_prob + context_score, - timestamp=hyp.timestamp, - context_state=new_context_state, - ) - ) - B = finalized_B - - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - ) - - -def modified_beam_search_lm_rescore( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - LM: LmScorer, - lm_scale_list: List[int], - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - Rescore the final results with RNNLM and return the one with the highest score - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - LM: - A neural network language model - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for t, batch_size in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B[i].add(new_hyp) - - B = B + finalized_B - - # get the am_scores for n-best list - hyps_shape = get_hyps_shape(B) - am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) - am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) - - # now LM rescore - # prepare input data to LM - candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] - possible_seqs = k2.RaggedTensor(candidate_seqs) - row_splits = possible_seqs.shape.row_splits(1) - sentence_token_lengths = row_splits[1:] - row_splits[:-1] - possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) - possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) - sentence_token_lengths += 1 - - x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) - y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) - x = x.to(device).to(torch.int64) - y = y.to(device).to(torch.int64) - sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) - - lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) - assert lm_scores.ndim == 2 - lm_scores = -1 * lm_scores.sum(dim=1) - - ans = {} - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - - # get the best hyp with different lm_scale - for lm_scale in lm_scale_list: - key = f"nnlm_scale_{lm_scale:.2f}" - tot_scores = am_scores.values + lm_scores * lm_scale - ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) - max_indexes = ragged_tot_scores.argmax().tolist() - unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] - hyps = [] - for idx in unsorted_indices: - hyps.append(unsorted_hyps[idx]) - - ans[key] = hyps - return ans - - -def modified_beam_search_lm_rescore_LODR( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - LM: LmScorer, - LODR_lm: NgramLm, - sp: spm.SentencePieceProcessor, - lm_scale_list: List[int], - beam: int = 4, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Union[List[List[int]], DecodingResults]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - Rescore the final results with RNNLM and return the one with the highest score - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - LM: - A neural network language model - return_timestamps: - Whether to return timestamps. - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for t, batch_size in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - - new_log_prob = topk_log_probs[k] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B[i].add(new_hyp) - - B = B + finalized_B - - # get the am_scores for n-best list - hyps_shape = get_hyps_shape(B) - am_scores = torch.tensor([hyp.log_prob.item() for b in B for hyp in b]) - am_scores = k2.RaggedTensor(value=am_scores, shape=hyps_shape).to(device) - - # now LM rescore - # prepare input data to LM - candidate_seqs = [hyp.ys[context_size:] for b in B for hyp in b] - possible_seqs = k2.RaggedTensor(candidate_seqs) - row_splits = possible_seqs.shape.row_splits(1) - sentence_token_lengths = row_splits[1:] - row_splits[:-1] - possible_seqs_with_sos = add_sos(possible_seqs, sos_id=1) - possible_seqs_with_eos = add_eos(possible_seqs, eos_id=1) - sentence_token_lengths += 1 - - x = possible_seqs_with_sos.pad(mode="constant", padding_value=blank_id) - y = possible_seqs_with_eos.pad(mode="constant", padding_value=blank_id) - x = x.to(device).to(torch.int64) - y = y.to(device).to(torch.int64) - sentence_token_lengths = sentence_token_lengths.to(device).to(torch.int64) - - lm_scores = LM.lm(x=x, y=y, lengths=sentence_token_lengths) - assert lm_scores.ndim == 2 - lm_scores = -1 * lm_scores.sum(dim=1) - - # now LODR scores - import math - - LODR_scores = [] - for seq in candidate_seqs: - tokens = " ".join(sp.id_to_piece(seq)) - LODR_scores.append(LODR_lm.score(tokens)) - LODR_scores = torch.tensor(LODR_scores).to(device) * math.log( - 10 - ) # arpa scores are 10-based - assert lm_scores.shape == LODR_scores.shape - - ans = {} - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - - LODR_scale_list = [0.05 * i for i in range(1, 20)] - # get the best hyp with different lm_scale and lodr_scale - for lm_scale in lm_scale_list: - for lodr_scale in LODR_scale_list: - key = f"nnlm_scale_{lm_scale:.2f}_lodr_scale_{lodr_scale:.2f}" - tot_scores = ( - am_scores.values / lm_scale + lm_scores - LODR_scores * lodr_scale - ) - ragged_tot_scores = k2.RaggedTensor(shape=am_scores.shape, value=tot_scores) - max_indexes = ragged_tot_scores.argmax().tolist() - unsorted_hyps = [candidate_seqs[idx] for idx in max_indexes] - hyps = [] - for idx in unsorted_indices: - hyps.append(unsorted_hyps[idx]) - - ans[key] = hyps - return ans - - -def _deprecated_modified_beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - beam: int = 4, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """It limits the maximum number of symbols per frame to 1. - - It decodes only one utterance at a time. We keep it only for reference. - The function :func:`modified_beam_search` should be preferred as it - supports batch decoding. - - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - return_timestamps: - Whether to return timestamps. - - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - T = encoder_out.size(1) - - B = HypothesisList() - B.add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - timestamp=[], - ) - ) - encoder_out = model.joiner.encoder_proj(encoder_out) - - for t in range(T): - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # current_encoder_out is of shape (1, 1, 1, encoder_out_dim) - # fmt: on - A = list(B) - B = HypothesisList() - - ys_log_probs = torch.cat([hyp.log_prob.reshape(1, 1) for hyp in A]) - # ys_log_probs is of shape (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyp in A], - device=device, - dtype=torch.int64, - ) - # decoder_input is of shape (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_output is of shape (num_hyps, 1, 1, joiner_dim) - - current_encoder_out = current_encoder_out.expand( - decoder_out.size(0), 1, 1, -1 - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) - # logits is of shape (num_hyps, 1, 1, vocab_size) - logits = logits.squeeze(1).squeeze(1) - - # now logits is of shape (num_hyps, vocab_size) - log_probs = logits.log_softmax(dim=-1) - - log_probs.add_(ys_log_probs) - - log_probs = log_probs.reshape(-1) - topk_log_probs, topk_indexes = log_probs.topk(beam) - - # topk_hyp_indexes are indexes into `A` - topk_hyp_indexes = topk_indexes // logits.size(-1) - topk_token_indexes = topk_indexes % logits.size(-1) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = topk_hyp_indexes.tolist() - topk_token_indexes = topk_token_indexes.tolist() - - for i in range(len(topk_hyp_indexes)): - hyp = A[topk_hyp_indexes[i]] - new_ys = hyp.ys[:] - new_timestamp = hyp.timestamp[:] - new_token = topk_token_indexes[i] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - new_timestamp.append(t) - new_log_prob = topk_log_probs[i] - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp - ) - B.add(new_hyp) - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - if not return_timestamps: - return ys - else: - return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) - - -def beam_search( - model: nn.Module, - encoder_out: torch.Tensor, - beam: int = 4, - temperature: float = 1.0, - blank_penalty: float = 0.0, - return_timestamps: bool = False, -) -> Union[List[int], DecodingResults]: - """ - It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf - - espnet/nets/beam_search_transducer.py#L247 is used as a reference. - - Args: - model: - An instance of `Transducer`. - encoder_out: - A tensor of shape (N, T, C) from the encoder. Support only N==1 for now. - beam: - Beam size. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - - Returns: - If return_timestamps is False, return the decoded result. - Else, return a DecodingResults object containing - decoded result and corresponding timestamps. - """ - assert encoder_out.ndim == 3 - - # support only batch_size == 1 for now - assert encoder_out.size(0) == 1, encoder_out.size(0) - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - - device = next(model.parameters()).device - - decoder_input = torch.tensor( - [blank_id] * context_size, - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - T = encoder_out.size(1) - t = 0 - - B = HypothesisList() - B.add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], log_prob=0.0, timestamp=[] - ) - ) - - max_sym_per_utt = 20000 - - sym_per_utt = 0 - - decoder_cache: Dict[str, torch.Tensor] = {} - - while t < T and sym_per_utt < max_sym_per_utt: - # fmt: off - current_encoder_out = encoder_out[:, t:t+1, :].unsqueeze(2) - # fmt: on - A = B - B = HypothesisList() - - joint_cache: Dict[str, torch.Tensor] = {} - - # TODO(fangjun): Implement prefix search to update the `log_prob` - # of hypotheses in A - - while True: - y_star = A.get_most_probable() - A.remove(y_star) - - cached_key = y_star.key - - if cached_key not in decoder_cache: - decoder_input = torch.tensor( - [y_star.ys[-context_size:]], - device=device, - dtype=torch.int64, - ).reshape(1, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False) - decoder_out = model.joiner.decoder_proj(decoder_out) - decoder_cache[cached_key] = decoder_out - else: - decoder_out = decoder_cache[cached_key] - - cached_key += f"-t-{t}" - if cached_key not in joint_cache: - logits = model.joiner( - current_encoder_out, - decoder_out.unsqueeze(1), - project_input=False, - ) - - if blank_penalty != 0: - logits[:, :, :, 0] -= blank_penalty - - # TODO(fangjun): Scale the blank posterior - log_prob = (logits / temperature).log_softmax(dim=-1) - # log_prob is (1, 1, 1, vocab_size) - log_prob = log_prob.squeeze() - # Now log_prob is (vocab_size,) - joint_cache[cached_key] = log_prob - else: - log_prob = joint_cache[cached_key] - - # First, process the blank symbol - skip_log_prob = log_prob[blank_id] - new_y_star_log_prob = y_star.log_prob + skip_log_prob - - # ys[:] returns a copy of ys - B.add( - Hypothesis( - ys=y_star.ys[:], - log_prob=new_y_star_log_prob, - timestamp=y_star.timestamp[:], - ) - ) - - # Second, process other non-blank labels - values, indices = log_prob.topk(beam + 1) - for i, v in zip(indices.tolist(), values.tolist()): - if i in (blank_id, unk_id): - continue - new_ys = y_star.ys + [i] - new_log_prob = y_star.log_prob + v - new_timestamp = y_star.timestamp + [t] - A.add( - Hypothesis( - ys=new_ys, - log_prob=new_log_prob, - timestamp=new_timestamp, - ) - ) - - # Check whether B contains more than "beam" elements more probable - # than the most probable in A - A_most_probable = A.get_most_probable() - - kept_B = B.filter(A_most_probable.log_prob) - - if len(kept_B) >= beam: - B = kept_B.topk(beam) - break - - t += 1 - - best_hyp = B.get_most_probable(length_norm=True) - ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks - - if not return_timestamps: - return ys - else: - return DecodingResults(hyps=[ys], timestamps=[best_hyp.timestamp]) - - -def fast_beam_search_with_nbest_rescoring( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - ngram_lm_scale_list: List[float], - num_paths: int, - G: k2.Fsa, - sp: spm.SentencePieceProcessor, - word_table: k2.SymbolTable, - oov_word: str = "", - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: - """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using fast beam search, num_path are selected - and rescored using a given language model. The shortest path within the - lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - ngram_lm_scale_list: - A list of floats representing LM score scales. - num_paths: - Number of paths to extract from the decoded lattice. - G: - An FsaVec containing only a single FSA. It is an n-gram LM. - sp: - The BPE model. - word_table: - The word symbol table. - oov_word: - OOV words are replaced with this word. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - am_scores = nbest.tot_scores() - - # Now we need to compute the LM scores of each path. - # (1) Get the token IDs of each Path. We assume the decoding_graph - # is an acceptor, i.e., lattice is also an acceptor - tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] - - tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) - tokens = tokens.remove_values_leq(0) # remove -1 and 0 - - token_list: List[List[int]] = tokens.tolist() - word_list: List[List[str]] = sp.decode(token_list) - - assert isinstance(oov_word, str), oov_word - assert oov_word in word_table, oov_word - oov_word_id = word_table[oov_word] - - word_ids_list: List[List[int]] = [] - - for words in word_list: - this_word_ids = [] - for w in words.split(): - if w in word_table: - this_word_ids.append(word_table[w]) - else: - this_word_ids.append(oov_word_id) - word_ids_list.append(this_word_ids) - - word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) - word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) - - num_unique_paths = len(word_ids_list) - - b_to_a_map = torch.zeros( - num_unique_paths, - dtype=torch.int32, - device=lattice.device, - ) - - rescored_word_fsas = k2.intersect_device( - a_fsas=G, - b_fsas=word_fsas_with_self_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True, - ret_arc_maps=False, - ) - - rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) - rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) - ngram_lm_scores = rescored_word_fsas.get_tot_scores( - use_double_scores=True, - log_semiring=False, - ) - - ans: Dict[str, Union[List[List[int]], DecodingResults]] = {} - for s in ngram_lm_scale_list: - key = f"ngram_lm_scale_{s}" - tot_scores = am_scores.values + s * ngram_lm_scores - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - max_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) - - return ans - - -def fast_beam_search_with_nbest_rnn_rescoring( - model: nn.Module, - decoding_graph: k2.Fsa, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - beam: float, - max_states: int, - max_contexts: int, - ngram_lm_scale_list: List[float], - num_paths: int, - G: k2.Fsa, - sp: spm.SentencePieceProcessor, - word_table: k2.SymbolTable, - rnn_lm_model: torch.nn.Module, - rnn_lm_scale_list: List[float], - oov_word: str = "", - use_double_scores: bool = True, - nbest_scale: float = 0.5, - temperature: float = 1.0, - return_timestamps: bool = False, -) -> Dict[str, Union[List[List[int]], DecodingResults]]: - """It limits the maximum number of symbols per frame to 1. - A lattice is first obtained using fast beam search, num_path are selected - and rescored using a given language model and a rnn-lm. - The shortest path within the lattice is used as the final output. - - Args: - model: - An instance of `Transducer`. - decoding_graph: - Decoding graph used for decoding, may be a TrivialGraph or a LG. - encoder_out: - A tensor of shape (N, T, C) from the encoder. - encoder_out_lens: - A tensor of shape (N,) containing the number of frames in `encoder_out` - before padding. - beam: - Beam value, similar to the beam used in Kaldi. - max_states: - Max states per stream per frame. - max_contexts: - Max contexts pre stream per frame. - ngram_lm_scale_list: - A list of floats representing LM score scales. - num_paths: - Number of paths to extract from the decoded lattice. - G: - An FsaVec containing only a single FSA. It is an n-gram LM. - sp: - The BPE model. - word_table: - The word symbol table. - rnn_lm_model: - A rnn-lm model used for LM rescoring - rnn_lm_scale_list: - A list of floats representing RNN score scales. - oov_word: - OOV words are replaced with this word. - use_double_scores: - True to use double precision for computation. False to use - single precision. - nbest_scale: - It's the scale applied to the lattice.scores. A smaller value - yields more unique paths. - temperature: - Softmax temperature. - return_timestamps: - Whether to return timestamps. - Returns: - Return the decoded result in a dict, where the key has the form - 'ngram_lm_scale_xx' and the value is the decoded results - optionally with timestamps. `xx` is the ngram LM scale value - used during decoding, i.e., 0.1. - """ - lattice = fast_beam_search( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=beam, - max_states=max_states, - max_contexts=max_contexts, - temperature=temperature, - ) - - nbest = Nbest.from_lattice( - lattice=lattice, - num_paths=num_paths, - use_double_scores=use_double_scores, - nbest_scale=nbest_scale, - ) - # at this point, nbest.fsa.scores are all zeros. - - nbest = nbest.intersect(lattice) - # Now nbest.fsa.scores contains acoustic scores - - am_scores = nbest.tot_scores() - - # Now we need to compute the LM scores of each path. - # (1) Get the token IDs of each Path. We assume the decoding_graph - # is an acceptor, i.e., lattice is also an acceptor - tokens_shape = nbest.fsa.arcs.shape().remove_axis(1) # [path][arc] - - tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.labels.contiguous()) - tokens = tokens.remove_values_leq(0) # remove -1 and 0 - - token_list: List[List[int]] = tokens.tolist() - word_list: List[List[str]] = sp.decode(token_list) - - assert isinstance(oov_word, str), oov_word - assert oov_word in word_table, oov_word - oov_word_id = word_table[oov_word] - - word_ids_list: List[List[int]] = [] - - for words in word_list: - this_word_ids = [] - for w in words.split(): - if w in word_table: - this_word_ids.append(word_table[w]) - else: - this_word_ids.append(oov_word_id) - word_ids_list.append(this_word_ids) - - word_fsas = k2.linear_fsa(word_ids_list, device=lattice.device) - word_fsas_with_self_loops = k2.add_epsilon_self_loops(word_fsas) - - num_unique_paths = len(word_ids_list) - - b_to_a_map = torch.zeros( - num_unique_paths, - dtype=torch.int32, - device=lattice.device, - ) - - rescored_word_fsas = k2.intersect_device( - a_fsas=G, - b_fsas=word_fsas_with_self_loops, - b_to_a_map=b_to_a_map, - sorted_match_a=True, - ret_arc_maps=False, - ) - - rescored_word_fsas = k2.remove_epsilon_self_loops(rescored_word_fsas) - rescored_word_fsas = k2.top_sort(k2.connect(rescored_word_fsas)) - ngram_lm_scores = rescored_word_fsas.get_tot_scores( - use_double_scores=True, - log_semiring=False, - ) - - # Now RNN-LM - blank_id = model.decoder.blank_id - sos_id = sp.piece_to_id("sos_id") - eos_id = sp.piece_to_id("eos_id") - - sos_tokens = add_sos(tokens, sos_id) - tokens_eos = add_eos(tokens, eos_id) - sos_tokens_row_splits = sos_tokens.shape.row_splits(1) - sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] - - x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) - y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) - - x_tokens = x_tokens.to(torch.int64) - y_tokens = y_tokens.to(torch.int64) - sentence_lengths = sentence_lengths.to(torch.int64) - - rnn_lm_nll = rnn_lm_model(x=x_tokens, y=y_tokens, lengths=sentence_lengths) - assert rnn_lm_nll.ndim == 2 - assert rnn_lm_nll.shape[0] == len(token_list) - rnn_lm_scores = -1 * rnn_lm_nll.sum(dim=1) - - ans: Dict[str, List[List[int]]] = {} - for n_scale in ngram_lm_scale_list: - for rnn_scale in rnn_lm_scale_list: - key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" - tot_scores = ( - am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores - ) - ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) - max_indexes = ragged_tot_scores.argmax() - best_path = k2.index_fsa(nbest.fsa, max_indexes) - - if not return_timestamps: - ans[key] = get_texts(best_path) - else: - ans[key] = get_texts_with_timestamp(best_path) - - return ans - - -def modified_beam_search_ngram_rescoring( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - ngram_lm: NgramLm, - ngram_lm_scale: float, - beam: int = 4, - temperature: float = 1.0, -) -> List[List[int]]: - """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. - - Args: - model: - The transducer model. - encoder_out: - Output from the encoder. Its shape is (N, T, C). - encoder_out_lens: - A 1-D tensor of shape (N,), containing number of valid frames in - encoder_out before padding. - beam: - Number of active paths during the beam search. - temperature: - Softmax temperature. - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - lm_scale = ngram_lm_scale - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state_cost=NgramLmStateCost(ngram_lm), - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [ - hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale - for hyps in A - for hyp in hyps - ] - ) # (num_hyps, 1) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) - - # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor - # as index, so we use `to(torch.int64)` below. - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - vocab_size = log_probs.size(-1) - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_ys = hyp.ys[:] - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - new_ys.append(new_token) - state_cost = hyp.state_cost.forward_one_step(new_token) - else: - state_cost = hyp.state_cost - - # We only keep AM scores in new_hyp.log_prob - new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale - - new_hyp = Hypothesis( - ys=new_ys, log_prob=new_log_prob, state_cost=state_cost - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -def modified_beam_search_LODR( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - LODR_lm: NgramLm, - LODR_lm_scale: float, - LM: LmScorer, - beam: int = 4, - context_graph: Optional[ContextGraph] = None, -) -> List[List[int]]: - """This function implements LODR (https://arxiv.org/abs/2203.16776) with - `modified_beam_search`. It uses a bi-gram language model as the estimate - of the internal language model and subtracts its score during shallow fusion - with an external language model. This implementation uses a RNNLM as the - external language model. - - Args: - model (Transducer): - The transducer model - encoder_out (torch.Tensor): - Encoder output in (N,T,C) - encoder_out_lens (torch.Tensor): - A 1-D tensor of shape (N,), containing the number of - valid frames in encoder_out before padding. - LODR_lm: - A low order n-gram LM, whose score will be subtracted during shallow fusion - LODR_lm_scale: - The scale of the LODR_lm - LM: - A neural net LM, e.g an RNNLM or transformer LM - beam (int, optional): - Beam size. Defaults to 4. - - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert LM is not None - lm_scale = LM.lm_scale - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - sos_id = getattr(LM, "sos_id", 1) - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - # get initial lm score and lm state by scoring the "sos" token - sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - lens = torch.tensor([1]).to(device) - init_score, init_states = LM.score_token(sos_token, lens) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, # state of the NN LM - lm_score=init_score.reshape(-1), - state_cost=NgramLmStateCost( - LODR_lm - ), # state of the source domain ngram - context_state=None if context_graph is None else context_graph.root, - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] # get batch - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - """ - for all hyps with a non-blank new token, score this token. - It is a little confusing here because this for-loop - looks very similar to the one below. Here, we go through all - top-k tokens and only add the non-blanks ones to the token_list. - LM will score those tokens given the LM states. Note that - the variable `scores` is the LM score after seeing the new - non-blank token. - """ - token_list = [] - hs = [] - cs = [] - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - if LM.lm_type == "rnn": - token_list.append([new_token]) - # store the LSTM states - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) - else: - # for transformer LM - token_list.append( - [sos_id] + hyp.ys[context_size:] + [new_token] - ) - - # forward NN LM to get new states and scores - if len(token_list) != 0: - x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) - if LM.lm_type == "rnn": - tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) - ) - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - state = (hs, cs) - else: - # for transformer LM - tokens_list = [torch.tensor(tokens) for tokens in token_list] - tokens_to_score = ( - torch.nn.utils.rnn.pad_sequence( - tokens_list, batch_first=True, padding_value=0.0 - ) - .to(device) - .to(torch.int64) - ) - - state = None - - scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) - - count = 0 # index, used to locate score and lm states - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - ys = hyp.ys[:] - - # current score of hyp - lm_score = hyp.lm_score - state = hyp.state - - hyp_log_prob = topk_log_probs[k] # get score of current hyp - new_token = topk_token_indexes[k] - - context_score = 0 - new_context_state = None if context_graph is None else hyp.context_state - if new_token not in (blank_id, unk_id): - if context_graph is not None: - ( - context_score, - new_context_state, - ) = context_graph.forward_one_step(hyp.context_state, new_token) - - ys.append(new_token) - state_cost = hyp.state_cost.forward_one_step(new_token) - - # calculate the score of the latest token - current_ngram_score = state_cost.lm_score - hyp.state_cost.lm_score - - assert current_ngram_score <= 0.0, ( - state_cost.lm_score, - hyp.state_cost.lm_score, - ) - # score = score + TDLM_score - LODR_score - # LODR_LM_scale should be a negative number here - hyp_log_prob += ( - lm_score[new_token] * lm_scale - + LODR_lm_scale * current_ngram_score - + context_score - ) # add the lm score - - lm_score = scores[count] - if LM.lm_type == "rnn": - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) - count += 1 - else: - state_cost = hyp.state_cost - - new_hyp = Hypothesis( - ys=ys, - log_prob=hyp_log_prob, - state=state, - lm_score=lm_score, - state_cost=state_cost, - context_state=new_context_state, - ) - B[i].add(new_hyp) - - B = B + finalized_B - - # finalize context_state, if the matched contexts do not reach final state - # we need to add the score on the corresponding backoff arc - if context_graph is not None: - finalized_B = [HypothesisList() for _ in range(len(B))] - for i, hyps in enumerate(B): - for hyp in list(hyps): - context_score, new_context_state = context_graph.finalize( - hyp.context_state - ) - finalized_B[i].add( - Hypothesis( - ys=hyp.ys, - log_prob=hyp.log_prob + context_score, - timestamp=hyp.timestamp, - context_state=new_context_state, - ) - ) - B = finalized_B - - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -def modified_beam_search_lm_shallow_fusion( - model: nn.Module, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - LM: LmScorer, - beam: int = 4, - return_timestamps: bool = False, -) -> List[List[int]]: - """Modified_beam_search + NN LM shallow fusion - - Args: - model (Transducer): - The transducer model - encoder_out (torch.Tensor): - Encoder output in (N,T,C) - encoder_out_lens (torch.Tensor): - A 1-D tensor of shape (N,), containing the number of - valid frames in encoder_out before padding. - sp: - Sentence piece generator. - LM (LmScorer): - A neural net LM, e.g RNN or Transformer - beam (int, optional): - Beam size. Defaults to 4. - - Returns: - Return a list-of-list of token IDs. ans[i] is the decoding results - for the i-th utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - assert LM is not None - lm_scale = LM.lm_scale - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = model.decoder.blank_id - sos_id = getattr(LM, "sos_id", 1) - unk_id = getattr(model, "unk_id", blank_id) - context_size = model.decoder.context_size - device = next(model.parameters()).device - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - # get initial lm score and lm state by scoring the "sos" token - sos_token = torch.tensor([[sos_id]]).to(torch.int64).to(device) - lens = torch.tensor([1]).to(device) - init_score, init_states = LM.score_token(sos_token, lens) - - B = [HypothesisList() for _ in range(N)] - for i in range(N): - B[i].add( - Hypothesis( - ys=[-1] * (context_size - 1) + [blank_id], - log_prob=torch.zeros(1, dtype=torch.float32, device=device), - state=init_states, - lm_score=init_score.reshape(-1), - timestamp=[], - ) - ) - - encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) - - offset = 0 - finalized_B = [] - for t, batch_size in enumerate(batch_size_list): - start = offset - end = offset + batch_size - current_encoder_out = encoder_out.data[start:end] # get batch - current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) - # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) - offset = end - - finalized_B = B[batch_size:] + finalized_B - B = B[:batch_size] - - hyps_shape = get_hyps_shape(B).to(device) - - A = [list(b) for b in B] - B = [HypothesisList() for _ in range(batch_size)] - - ys_log_probs = torch.cat( - [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] - ) - - lm_scores = torch.cat( - [hyp.lm_score.reshape(1, -1) for hyps in A for hyp in hyps] - ) - - decoder_input = torch.tensor( - [hyp.ys[-context_size:] for hyps in A for hyp in hyps], - device=device, - dtype=torch.int64, - ) # (num_hyps, context_size) - - decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) - decoder_out = model.joiner.decoder_proj(decoder_out) - - current_encoder_out = torch.index_select( - current_encoder_out, - dim=0, - index=hyps_shape.row_ids(1).to(torch.int64), - ) # (num_hyps, 1, 1, encoder_out_dim) - - logits = model.joiner( - current_encoder_out, - decoder_out, - project_input=False, - ) # (num_hyps, 1, 1, vocab_size) - - logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - - log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) - - log_probs.add_(ys_log_probs) - - vocab_size = log_probs.size(-1) - - log_probs = log_probs.reshape(-1) - - row_splits = hyps_shape.row_splits(1) * vocab_size - log_probs_shape = k2.ragged.create_ragged_shape2( - row_splits=row_splits, cached_tot_size=log_probs.numel() - ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) - """ - for all hyps with a non-blank new token, score this token. - It is a little confusing here because this for-loop - looks very similar to the one below. Here, we go through all - top-k tokens and only add the non-blanks ones to the token_list. - `LM` will score those tokens given the LM states. Note that - the variable `scores` is the LM score after seeing the new - non-blank token. - """ - token_list = [] # a list of list - hs = [] - cs = [] - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - new_token = topk_token_indexes[k] - if new_token not in (blank_id, unk_id): - if LM.lm_type == "rnn": - token_list.append([new_token]) - # store the LSTM states - hs.append(hyp.state[0]) - cs.append(hyp.state[1]) - else: - # for transformer LM - token_list.append( - [sos_id] + hyp.ys[context_size:] + [new_token] - ) - - if len(token_list) != 0: - x_lens = torch.tensor([len(tokens) for tokens in token_list]).to(device) - if LM.lm_type == "rnn": - tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) - ) - hs = torch.cat(hs, dim=1).to(device) - cs = torch.cat(cs, dim=1).to(device) - state = (hs, cs) - else: - # for transformer LM - tokens_list = [torch.tensor(tokens) for tokens in token_list] - tokens_to_score = ( - torch.nn.utils.rnn.pad_sequence( - tokens_list, batch_first=True, padding_value=0.0 - ) - .to(device) - .to(torch.int64) - ) - - state = None - - scores, lm_states = LM.score_token(tokens_to_score, x_lens, state) - - count = 0 # index, used to locate score and lm states - for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) - - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - topk_hyp_indexes = (topk_indexes // vocab_size).tolist() - topk_token_indexes = (topk_indexes % vocab_size).tolist() - - for k in range(len(topk_hyp_indexes)): - hyp_idx = topk_hyp_indexes[k] - hyp = A[i][hyp_idx] - - ys = hyp.ys[:] - - lm_score = hyp.lm_score - state = hyp.state - - hyp_log_prob = topk_log_probs[k] # get score of current hyp - new_token = topk_token_indexes[k] - new_timestamp = hyp.timestamp[:] - if new_token not in (blank_id, unk_id): - ys.append(new_token) - new_timestamp.append(t) - - hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score - - lm_score = scores[count] - if LM.lm_type == "rnn": - state = ( - lm_states[0][:, count, :].unsqueeze(1), - lm_states[1][:, count, :].unsqueeze(1), - ) - count += 1 - - new_hyp = Hypothesis( - ys=ys, - log_prob=hyp_log_prob, - state=state, - lm_score=lm_score, - timestamp=new_timestamp, - ) - B[i].add(new_hyp) - - B = B + finalized_B - best_hyps = [b.get_most_probable(length_norm=True) for b in B] - - sorted_ans = [h.ys[context_size:] for h in best_hyps] - sorted_timestamps = [h.timestamp for h in best_hyps] - ans = [] - ans_timestamps = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) - - if not return_timestamps: - return ans - else: - return DecodingResults( - hyps=ans, - timestamps=ans_timestamps, - ) diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/conformer.py deleted file mode 100644 index ab46e233b..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/conformer.py +++ /dev/null @@ -1,1600 +0,0 @@ -#!/usr/bin/env python3 -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import math -import warnings -from typing import List, Optional, Tuple - -import torch -from encoder_interface import EncoderInterface -from scaling import ( - ActivationBalancer, - BasicNorm, - DoubleSwish, - ScaledConv1d, - ScaledConv2d, - ScaledLinear, -) -from torch import Tensor, nn - -from icefall.utils import is_jit_tracing, make_pad_mask, subsequent_chunk_mask - - -class Conformer(EncoderInterface): - """ - Args: - num_features (int): Number of input features - subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension, also the output dimension - nhead (int): number of head - dim_feedforward (int): feedforward dimention - num_encoder_layers (int): number of encoder layers - dropout (float): dropout rate - layer_dropout (float): layer-dropout rate. - cnn_module_kernel (int): Kernel size of convolution module - vgg_frontend (bool): whether to use vgg frontend. - dynamic_chunk_training (bool): whether to use dynamic chunk training, if - you want to train a streaming model, this is expected to be True. - When setting True, it will use a masking strategy to make the attention - see only limited left and right context. - short_chunk_threshold (float): a threshold to determinize the chunk size - to be used in masking training, if the randomly generated chunk size - is greater than ``max_len * short_chunk_threshold`` (max_len is the - max sequence length of current batch) then it will use - full context in training (i.e. with chunk size equals to max_len). - This will be used only when dynamic_chunk_training is True. - short_chunk_size (int): see docs above, if the randomly generated chunk - size equals to or less than ``max_len * short_chunk_threshold``, the - chunk size will be sampled uniformly from 1 to short_chunk_size. - This also will be used only when dynamic_chunk_training is True. - num_left_chunks (int): the left context (in chunks) attention can see, the - chunk size is decided by short_chunk_threshold and short_chunk_size. - A minus value means seeing full left context. - This also will be used only when dynamic_chunk_training is True. - causal (bool): Whether to use causal convolution in conformer encoder - layer. This MUST be True when using dynamic_chunk_training. - """ - - def __init__( - self, - num_features: int, - subsampling_factor: int = 4, - d_model: int = 256, - nhead: int = 4, - dim_feedforward: int = 2048, - num_encoder_layers: int = 12, - dropout: float = 0.1, - layer_dropout: float = 0.075, - cnn_module_kernel: int = 31, - dynamic_chunk_training: bool = False, - short_chunk_threshold: float = 0.75, - short_chunk_size: int = 25, - num_left_chunks: int = -1, - causal: bool = False, - ) -> None: - super(Conformer, self).__init__() - - self.num_features = num_features - self.subsampling_factor = subsampling_factor - if subsampling_factor != 4: - raise NotImplementedError("Support only 'subsampling_factor=4'.") - - # self.encoder_embed converts the input of shape (N, T, num_features) - # to the shape (N, T//subsampling_factor, d_model). - # That is, it does two things simultaneously: - # (1) subsampling: T -> T//subsampling_factor - # (2) embedding: num_features -> d_model - self.encoder_embed = Conv2dSubsampling(num_features, d_model) - - self.encoder_layers = num_encoder_layers - self.d_model = d_model - self.cnn_module_kernel = cnn_module_kernel - self.causal = causal - self.dynamic_chunk_training = dynamic_chunk_training - self.short_chunk_threshold = short_chunk_threshold - self.short_chunk_size = short_chunk_size - self.num_left_chunks = num_left_chunks - - self.encoder_pos = RelPositionalEncoding(d_model, dropout) - - encoder_layer = ConformerEncoderLayer( - d_model, - nhead, - dim_feedforward, - dropout, - layer_dropout, - cnn_module_kernel, - causal, - ) - self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) - self._init_state: List[torch.Tensor] = [torch.empty(0)] - - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - warmup: - A floating point value that gradually increases from 0 throughout - training; when it is >= 1.0 we are "fully warmed up". It is used - to turn modules on sequentially. - Returns: - Return a tuple containing 2 tensors: - - embeddings: its shape is (batch_size, output_seq_len, d_model) - - lengths, a tensor of shape (batch_size,) containing the number - of frames in `embeddings` before padding. - """ - x = self.encoder_embed(x) - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - # Caution: We assume the subsampling factor is 4! - - # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning - # - # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 - lengths = (((x_lens - 1) >> 1) - 1) >> 1 - - if not is_jit_tracing(): - assert x.size(0) == lengths.max().item() - - src_key_padding_mask = make_pad_mask(lengths, x.size(0)) - - if self.dynamic_chunk_training: - assert ( - self.causal - ), "Causal convolution is required for streaming conformer." - max_len = x.size(0) - chunk_size = torch.randint(1, max_len, (1,)).item() - if chunk_size > (max_len * self.short_chunk_threshold): - chunk_size = max_len - else: - chunk_size = chunk_size % self.short_chunk_size + 1 - - mask = ~subsequent_chunk_mask( - size=x.size(0), - chunk_size=chunk_size, - num_left_chunks=self.num_left_chunks, - device=x.device, - ) - x = self.encoder( - x, - pos_emb, - mask=mask, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - ) # (T, N, C) - else: - x = self.encoder( - x, - pos_emb, - mask=None, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - ) # (T, N, C) - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - return x, lengths - - @torch.jit.export - def get_init_state( - self, left_context: int, device: torch.device - ) -> List[torch.Tensor]: - """Return the initial cache state of the model. - - Args: - left_context: The left context size (in frames after subsampling). - - Returns: - Return the initial state of the model, it is a list containing two - tensors, the first one is the cache for attentions which has a shape - of (num_encoder_layers, left_context, encoder_dim), the second one - is the cache of conv_modules which has a shape of - (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). - - NOTE: the returned tensors are on the given device. - """ - if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: - # Note: It is OK to share the init state as it is - # not going to be modified by the model - return self._init_state - - init_states: List[torch.Tensor] = [ - torch.zeros( - ( - self.encoder_layers, - left_context, - self.d_model, - ), - device=device, - ), - torch.zeros( - ( - self.encoder_layers, - self.cnn_module_kernel - 1, - self.d_model, - ), - device=device, - ), - ] - - self._init_state = init_states - - return init_states - - @torch.jit.export - def streaming_forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - states: Optional[List[Tensor]] = None, - processed_lens: Optional[Tensor] = None, - left_context: int = 64, - right_context: int = 4, - chunk_size: int = 16, - simulate_streaming: bool = False, - warmup: float = 1.0, - ) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]: - """ - Args: - x: - The input tensor. Its shape is (batch_size, seq_len, feature_dim). - x_lens: - A tensor of shape (batch_size,) containing the number of frames in - `x` before padding. - states: - The decode states for previous frames which contains the cached data. - It has two elements, the first element is the attn_cache which has - a shape of (encoder_layers, left_context, batch, attention_dim), - the second element is the conv_cache which has a shape of - (encoder_layers, cnn_module_kernel-1, batch, conv_dim). - Note: states will be modified in this function. - processed_lens: - How many frames (after subsampling) have been processed for each sequence. - left_context: - How many previous frames the attention can see in current chunk. - Note: It's not that each individual frame has `left_context` frames - of left context, some have more. - right_context: - How many future frames the attention can see in current chunk. - Note: It's not that each individual frame has `right_context` frames - of right context, some have more. - chunk_size: - The chunk size for decoding, this will be used to simulate streaming - decoding using masking. - simulate_streaming: - If setting True, it will use a masking strategy to simulate streaming - fashion (i.e. every chunk data only see limited left context and - right context). The whole sequence is supposed to be send at a time - When using simulate_streaming. - warmup: - A floating point value that gradually increases from 0 throughout - training; when it is >= 1.0 we are "fully warmed up". It is used - to turn modules on sequentially. - Returns: - Return a tuple containing 2 tensors: - - logits, its shape is (batch_size, output_seq_len, output_dim) - - logit_lens, a tensor of shape (batch_size,) containing the number - of frames in `logits` before padding. - - decode_states, the updated states including the information - of current chunk. - """ - - # x: [N, T, C] - # Caution: We assume the subsampling factor is 4! - - # lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning - # - # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 - lengths = (((x_lens - 1) >> 1) - 1) >> 1 - - if not simulate_streaming: - assert states is not None - assert processed_lens is not None - assert ( - len(states) == 2 - and states[0].shape - == (self.encoder_layers, left_context, x.size(0), self.d_model) - and states[1].shape - == ( - self.encoder_layers, - self.cnn_module_kernel - 1, - x.size(0), - self.d_model, - ) - ), f"""The length of states MUST be equal to 2, and the shape of - first element should be {(self.encoder_layers, left_context, x.size(0), self.d_model)}, - given {states[0].shape}. the shape of second element should be - {(self.encoder_layers, self.cnn_module_kernel - 1, x.size(0), self.d_model)}, - given {states[1].shape}.""" - - lengths -= 2 # we will cut off 1 frame on each side of encoder_embed output - - src_key_padding_mask = make_pad_mask(lengths) - - processed_mask = torch.arange(left_context, device=x.device).expand( - x.size(0), left_context - ) - processed_lens = processed_lens.view(x.size(0), 1) - processed_mask = (processed_lens <= processed_mask).flip(1) - - src_key_padding_mask = torch.cat( - [processed_mask, src_key_padding_mask], dim=1 - ) - - embed = self.encoder_embed(x) - - # cut off 1 frame on each size of embed as they see the padding - # value which causes a training and decoding mismatch. - embed = embed[:, 1:-1, :] - - embed, pos_enc = self.encoder_pos(embed, left_context) - embed = embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) - - x, states = self.encoder.chunk_forward( - embed, - pos_enc, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - states=states, - left_context=left_context, - right_context=right_context, - ) # (T, B, F) - if right_context > 0: - x = x[0:-right_context, ...] - lengths -= right_context - else: - assert states is None - states = [] # just to make torch.script.jit happy - # this branch simulates streaming decoding using mask as we are - # using in training time. - src_key_padding_mask = make_pad_mask(lengths) - x = self.encoder_embed(x) - x, pos_emb = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - - assert x.size(0) == lengths.max().item() - - if chunk_size < 0: - # use full attention - chunk_size = x.size(0) - left_context = -1 - - num_left_chunks = -1 - if left_context >= 0: - assert left_context % chunk_size == 0 - num_left_chunks = left_context // chunk_size - - mask = ~subsequent_chunk_mask( - size=x.size(0), - chunk_size=chunk_size, - num_left_chunks=num_left_chunks, - device=x.device, - ) - x = self.encoder( - x, - pos_emb, - mask=mask, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - ) # (T, N, C) - - x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) - - return x, lengths, states - - -class ConformerEncoderLayer(nn.Module): - """ - ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks. - See: "Conformer: Convolution-augmented Transformer for Speech Recognition" - - Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - cnn_module_kernel (int): Kernel size of convolution module. - causal (bool): Whether to use causal convolution in conformer encoder - layer. This MUST be True when using dynamic_chunk_training and streaming decoding. - - Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = encoder_layer(src, pos_emb) - """ - - def __init__( - self, - d_model: int, - nhead: int, - dim_feedforward: int = 2048, - dropout: float = 0.1, - layer_dropout: float = 0.075, - cnn_module_kernel: int = 31, - causal: bool = False, - ) -> None: - super(ConformerEncoderLayer, self).__init__() - - self.layer_dropout = layer_dropout - - self.d_model = d_model - - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) - - self.feed_forward = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), - nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), - ) - - self.feed_forward_macaron = nn.Sequential( - ScaledLinear(d_model, dim_feedforward), - ActivationBalancer(channel_dim=-1), - DoubleSwish(), - nn.Dropout(dropout), - ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), - ) - - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) - - self.norm_final = BasicNorm(d_model) - - # try to ensure the output is close to zero-mean (or at least, zero-median). - self.balancer = ActivationBalancer( - channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 - ) - - self.dropout = nn.Dropout(dropout) - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - src_mask: Optional[Tensor] = None, - warmup: float = 1.0, - ) -> Tensor: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - src_key_padding_mask: the mask for the src keys per batch (optional). - src_mask: the mask for the src sequence (optional). - warmup: controls selective bypass of of layers; if < 1.0, we will - bypass layers more frequently. - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number - """ - src_orig = src - - warmup_scale = min(0.1 + warmup, 1.0) - # alpha = 1.0 means fully use this encoder layer, 0.0 would mean - # completely bypass it. - if self.training: - alpha = ( - warmup_scale - if torch.rand(()).item() <= (1.0 - self.layer_dropout) - else 0.1 - ) - else: - alpha = 1.0 - - # macaron style feed forward module - src = src + self.dropout(self.feed_forward_macaron(src)) - - # multi-headed self-attention module - src_att = self.self_attn( - src, - src, - src, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - )[0] - - src = src + self.dropout(src_att) - - # convolution module - conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) - src = src + self.dropout(conv) - - # feed forward module - src = src + self.dropout(self.feed_forward(src)) - - src = self.norm_final(self.balancer(src)) - - if alpha != 1.0: - src = alpha * src + (1 - alpha) * src_orig - - return src - - @torch.jit.export - def chunk_forward( - self, - src: Tensor, - pos_emb: Tensor, - states: List[Tensor], - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - warmup: float = 1.0, - left_context: int = 0, - right_context: int = 0, - ) -> Tuple[Tensor, List[Tensor]]: - """ - Pass the input through the encoder layer. - - Args: - src: the sequence to the encoder layer (required). - pos_emb: Positional embedding tensor (required). - states: - The decode states for previous frames which contains the cached data. - It has two elements, the first element is the attn_cache which has - a shape of (left_context, batch, attention_dim), - the second element is the conv_cache which has a shape of - (cnn_module_kernel-1, batch, conv_dim). - Note: states will be modified in this function. - src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - warmup: controls selective bypass of of layers; if < 1.0, we will - bypass layers more frequently. - left_context: - How many previous frames the attention can see in current chunk. - Note: It's not that each individual frame has `left_context` frames - of left context, some have more. - right_context: - How many future frames the attention can see in current chunk. - Note: It's not that each individual frame has `right_context` frames - of right context, some have more. - - Shape: - src: (S, N, E). - pos_emb: (N, 2*(S+left_context)-1, E). - src_mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, N is the batch size, E is the feature number - """ - - assert not self.training - assert len(states) == 2 - assert states[0].shape == (left_context, src.size(1), src.size(2)) - - # macaron style feed forward module - src = src + self.dropout(self.feed_forward_macaron(src)) - - # We put the attention cache this level (i.e. before linear transformation) - # to save memory consumption, when decoding in streaming fashion, the - # batch size would be thousands (for 32GB machine), if we cache key & val - # separately, it needs extra several GB memory. - # TODO(WeiKang): Move cache to self_attn level (i.e. cache key & val - # separately) if needed. - key = torch.cat([states[0], src], dim=0) - val = key - if right_context > 0: - states[0] = key[ - -(left_context + right_context) : -right_context, ... # noqa - ] - else: - states[0] = key[-left_context:, ...] - - # multi-headed self-attention module - src_att = self.self_attn( - src, - key, - val, - pos_emb=pos_emb, - attn_mask=src_mask, - key_padding_mask=src_key_padding_mask, - left_context=left_context, - )[0] - - src = src + self.dropout(src_att) - - # convolution module - conv, conv_cache = self.conv_module(src, states[1], right_context) - states[1] = conv_cache - - src = src + self.dropout(conv) - - # feed forward module - src = src + self.dropout(self.feed_forward(src)) - - src = self.norm_final(self.balancer(src)) - - return src, states - - -class ConformerEncoder(nn.Module): - r"""ConformerEncoder is a stack of N encoder layers - - Args: - encoder_layer: an instance of the ConformerEncoderLayer() class (required). - num_layers: the number of sub-encoder-layers in the encoder (required). - - Examples:: - >>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8) - >>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6) - >>> src = torch.rand(10, 32, 512) - >>> pos_emb = torch.rand(32, 19, 512) - >>> out = conformer_encoder(src, pos_emb) - """ - - def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: - super().__init__() - self.layers = nn.ModuleList( - [copy.deepcopy(encoder_layer) for i in range(num_layers)] - ) - self.num_layers = num_layers - - def forward( - self, - src: Tensor, - pos_emb: Tensor, - src_key_padding_mask: Optional[Tensor] = None, - mask: Optional[Tensor] = None, - warmup: float = 1.0, - ) -> Tensor: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - pos_emb: Positional embedding tensor (required). - src_key_padding_mask: the mask for the src keys per batch (optional). - mask: the mask for the src sequence (optional). - warmup: controls selective bypass of of layers; if < 1.0, we will - bypass layers more frequently. - - Shape: - src: (S, N, E). - pos_emb: (N, 2*S-1, E) - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - """ - output = src - - for layer_index, mod in enumerate(self.layers): - output = mod( - output, - pos_emb, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - ) - - return output - - @torch.jit.export - def chunk_forward( - self, - src: Tensor, - pos_emb: Tensor, - states: List[Tensor], - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - warmup: float = 1.0, - left_context: int = 0, - right_context: int = 0, - ) -> Tuple[Tensor, List[Tensor]]: - r"""Pass the input through the encoder layers in turn. - - Args: - src: the sequence to the encoder (required). - pos_emb: Positional embedding tensor (required). - states: - The decode states for previous frames which contains the cached data. - It has two elements, the first element is the attn_cache which has - a shape of (encoder_layers, left_context, batch, attention_dim), - the second element is the conv_cache which has a shape of - (encoder_layers, cnn_module_kernel-1, batch, conv_dim). - Note: states will be modified in this function. - mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). - warmup: controls selective bypass of of layers; if < 1.0, we will - bypass layers more frequently. - left_context: - How many previous frames the attention can see in current chunk. - Note: It's not that each individual frame has `left_context` frames - of left context, some have more. - right_context: - How many future frames the attention can see in current chunk. - Note: It's not that each individual frame has `right_context` frames - of right context, some have more. - Shape: - src: (S, N, E). - pos_emb: (N, 2*(S+left_context)-1, E). - mask: (S, S). - src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number - - """ - assert not self.training - assert len(states) == 2 - assert states[0].shape == ( - self.num_layers, - left_context, - src.size(1), - src.size(2), - ) - assert states[1].size(0) == self.num_layers - - output = src - - for layer_index, mod in enumerate(self.layers): - cache = [states[0][layer_index], states[1][layer_index]] - output, cache = mod.chunk_forward( - output, - pos_emb, - states=cache, - src_mask=mask, - src_key_padding_mask=src_key_padding_mask, - warmup=warmup, - left_context=left_context, - right_context=right_context, - ) - states[0][layer_index] = cache[0] - states[1][layer_index] = cache[1] - - return output, states - - -class RelPositionalEncoding(torch.nn.Module): - """Relative positional encoding module. - - See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py - - Args: - d_model: Embedding dimension. - dropout_rate: Dropout rate. - max_len: Maximum input length. - - """ - - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: - """Construct an PositionalEncoding object.""" - super(RelPositionalEncoding, self).__init__() - if is_jit_tracing(): - # 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e., - # It assumes that the maximum input won't have more than - # 10k frames. - # - # TODO(fangjun): Use torch.jit.script() for this module - max_len = 10000 - - self.d_model = d_model - self.dropout = torch.nn.Dropout(p=dropout_rate) - self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - - def extend_pe(self, x: Tensor, left_context: int = 0) -> None: - """Reset the positional encodings.""" - x_size_1 = x.size(1) + left_context - if self.pe is not None: - # self.pe contains both positive and negative parts - # the length of self.pe is 2 * input_len - 1 - if self.pe.size(1) >= x_size_1 * 2 - 1: - # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): - self.pe = self.pe.to(dtype=x.dtype, device=x.device) - return - # Suppose `i` means to the position of query vector and `j` means the - # position of key vector. We use position relative positions when keys - # are to the left (i>j) and negative relative positions otherwise (i Tuple[Tensor, Tensor]: - """Add positional encoding. - - Args: - x (torch.Tensor): Input tensor (batch, time, `*`). - left_context (int): left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. - - Returns: - torch.Tensor: Encoded tensor (batch, time, `*`). - torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). - - """ - if isinstance(left_context, torch.Tensor): - left_context = left_context.item() - self.extend_pe(x, left_context) - x_size_1 = x.size(1) + left_context - pos_emb = self.pe[ - :, - self.pe.size(1) // 2 - - x_size_1 - + 1 : self.pe.size(1) // 2 # noqa E203 - + x.size(1), - ] - return self.dropout(x), self.dropout(pos_emb) - - -class RelPositionMultiheadAttention(nn.Module): - r"""Multi-Head Attention layer with relative position encoding - - See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" - - Args: - embed_dim: total dimension of the model. - num_heads: parallel attention heads. - dropout: a Dropout layer on attn_output_weights. Default: 0.0. - - Examples:: - - >>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads) - >>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb) - """ - - def __init__( - self, - embed_dim: int, - num_heads: int, - dropout: float = 0.0, - ) -> None: - super(RelPositionMultiheadAttention, self).__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - assert ( - self.head_dim * num_heads == self.embed_dim - ), "embed_dim must be divisible by num_heads" - - self.in_proj = ScaledLinear(embed_dim, 3 * embed_dim, bias=True) - self.out_proj = ScaledLinear( - embed_dim, embed_dim, bias=True, initial_scale=0.25 - ) - - # linear transformation for positional encoding. - self.linear_pos = ScaledLinear(embed_dim, embed_dim, bias=False) - # these two learnable bias are used in matrix c and matrix d - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim)) - self.pos_bias_u_scale = nn.Parameter(torch.zeros(()).detach()) - self.pos_bias_v_scale = nn.Parameter(torch.zeros(()).detach()) - self._reset_parameters() - - def _pos_bias_u(self): - return self.pos_bias_u * self.pos_bias_u_scale.exp() - - def _pos_bias_v(self): - return self.pos_bias_v * self.pos_bias_v_scale.exp() - - def _reset_parameters(self) -> None: - nn.init.normal_(self.pos_bias_u, std=0.01) - nn.init.normal_(self.pos_bias_v, std=0.01) - - def forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = False, - attn_mask: Optional[Tensor] = None, - left_context: int = 0, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. When given a binary mask and a value is True, - the corresponding value on the attention layer will be ignored. When given - a byte mask and a value is non-zero, the corresponding value on the attention - layer will be ignored - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - left_context (int): left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. - - Shape: - - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the position - with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - return self.multi_head_attention_forward( - query, - key, - value, - pos_emb, - self.embed_dim, - self.num_heads, - self.in_proj.get_weight(), - self.in_proj.get_bias(), - self.dropout, - self.out_proj.get_weight(), - self.out_proj.get_bias(), - training=self.training, - key_padding_mask=key_padding_mask, - need_weights=need_weights, - attn_mask=attn_mask, - left_context=left_context, - ) - - def rel_shift(self, x: Tensor, left_context: int = 0) -> Tensor: - """Compute relative positional encoding. - - Args: - x: Input tensor (batch, head, time1, 2*time1-1+left_context). - time1 means the length of query vector. - left_context (int): left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. - - Returns: - Tensor: tensor of shape (batch, head, time1, time2) - (note: time2 has the same value as time1, but it is for - the key, while time1 is for the query). - """ - (batch_size, num_heads, time1, n) = x.shape - - time2 = time1 + left_context - if not is_jit_tracing(): - assert ( - n == left_context + 2 * time1 - 1 - ), f"{n} == {left_context} + 2 * {time1} - 1" - - if is_jit_tracing(): - rows = torch.arange(start=time1 - 1, end=-1, step=-1) - cols = torch.arange(time2) - rows = rows.repeat(batch_size * num_heads).unsqueeze(-1) - indexes = rows + cols - - x = x.reshape(-1, n) - x = torch.gather(x, dim=1, index=indexes) - x = x.reshape(batch_size, num_heads, time1, time2) - return x - else: - # Note: TorchScript requires explicit arg for stride() - batch_stride = x.stride(0) - head_stride = x.stride(1) - time1_stride = x.stride(2) - n_stride = x.stride(3) - return x.as_strided( - (batch_size, num_heads, time1, time2), - (batch_stride, head_stride, time1_stride - n_stride, n_stride), - storage_offset=n_stride * (time1 - 1), - ) - - def multi_head_attention_forward( - self, - query: Tensor, - key: Tensor, - value: Tensor, - pos_emb: Tensor, - embed_dim_to_check: int, - num_heads: int, - in_proj_weight: Tensor, - in_proj_bias: Tensor, - dropout_p: float, - out_proj_weight: Tensor, - out_proj_bias: Tensor, - training: bool = True, - key_padding_mask: Optional[Tensor] = None, - need_weights: bool = False, - attn_mask: Optional[Tensor] = None, - left_context: int = 0, - ) -> Tuple[Tensor, Optional[Tensor]]: - r""" - Args: - query, key, value: map a query and a set of key-value pairs to an output. - pos_emb: Positional embedding tensor - embed_dim_to_check: total dimension of the model. - num_heads: parallel attention heads. - in_proj_weight, in_proj_bias: input projection weight and bias. - dropout_p: probability of an element to be zeroed. - out_proj_weight, out_proj_bias: the output projection weight and bias. - training: apply dropout if is ``True``. - key_padding_mask: if provided, specified padding elements in the key will - be ignored by the attention. This is an binary mask. When the value is True, - the corresponding value on the attention layer will be filled with -inf. - need_weights: output attn_output_weights. - attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all - the batches while a 3D mask allows to specify a different mask for the entries of each batch. - left_context (int): left context (in frames) used during streaming decoding. - this is used only in real streaming decoding, in other circumstances, - it MUST be 0. - - Shape: - Inputs: - - query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is - the embedding dimension. - - key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is - the embedding dimension. - - pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence - length, N is the batch size, E is the embedding dimension. - - key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length. - If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions - will be unchanged. If a BoolTensor is provided, the positions with the - value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged. - - attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length. - 3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length, - S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked - positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend - while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True`` - are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor - is provided, it will be added to the attention weight. - - Outputs: - - attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, - E is the embedding dimension. - - attn_output_weights: :math:`(N, L, S)` where N is the batch size, - L is the target sequence length, S is the source sequence length. - """ - - tgt_len, bsz, embed_dim = query.size() - if not is_jit_tracing(): - assert embed_dim == embed_dim_to_check - assert key.size(0) == value.size(0) and key.size(1) == value.size(1) - - head_dim = embed_dim // num_heads - if not is_jit_tracing(): - assert ( - head_dim * num_heads == embed_dim - ), "embed_dim must be divisible by num_heads" - - scaling = float(head_dim) ** -0.5 - - if torch.equal(query, key) and torch.equal(key, value): - # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) - - elif torch.equal(key, value): - # encoder-decoder attention - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1) - - else: - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = 0 - _end = embed_dim - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - q = nn.functional.linear(query, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim - _end = embed_dim * 2 - _w = in_proj_weight[_start:_end, :] - if _b is not None: - _b = _b[_start:_end] - k = nn.functional.linear(key, _w, _b) - - # This is inline in_proj function with in_proj_weight and in_proj_bias - _b = in_proj_bias - _start = embed_dim * 2 - _end = None - _w = in_proj_weight[_start:, :] - if _b is not None: - _b = _b[_start:] - v = nn.functional.linear(value, _w, _b) - - if attn_mask is not None: - assert ( - attn_mask.dtype == torch.float32 - or attn_mask.dtype == torch.float64 - or attn_mask.dtype == torch.float16 - or attn_mask.dtype == torch.uint8 - or attn_mask.dtype == torch.bool - ), "Only float, byte, and bool types are supported for attn_mask, not {}".format( - attn_mask.dtype - ) - if attn_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for attn_mask is deprecated. Use bool tensor instead." - ) - attn_mask = attn_mask.to(torch.bool) - - if attn_mask.dim() == 2: - attn_mask = attn_mask.unsqueeze(0) - if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") - elif attn_mask.dim() == 3: - if list(attn_mask.size()) != [ - bsz * num_heads, - query.size(0), - key.size(0), - ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") - else: - raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) - ) - # attn_mask's dim is 3 now. - - # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: - warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." - ) - key_padding_mask = key_padding_mask.to(torch.bool) - - q = (q * scaling).contiguous().view(tgt_len, bsz, num_heads, head_dim) - k = k.contiguous().view(-1, bsz, num_heads, head_dim) - v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) - - src_len = k.size(0) - - if key_padding_mask is not None and not is_jit_tracing(): - assert key_padding_mask.size(0) == bsz, "{} == {}".format( - key_padding_mask.size(0), bsz - ) - assert key_padding_mask.size(1) == src_len, "{} == {}".format( - key_padding_mask.size(1), src_len - ) - - q = q.transpose(0, 1) # (batch, time1, head, d_k) - - pos_emb_bsz = pos_emb.size(0) - if not is_jit_tracing(): - assert pos_emb_bsz in (1, bsz) # actually it is 1 - - p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim) - # (batch, 2*time1, head, d_k) --> (batch, head, d_k, 2*time -1) - p = p.permute(0, 2, 3, 1) - - q_with_bias_u = (q + self._pos_bias_u()).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - q_with_bias_v = (q + self._pos_bias_v()).transpose( - 1, 2 - ) # (batch, head, time1, d_k) - - # compute attention score - # first compute matrix a and matrix c - # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 - k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) - - # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd, left_context) - - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) - - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) - - if not is_jit_tracing(): - assert list(attn_output_weights.size()) == [ - bsz * num_heads, - tgt_len, - src_len, - ] - - if attn_mask is not None: - if attn_mask.dtype == torch.bool: - attn_output_weights.masked_fill_(attn_mask, float("-inf")) - else: - attn_output_weights += attn_mask - - if key_padding_mask is not None: - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - attn_output_weights = attn_output_weights.masked_fill( - key_padding_mask.unsqueeze(1).unsqueeze(2), - float("-inf"), - ) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, src_len - ) - - attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1) - - # If we are using dynamic_chunk_training and setting a limited - # num_left_chunks, the attention may only see the padding values which - # will also be masked out by `key_padding_mask`, at this circumstances, - # the whole column of `attn_output_weights` will be `-inf` - # (i.e. be `nan` after softmax), so, we fill `0.0` at the masking - # positions to avoid invalid loss value below. - if ( - attn_mask is not None - and attn_mask.dtype == torch.bool - and key_padding_mask is not None - ): - if attn_mask.size(0) != 1: - attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) - else: - # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( - 1 - ).unsqueeze(2) - - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) - attn_output_weights = attn_output_weights.view( - bsz * num_heads, tgt_len, src_len - ) - - attn_output_weights = nn.functional.dropout( - attn_output_weights, p=dropout_p, training=training - ) - - attn_output = torch.bmm(attn_output_weights, v) - - if not is_jit_tracing(): - assert list(attn_output.size()) == [ - bsz * num_heads, - tgt_len, - head_dim, - ] - - attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) - ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) - - if need_weights: - # average attention weights over heads - attn_output_weights = attn_output_weights.view( - bsz, num_heads, tgt_len, src_len - ) - return attn_output, attn_output_weights.sum(dim=1) / num_heads - else: - return attn_output, None - - -class ConvolutionModule(nn.Module): - """ConvolutionModule in Conformer model. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py - - Args: - channels (int): The number of channels of conv layers. - kernel_size (int): Kernerl size of conv layers. - bias (bool): Whether to use bias in conv layers (default=True). - causal (bool): Whether to use causal convolution. - """ - - def __init__( - self, - channels: int, - kernel_size: int, - bias: bool = True, - causal: bool = False, - ) -> None: - """Construct an ConvolutionModule object.""" - super(ConvolutionModule, self).__init__() - # kernerl_size should be a odd number for 'SAME' padding - assert (kernel_size - 1) % 2 == 0 - self.causal = causal - - self.pointwise_conv1 = ScaledConv1d( - channels, - 2 * channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - ) - - # after pointwise_conv1 we put x through a gated linear unit (nn.functional.glu). - # For most layers the normal rms value of channels of x seems to be in the range 1 to 4, - # but sometimes, for some reason, for layer 0 the rms ends up being very large, - # between 50 and 100 for different channels. This will cause very peaky and - # sparse derivatives for the sigmoid gating function, which will tend to make - # the loss function not learn effectively. (for most layers the average absolute values - # are in the range 0.5..9.0, and the average p(x>0), i.e. positive proportion, - # at the output of pointwise_conv1.output is around 0.35 to 0.45 for different - # layers, which likely breaks down as 0.5 for the "linear" half and - # 0.2 to 0.3 for the part that goes into the sigmoid. The idea is that if we - # constrain the rms values to a reasonable range via a constraint of max_abs=10.0, - # it will be in a better position to start learning something, i.e. to latch onto - # the correct range. - self.deriv_balancer1 = ActivationBalancer( - channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 - ) - - self.lorder = kernel_size - 1 - padding = (kernel_size - 1) // 2 - if self.causal: - padding = 0 - - self.depthwise_conv = ScaledConv1d( - channels, - channels, - kernel_size, - stride=1, - padding=padding, - groups=channels, - bias=bias, - ) - - self.deriv_balancer2 = ActivationBalancer( - channel_dim=1, min_positive=0.05, max_positive=1.0 - ) - - self.activation = DoubleSwish() - - self.pointwise_conv2 = ScaledConv1d( - channels, - channels, - kernel_size=1, - stride=1, - padding=0, - bias=bias, - initial_scale=0.25, - ) - - def forward( - self, - x: Tensor, - cache: Optional[Tensor] = None, - right_context: int = 0, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tuple[Tensor, Tensor]: - """Compute convolution module. - - Args: - x: Input tensor (#time, batch, channels). - cache: The cache of depthwise_conv, only used in real streaming - decoding. - right_context: - How many future frames the attention can see in current chunk. - Note: It's not that each individual frame has `right_context` frames - src_key_padding_mask: the mask for the src keys per batch (optional). - of right context, some have more. - - Returns: - If cache is None return the output tensor (#time, batch, channels). - If cache is not None, return a tuple of Tensor, the first one is - the output tensor (#time, batch, channels), the second one is the - new cache for next chunk (#kernel_size - 1, batch, channels). - - """ - # exchange the temporal dimension and the feature dimension - x = x.permute(1, 2, 0) # (#batch, channels, time). - - # GLU mechanism - x = self.pointwise_conv1(x) # (batch, 2*channels, time) - - x = self.deriv_balancer1(x) - x = nn.functional.glu(x, dim=1) # (batch, channels, time) - - # 1D Depthwise Conv - if src_key_padding_mask is not None: - x.masked_fill_(src_key_padding_mask.unsqueeze(1).expand_as(x), 0.0) - if self.causal and self.lorder > 0: - if cache is None: - # Make depthwise_conv causal by - # manualy padding self.lorder zeros to the left - x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) - else: - assert not self.training, "Cache should be None in training time" - assert cache.size(0) == self.lorder - x = torch.cat([cache.permute(1, 2, 0), x], dim=2) - if right_context > 0: - cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : (-right_context), # noqa - ..., - ] - else: - cache = x.permute(2, 0, 1)[-self.lorder :, ...] # noqa - x = self.depthwise_conv(x) - - x = self.deriv_balancer2(x) - x = self.activation(x) - - x = self.pointwise_conv2(x) # (batch, channel, time) - - # torch.jit.script requires return types be the same as annotated above - if cache is None: - cache = torch.empty(0) - - return x.permute(2, 0, 1), cache - - -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - - Convert an input of shape (N, T, idim) to an output - with shape (N, T', odim), where - T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 - - It is based on - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - layer1_channels: int = 8, - layer2_channels: int = 32, - layer3_channels: int = 128, - ) -> None: - """ - Args: - in_channels: - Number of channels in. The input shape is (N, T, in_channels). - Caution: It requires: T >=7, in_channels >=7 - out_channels - Output dim. The output shape is (N, ((T-1)//2 - 1)//2, out_channels) - layer1_channels: - Number of channels in layer1 - layer1_channels: - Number of channels in layer2 - """ - assert in_channels >= 7 - super().__init__() - - self.conv = nn.Sequential( - ScaledConv2d( - in_channels=1, - out_channels=layer1_channels, - kernel_size=3, - padding=1, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( - in_channels=layer1_channels, - out_channels=layer2_channels, - kernel_size=3, - stride=2, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ScaledConv2d( - in_channels=layer2_channels, - out_channels=layer3_channels, - kernel_size=3, - stride=2, - ), - ActivationBalancer(channel_dim=1), - DoubleSwish(), - ) - self.out = ScaledLinear( - layer3_channels * (((in_channels - 1) // 2 - 1) // 2), out_channels - ) - # set learn_eps=False because out_norm is preceded by `out`, and `out` - # itself has learned scale, so the extra degree of freedom is not - # needed. - self.out_norm = BasicNorm(out_channels, learn_eps=False) - # constrain median of output to be close to zero. - self.out_balancer = ActivationBalancer( - channel_dim=-1, min_positive=0.45, max_positive=0.55 - ) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """Subsample x. - - Args: - x: - Its shape is (N, T, idim). - - Returns: - Return a tensor of shape (N, ((T-1)//2 - 1)//2, odim) - """ - # On entry, x is (N, T, idim) - x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) - x = self.conv(x) - # Now x is of shape (N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - # Now x is of shape (N, ((T-1)//2 - 1))//2, odim) - x = self.out_norm(x) - x = self.out_balancer(x) - return x - - -if __name__ == "__main__": - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - feature_dim = 50 - c = Conformer(num_features=feature_dim, d_model=128, nhead=4) - batch_size = 5 - seq_len = 20 - # Just make sure the forward pass runs. - f = c( - torch.randn(batch_size, seq_len, feature_dim), - torch.full((batch_size,), seq_len, dtype=torch.int64), - warmup=0.5, - ) diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/decode.py deleted file mode 100755 index ca4d860c9..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/decode.py +++ /dev/null @@ -1,690 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -When training with the L subset, usage: -(1) greedy search -./pruned_transducer_stateless2/decode.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 100 \ - --decoding-method greedy_search - -(2) modified beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 100 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(3) fast beam search (1best) -./pruned_transducer_stateless2/decode.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 1500 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 - -(4) fast beam search (nbest) -./pruned_transducer_stateless2/decode.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(5) fast beam search (nbest oracle WER) -./pruned_transducer_stateless2/decode.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --nbest-scale 0.5 - -(6) fast beam search (with LG) -./pruned_transducer_stateless2/decode.py \ - --epoch 10 \ - --avg 2 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang-dir data/lang_char \ - --max-duration 600 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import ReazonSpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from train import get_params, get_transducer_model -from tokenizer import Tokenizer - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="It specifies the checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--batch", - type=int, - default=None, - help="It specifies the batch checkpoint to use for decoding." - "Note: Epoch counts from 0.", - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. ", - ) - - parser.add_argument( - "--avg-last-n", - type=int, - default=0, - help="""If positive, --epoch and --avg are ignored and it - will use the last n checkpoints exp_dir/checkpoint-xxx.pt - where xxx is the number of processed batches while - saving that checkpoint. - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to - specify `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.35, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: Tokenizer, - batch: dict, - word_talbe: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = model.device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_talbe[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - beam=params.beam_size, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.text2word(sp.decode(hyp))) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: Tokenizer, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 100 - else: - log_interval = 2 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_talbe=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = sp.text2word(ref_text) - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if params.decoding_method == "fast_beam_search_nbest_LG": - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - if ( - params.decoding_method == "fast_beam_search_nbest" - or params.decoding_method == "fast_beam_search_nbest_oracle" - ): - params.suffix += f"-nbest-scale-{params.nbest_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - sp = Tokenizer.load(params.lang_dir, params.lang_type) - - # and are defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if params.avg_last_n > 0: - filenames = find_checkpoints(params.exp_dir)[: params.avg_last_n] - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - elif params.batch is not None: - filenames = f"{params.exp_dir}/checkpoint-{params.batch}.pt" - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints([filenames], device=device)) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.to(device) - model.eval() - model.device = device - - if "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lg_filename = params.lang_dir + "/LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - reazonspeech = ReazonSpeechAsrDataModule(args) - - dev_cuts = reazonspeech.valid_cuts() - dev_dl = reazonspeech.valid_dataloaders(dev_cuts) - - test_cuts = reazonspeech.test_cuts() - test_dl = reazonspeech.test_dataloaders(test_cuts) - - test_sets = ["dev", "test"] - test_dl = [dev_dl, test_dl] - - for test_set, test_dl in zip(test_sets, test_dl): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/decode_aishell.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/decode_aishell.py deleted file mode 100755 index 2e644ec2f..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/decode_aishell.py +++ /dev/null @@ -1,547 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless2/decode.py \ - --epoch 84 \ - --avg 25 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./pruned_transducer_stateless2/decode.py \ - --epoch 84 \ - --avg 25 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method beam_search \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 84 \ - --avg 25 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method modified_beam_search \ - --beam-size 4 - -(4) fast beam search -./pruned_transducer_stateless2/decode.py \ - --epoch 84 \ - --avg 25 \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --max-duration 600 \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 -""" - - -import argparse -import logging -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from aishell import AishellAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from finetune import get_params, get_transducer_model - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - write_error_stats, -) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="The lang dir", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - token_table: k2.SymbolTable, - batch: dict, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - token_table: - It maps token ID to a string. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - else: - hyp_tokens = [] - batch_size = encoder_out.size(0) - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyp_tokens.append(hyp) - - hyps = [[token_table[t] for t in tokens] for tokens in hyp_tokens] - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif params.decoding_method == "fast_beam_search": - return { - ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ): hyps - } - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - token_table: k2.SymbolTable, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - token_table: - It maps a token ID to a string. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - token_table=token_table, - decoding_graph=decoding_graph, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - # we compute CER for aishell dataset. - results_char = [] - for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results_char, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - -@torch.no_grad() -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - args.lang_dir = Path(args.lang_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "modified_beam_search", - ) - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict( - average_checkpoints(filenames, device=device), strict=False - ) - - model.to(device) - model.eval() - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - aishell = AishellAsrDataModule(args) - test_cuts = aishell.test_cuts() - dev_cuts = aishell.valid_cuts() - test_dl = aishell.test_dataloaders(test_cuts) - dev_dl = aishell.test_dataloaders(dev_cuts) - - test_sets = ["test", "dev"] - test_dls = [test_dl, dev_dl] - - for test_set, test_dl in zip(test_sets, test_dls): - results_dict = decode_dataset( - dl=test_dl, - params=params, - model=model, - token_table=lexicon.token_table, - decoding_graph=decoding_graph, - ) - - save_results( - params=params, - test_set_name=test_set, - results_dict=results_dict, - ) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/decode_stream.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/decode_stream.py deleted file mode 120000 index 3931e9a33..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -/var/data/share20/qc/k2/Github/icefall/egs/librispeech/ASR/pruned_transducer_stateless2/decode_stream.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/decoder.py deleted file mode 100644 index d44ed6f81..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/decoder.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -import torch.nn.functional as F -from scaling import ScaledConv1d, ScaledEmbedding - -from icefall.utils import is_jit_tracing - - -class Decoder(nn.Module): - """This class modifies the stateless decoder from the following paper: - - RNN-transducer with stateless prediction network - https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419 - - It removes the recurrent connection from the decoder, i.e., the prediction - network. Different from the above paper, it adds an extra Conv1d - right after the embedding layer. - - TODO: Implement https://arxiv.org/pdf/2109.07513.pdf - """ - - def __init__( - self, - vocab_size: int, - decoder_dim: int, - blank_id: int, - context_size: int, - ): - """ - Args: - vocab_size: - Number of tokens of the modeling unit including blank. - decoder_dim: - Dimension of the input embedding, and of the decoder output. - blank_id: - The ID of the blank symbol. - context_size: - Number of previous words to use to predict the next word. - 1 means bigram; 2 means trigram. n means (n+1)-gram. - """ - super().__init__() - - self.embedding = ScaledEmbedding( - num_embeddings=vocab_size, - embedding_dim=decoder_dim, - ) - self.blank_id = blank_id - - assert context_size >= 1, context_size - self.context_size = context_size - self.vocab_size = vocab_size - if context_size > 1: - self.conv = ScaledConv1d( - in_channels=decoder_dim, - out_channels=decoder_dim, - kernel_size=context_size, - padding=0, - groups=decoder_dim, - bias=False, - ) - else: - # It is to support torch script - self.conv = nn.Identity() - - def forward( - self, - y: torch.Tensor, - need_pad: bool = True # Annotation should be Union[bool, torch.Tensor] - # but, torch.jit.script does not support Union. - ) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, U). - need_pad: - True to left pad the input. Should be True during training. - False to not pad the input. Should be False during inference. - Returns: - Return a tensor of shape (N, U, decoder_dim). - """ - if isinstance(need_pad, torch.Tensor): - # This is for torch.jit.trace(), which cannot handle the case - # when the input argument is not a tensor. - need_pad = bool(need_pad) - - y = y.to(torch.int64) - # this stuff about clamp() is a temporary fix for a mismatch - # at utterance start, we use negative ids in beam_search.py - if torch.jit.is_tracing(): - # This is for exporting to PNNX via ONNX - embedding_out = self.embedding(y) - else: - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) - if self.context_size > 1: - embedding_out = embedding_out.permute(0, 2, 1) - if need_pad: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) - else: - # During inference time, there is no need to do extra padding - # as we only need one output - if not is_jit_tracing(): - assert embedding_out.size(-1) == self.context_size - embedding_out = self.conv(embedding_out) - embedding_out = embedding_out.permute(0, 2, 1) - embedding_out = F.relu(embedding_out) - return embedding_out diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/encoder_interface.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/encoder_interface.py deleted file mode 100644 index 257facce4..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/encoder_interface.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Tuple - -import torch -import torch.nn as nn - - -class EncoderInterface(nn.Module): - def forward( - self, x: torch.Tensor, x_lens: torch.Tensor - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A tensor of shape (batch_size, input_seq_len, num_features) - containing the input features. - x_lens: - A tensor of shape (batch_size,) containing the number of frames - in `x` before padding. - Returns: - Return a tuple containing two tensors: - - encoder_out, a tensor of (batch_size, out_seq_len, output_dim) - containing unnormalized probabilities, i.e., the output of a - linear layer. - - encoder_out_lens, a tensor of shape (batch_size,) containing - the number of frames in `encoder_out` before padding. - """ - raise NotImplementedError("Please implement it in a subclass") diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/export-onnx.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/export-onnx.py deleted file mode 100755 index 140b1d37f..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/export-onnx.py +++ /dev/null @@ -1,518 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang) - -""" -This script exports a transducer model from PyTorch to ONNX. - -We use the pre-trained model from -https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2 -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/wenetspeech/ASR - -repo_url=icefall_asr_wenetspeech_pruned_transducer_stateless2 -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_char/Linv.pt" -git lfs pull --include "exp/pretrained_epoch_10_avg_2.pt" - -cd exp -ln -s pretrained_epoch_10_avg_2.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./pruned_transducer_stateless2/export-onnx.py \ - --lang-dir $repo/data/lang_char \ - --epoch 99 \ - --avg 1 \ - --exp-dir $repo/exp - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -See ./onnx_pretrained.py for how to -use the exported ONNX models. -""" - -import argparse -import logging -from pathlib import Path -from typing import Dict, Tuple - -import onnx -import torch -import torch.nn as nn -from conformer import Conformer -from decoder import Decoder -from onnxruntime.quantization import QuantType, quantize_dynamic -from scaling_converter import convert_scaled_to_non_scaled -from train import get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import setup_logger, str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless5/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="The lang dir", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - return parser - - -def add_meta_data(filename: str, meta_data: Dict[str, str]): - """Add meta data to an ONNX model. It is changed in-place. - - Args: - filename: - Filename of the ONNX model to be changed. - meta_data: - Key-value pairs. - """ - model = onnx.load(filename) - for key, value in meta_data.items(): - meta = model.metadata_props.add() - meta.key = key - meta.value = value - - onnx.save(model, filename) - - -class OnnxEncoder(nn.Module): - """A wrapper for Conformer and the encoder_proj from the joiner""" - - def __init__(self, encoder: Conformer, encoder_proj: nn.Linear): - """ - Args: - encoder: - A Conformer encoder. - encoder_proj: - The projection layer for encoder from the joiner. - """ - super().__init__() - self.encoder = encoder - self.encoder_proj = encoder_proj - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Please see the help information of Conformer.forward - - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 1-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tuple containing: - - encoder_out, A 3-D tensor of shape (N, T', joiner_dim) - - encoder_out_lens, A 1-D tensor of shape (N,) - """ - encoder_out, encoder_out_lens = self.encoder(x, x_lens) - - encoder_out = self.encoder_proj(encoder_out) - # Now encoder_out is of shape (N, T, joiner_dim) - - return encoder_out, encoder_out_lens - - -class OnnxDecoder(nn.Module): - """A wrapper for Decoder and the decoder_proj from the joiner""" - - def __init__(self, decoder: Decoder, decoder_proj: nn.Linear): - super().__init__() - self.decoder = decoder - self.decoder_proj = decoder_proj - - def forward(self, y: torch.Tensor) -> torch.Tensor: - """ - Args: - y: - A 2-D tensor of shape (N, context_size). - Returns - Return a 2-D tensor of shape (N, joiner_dim) - """ - need_pad = False - decoder_output = self.decoder(y, need_pad=need_pad) - decoder_output = decoder_output.squeeze(1) - output = self.decoder_proj(decoder_output) - - return output - - -class OnnxJoiner(nn.Module): - """A wrapper for the joiner""" - - def __init__(self, output_linear: nn.Linear): - super().__init__() - self.output_linear = output_linear - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - logit = encoder_out + decoder_out - logit = self.output_linear(torch.tanh(logit)) - return logit - - -def export_encoder_model_onnx( - encoder_model: OnnxEncoder, - encoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the given encoder model to ONNX format. - The exported model has two inputs: - - - x, a tensor of shape (N, T, C); dtype is torch.float32 - - x_lens, a tensor of shape (N,); dtype is torch.int64 - - and it has two outputs: - - - encoder_out, a tensor of shape (N, T', joiner_dim) - - encoder_out_lens, a tensor of shape (N,) - - Args: - encoder_model: - The input encoder model - encoder_filename: - The filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - x = torch.zeros(1, 100, 80, dtype=torch.float32) - x_lens = torch.tensor([100], dtype=torch.int64) - - torch.onnx.export( - encoder_model, - (x, x_lens), - encoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["x", "x_lens"], - output_names=["encoder_out", "encoder_out_lens"], - dynamic_axes={ - "x": {0: "N", 1: "T"}, - "x_lens": {0: "N"}, - "encoder_out": {0: "N", 1: "T"}, - "encoder_out_lens": {0: "N"}, - }, - ) - - meta_data = { - "model_type": "conformer", - "version": "1", - "model_author": "k2-fsa", - "comment": "stateless5", - } - logging.info(f"meta_data: {meta_data}") - - add_meta_data(filename=encoder_filename, meta_data=meta_data) - - -def export_decoder_model_onnx( - decoder_model: OnnxDecoder, - decoder_filename: str, - opset_version: int = 11, -) -> None: - """Export the decoder model to ONNX format. - - The exported model has one input: - - - y: a torch.int64 tensor of shape (N, decoder_model.context_size) - - and has one output: - - - decoder_out: a torch.float32 tensor of shape (N, joiner_dim) - - Args: - decoder_model: - The decoder model to be exported. - decoder_filename: - Filename to save the exported ONNX model. - opset_version: - The opset version to use. - """ - context_size = decoder_model.decoder.context_size - vocab_size = decoder_model.decoder.vocab_size - - y = torch.zeros(10, context_size, dtype=torch.int64) - decoder_model = torch.jit.script(decoder_model) - torch.onnx.export( - decoder_model, - y, - decoder_filename, - verbose=False, - opset_version=opset_version, - input_names=["y"], - output_names=["decoder_out"], - dynamic_axes={ - "y": {0: "N"}, - "decoder_out": {0: "N"}, - }, - ) - - meta_data = { - "context_size": str(context_size), - "vocab_size": str(vocab_size), - } - add_meta_data(filename=decoder_filename, meta_data=meta_data) - - -def export_joiner_model_onnx( - joiner_model: nn.Module, - joiner_filename: str, - opset_version: int = 11, -) -> None: - """Export the joiner model to ONNX format. - The exported joiner model has two inputs: - - - encoder_out: a tensor of shape (N, joiner_dim) - - decoder_out: a tensor of shape (N, joiner_dim) - - and produces one output: - - - logit: a tensor of shape (N, vocab_size) - """ - joiner_dim = joiner_model.output_linear.weight.shape[1] - logging.info(f"joiner dim: {joiner_dim}") - - projected_encoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - projected_decoder_out = torch.rand(11, joiner_dim, dtype=torch.float32) - - torch.onnx.export( - joiner_model, - (projected_encoder_out, projected_decoder_out), - joiner_filename, - verbose=False, - opset_version=opset_version, - input_names=[ - "encoder_out", - "decoder_out", - ], - output_names=["logit"], - dynamic_axes={ - "encoder_out": {0: "N"}, - "decoder_out": {0: "N"}, - "logit": {0: "N"}, - }, - ) - meta_data = { - "joiner_dim": str(joiner_dim), - } - add_meta_data(filename=joiner_filename, meta_data=meta_data) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - setup_logger(f"{params.exp_dir}/log-export/log-export-onnx") - - logging.info(f"device: {device}") - - lexicon = Lexicon(params.lang_dir) - params.blank_id = 0 - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.to("cpu") - model.eval() - - convert_scaled_to_non_scaled(model, inplace=True) - - encoder = OnnxEncoder( - encoder=model.encoder, - encoder_proj=model.joiner.encoder_proj, - ) - - decoder = OnnxDecoder( - decoder=model.decoder, - decoder_proj=model.joiner.decoder_proj, - ) - - joiner = OnnxJoiner(output_linear=model.joiner.output_linear) - - encoder_num_param = sum([p.numel() for p in encoder.parameters()]) - decoder_num_param = sum([p.numel() for p in decoder.parameters()]) - joiner_num_param = sum([p.numel() for p in joiner.parameters()]) - total_num_param = encoder_num_param + decoder_num_param + joiner_num_param - logging.info(f"encoder parameters: {encoder_num_param}") - logging.info(f"decoder parameters: {decoder_num_param}") - logging.info(f"joiner parameters: {joiner_num_param}") - logging.info(f"total parameters: {total_num_param}") - - if params.iter > 0: - suffix = f"iter-{params.iter}" - else: - suffix = f"epoch-{params.epoch}" - - suffix += f"-avg-{params.avg}" - - opset_version = 13 - - logging.info("Exporting encoder") - encoder_filename = params.exp_dir / f"encoder-{suffix}.onnx" - export_encoder_model_onnx( - encoder, - encoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported encoder to {encoder_filename}") - - logging.info("Exporting decoder") - decoder_filename = params.exp_dir / f"decoder-{suffix}.onnx" - export_decoder_model_onnx( - decoder, - decoder_filename, - opset_version=opset_version, - ) - logging.info(f"Exported decoder to {decoder_filename}") - - logging.info("Exporting joiner") - joiner_filename = params.exp_dir / f"joiner-{suffix}.onnx" - export_joiner_model_onnx( - joiner, - joiner_filename, - opset_version=opset_version, - ) - logging.info(f"Exported joiner to {joiner_filename}") - - # Generate int8 quantization models - # See https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html#data-type-selection - - logging.info("Generate int8 quantization models") - - encoder_filename_int8 = params.exp_dir / f"encoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=encoder_filename, - model_output=encoder_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - decoder_filename_int8 = params.exp_dir / f"decoder-{suffix}.int8.onnx" - quantize_dynamic( - model_input=decoder_filename, - model_output=decoder_filename_int8, - op_types_to_quantize=["MatMul", "Gather"], - weight_type=QuantType.QInt8, - ) - - joiner_filename_int8 = params.exp_dir / f"joiner-{suffix}.int8.onnx" - quantize_dynamic( - model_input=joiner_filename, - model_output=joiner_filename_int8, - op_types_to_quantize=["MatMul"], - weight_type=QuantType.QInt8, - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - main() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/export.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/export.py deleted file mode 100755 index 78a20a5b1..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/export.py +++ /dev/null @@ -1,238 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. -""" -Usage: - -(1) Export to torchscript model using torch.jit.script() - -./pruned_transducer_stateless2/export.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang data/lang_char \ - --epoch 26 \ - --avg 5 \ - --jit true - -It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later -load it by `torch.jit.load("cpu_jit.pt")`. - -Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python -are on CPU. You can use `to("cuda")` to move them to a CUDA device. - -Please refer to -https://k2-fsa.github.io/sherpa/python/offline_asr/conformer/index.html -for how to use `cpu_jit.pt` for speech recognition. - -(2) Export `model.state_dict()` - -./pruned_transducer_stateless2/export.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --lang data/lang_char \ - --epoch 26 \ - --avg 5 - -It will generate a file `pretrained.pt` in the given `exp_dir`. You can later -load it by `icefall.checkpoint.load_checkpoint()`. - -To use the generated file with `pruned_transducer_stateless2/decode.py`, -you can do: - - cd /path/to/exp_dir - ln -s pretrained.pt epoch-9999.pt - - cd /path/to/egs/reazonspeech/ASR - ./pruned_transducer_stateless2/decode.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --epoch 9999 \ - --avg 1 \ - --max-duration 180 \ - --decoding-method greedy_search \ - --lang data/lang_char - -You can find pretrained models at -https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp -""" - -import argparse -import logging -from pathlib import Path - -import torch -from train import get_params, get_transducer_model -from tokenizer import Tokenizer - -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint -from icefall.lexicon import Lexicon -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=29, - help="""It specifies the checkpoint to use for averaging. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="The lang dir", - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - # add_model_arguments(parser) - - return parser - - -def main(): - parser = get_parser() - Tokenizer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = Tokenizer.load(params.lang_dir, params.lang_type) - - # is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - - model.to("cpu") - model.eval() - - if params.jit: - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / f"cpu_jit-epoch-{params.epoch}-avg-{params.avg}.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torch.jit.script") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = ( - params.exp_dir / f"pretrained-epoch-{params.epoch}-avg-{params.avg}.pt" - ) - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/finetune.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/finetune.py deleted file mode 100755 index c34f1593d..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/finetune.py +++ /dev/null @@ -1,1054 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Xiaoyu Yang, -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless2/finetune.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless2/exp \ - --full-libri 1 \ - --do-finetune 1 \ - --max-duration 100 - -""" - - -import argparse -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, List, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from aishell import AishellAsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, Eve -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import save_checkpoint_with_global_batch_idx -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def add_finetune_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--do-finetune", type=str2bool, default=False) - - parser.add_argument( - "--init-modules", - type=str, - default=None, - help=""" - Modules to be initialized. It matches all parameters starting with - a specific key. The keys are given with Comma separated. If None, - all modules will be initialised. For example, if you only want to - initialise all parameters staring with "encoder", use "encoder"; - if you want to initialise parameters starting with encoder or decoder, - use "encoder,joiner". - """, - ) - - parser.add_argument( - "--finetune-ckpt", - type=str, - default=None, - help="Fine-tuning from which checkpoint (a path to a .pt file)", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.0001, - help="The initial learning rate. This value should not need to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=100000, - help="""Number of steps that affects how rapidly the learning rate - decreases. During fine-tuning, we set this very large so that the - learning rate slowly decays with number of batches. You may tune - its value by yourself. - """, - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=100, - help="""Number of epochs that affects how rapidly the learning rate - decreases. During fine-tuning, we set this very large so that the - learning rate slowly decays with number of batches. You may tune - its value by yourself. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=1, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=20, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--valid-interval", - type=int, - default=3000, - help="""When training_subset is L, set the valid_interval to 3000. - When training_subset is M, set the valid_interval to 1000. - When training_subset is S, set the valid_interval to 400. - """, - ) - - parser.add_argument( - "--model-warm-step", - type=int, - default=3000, - help="""When training_subset is L, set the model_warm_step to 3000. - When training_subset is M, set the model_warm_step to 500. - When training_subset is S, set the model_warm_step to 100. - """, - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - Explanation of options saved in `params`: - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - best_train_epoch: It is the epoch that has the best training loss. - - best_valid_epoch: It is the epoch that has the best validation loss. - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - feature_dim: The model input dim. It has to match the one used - in computing features. - - subsampling_factor: The subsampling factor for the model. - - encoder_dim: Hidden dim for multi-head attention model. - - num_decoder_layers: Number of decoder layer of transformer decoder. - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "encoder_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 0: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def load_model_params( - ckpt: str, model: nn.Module, init_modules: List[str] = None, strict: bool = True -): - """Load model params from checkpoint - - Args: - ckpt (str): Path to the checkpoint - model (nn.Module): model to be loaded - - """ - logging.info(f"Loading checkpoint from {ckpt}") - checkpoint = torch.load(ckpt, map_location="cpu") - - # if module list is empty, load the whole model from ckpt - if not init_modules: - if next(iter(checkpoint["model"])).startswith("module."): - logging.info("Loading checkpoint saved by DDP") - - dst_state_dict = model.state_dict() - src_state_dict = checkpoint["model"] - for key in dst_state_dict.keys(): - src_key = "{}.{}".format("module", key) - dst_state_dict[key] = src_state_dict.pop(src_key) - assert len(src_state_dict) == 0 - model.load_state_dict(dst_state_dict, strict=strict) - else: - model.load_state_dict(checkpoint["model"], strict=strict) - else: - src_state_dict = checkpoint["model"] - dst_state_dict = model.state_dict() - for module in init_modules: - logging.info(f"Loading parameters starting with prefix {module}") - src_keys = [ - k for k in src_state_dict.keys() if k.startswith(module.strip() + ".") - ] - dst_keys = [ - k for k in dst_state_dict.keys() if k.startswith(module.strip() + ".") - ] - assert set(src_keys) == set(dst_keys) # two sets should match exactly - for key in src_keys: - dst_state_dict[key] = src_state_dict.pop(key) - - model.load_state_dict(dst_state_dict, strict=strict) - - return None - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CharCtcTrainingGraphCompiler, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - - y = graph_compiler.texts_to_ids(texts) - if isinstance(y, list): - y = k2.RaggedTensor(y).to(device) - else: - y = y.to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - ) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - graph_compiler: CharCtcTrainingGraphCompiler, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - graph_compiler: CharCtcTrainingGraphCompiler, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - lexicon = Lexicon(params.lang_dir) - graph_compiler = CharCtcTrainingGraphCompiler( - lexicon=lexicon, - device=device, - ) - - params.blank_id = lexicon.token_table[""] - params.vocab_size = max(lexicon.tokens) + 1 - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # load model parameters for model fine-tuning - if params.do_finetune: - modules = params.init_modules.split(",") if params.init_modules else None - checkpoints = load_model_params( - ckpt=params.finetune_ckpt, model=model, init_modules=modules - ) - else: - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - model.device = device - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - aishell = AishellAsrDataModule(args) - train_dl = aishell.train_dataloaders(aishell.train_cuts()) - valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) - - if not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - graph_compiler=graph_compiler, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - graph_compiler=graph_compiler, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - graph_compiler: CharCtcTrainingGraphCompiler, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = graph_compiler.texts_to_ids(supervisions["text"]) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - graph_compiler: CharCtcTrainingGraphCompiler, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - graph_compiler=graph_compiler, - batch=batch, - is_training=True, - warmup=0.0 if params.start_epoch == 1 else 1.0, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) - raise - - -def main(): - parser = get_parser() - AishellAsrDataModule.add_arguments( - parser - ) # you may replace this with your own dataset - add_finetune_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py deleted file mode 100755 index f90dd2b43..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py +++ /dev/null @@ -1,336 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads torchscript models, either exported by `torch.jit.trace()` -or by `torch.jit.script()`, and uses them to decode waves. -You can use the following command to get the exported models: - -./pruned_transducer_stateless2/export.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --tokens data/lang_char/tokens.txt \ - --epoch 10 \ - --avg 2 \ - --jit-trace 1 - -or - -./pruned_transducer_stateless2/export.py \ - --exp-dir ./pruned_transducer_stateless2/exp \ - --tokens data/lang_char/tokens.txt \ - --epoch 10 \ - --avg 2 \ - --jit 1 - -Usage of this script: - -./pruned_transducer_stateless2/jit_pretrained.py \ - --encoder-model-filename ./pruned_transducer_stateless2/exp/encoder_jit_trace.pt \ - --decoder-model-filename ./pruned_transducer_stateless2/exp/decoder_jit_trace.pt \ - --joiner-model-filename ./pruned_transducer_stateless2/exp/joiner_jit_trace.pt \ - --tokens data/lang_char/tokens.txt \ - /path/to/foo.wav \ - /path/to/bar.wav - -or - -./pruned_transducer_stateless2/jit_pretrained.py \ - --encoder-model-filename ./pruned_transducer_stateless2/exp/encoder_jit_script.pt \ - --decoder-model-filename ./pruned_transducer_stateless2/exp/decoder_jit_script.pt \ - --joiner-model-filename ./pruned_transducer_stateless2/exp/joiner_jit_script.pt \ - --tokens data/lang_char/tokens.txt \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can find pretrained models at -https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp -""" - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder torchscript model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder torchscript model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner torchscript model. ", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="Context size of the decoder model", - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - decoder: torch.jit.ScriptModule, - joiner: torch.jit.ScriptModule, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, - context_size: int, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - decoder: - The decoder model. - joiner: - The joiner model. - encoder_out: - A 3-D tensor of shape (N, T, C) - encoder_out_lens: - A 1-D tensor of shape (N,). - context_size: - The context size of the decoder model. - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3 - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - device = encoder_out.device - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - device=device, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = decoder( - decoder_input, - need_pad=torch.tensor([False]), - ).squeeze(1) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - current_encoder_out = current_encoder_out - # current_encoder_out's shape: (batch_size, encoder_out_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - - logits = joiner( - current_encoder_out, - decoder_out, - ) - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - device=device, - dtype=torch.int64, - ) - decoder_out = decoder( - decoder_input, - need_pad=torch.tensor([False]), - ) - decoder_out = decoder_out.squeeze(1) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - encoder = torch.jit.load(args.encoder_model_filename) - decoder = torch.jit.load(args.decoder_model_filename) - joiner = torch.jit.load(args.joiner_model_filename) - - encoder.eval() - decoder.eval() - joiner.eval() - - encoder.to(device) - decoder.to(device) - joiner.to(device) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = encoder( - x=features, - x_lens=feature_lengths, - ) - - hyps = greedy_search( - decoder=decoder, - joiner=joiner, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - context_size=args.context_size, - ) - symbol_table = k2.SymbolTable.from_file(args.tokens) - s = "\n" - for filename, hyp in zip(args.sound_files, hyps): - words = "".join([symbol_table[i] for i in hyp]) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/joiner.py deleted file mode 100644 index 9f88bd029..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/joiner.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -import torch.nn as nn -from scaling import ScaledLinear - -from icefall.utils import is_jit_tracing - - -class Joiner(nn.Module): - def __init__( - self, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - super().__init__() - - self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim) - self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim) - self.output_linear = ScaledLinear(joiner_dim, vocab_size) - - def forward( - self, - encoder_out: torch.Tensor, - decoder_out: torch.Tensor, - project_input: bool = True, - ) -> torch.Tensor: - """ - Args: - encoder_out: - Output from the encoder. Its shape is (N, T, s_range, C). - decoder_out: - Output from the decoder. Its shape is (N, T, s_range, C). - project_input: - If true, apply input projections encoder_proj and decoder_proj. - If this is false, it is the user's responsibility to do this - manually. - Returns: - Return a tensor of shape (N, T, s_range, C). - """ - if not is_jit_tracing(): - assert encoder_out.ndim == decoder_out.ndim - - if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) - else: - logit = encoder_out + decoder_out - - logit = self.output_linear(torch.tanh(logit)) - - return logit diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/lstmp.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/lstmp.py deleted file mode 100644 index dba6eb520..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/lstmp.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Optional, Tuple - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class LSTMP(nn.Module): - """LSTM with projection. - - PyTorch does not support exporting LSTM with projection to ONNX. - This class reimplements LSTM with projection using basic matrix-matrix - and matrix-vector operations. It is not intended for training. - """ - - def __init__(self, lstm: nn.LSTM): - """ - Args: - lstm: - LSTM with proj_size. We support only uni-directional, - 1-layer LSTM with projection at present. - """ - super().__init__() - assert lstm.bidirectional is False, lstm.bidirectional - assert lstm.num_layers == 1, lstm.num_layers - assert 0 < lstm.proj_size < lstm.hidden_size, ( - lstm.proj_size, - lstm.hidden_size, - ) - - assert lstm.batch_first is False, lstm.batch_first - - state_dict = lstm.state_dict() - - w_ih = state_dict["weight_ih_l0"] - w_hh = state_dict["weight_hh_l0"] - - b_ih = state_dict["bias_ih_l0"] - b_hh = state_dict["bias_hh_l0"] - - w_hr = state_dict["weight_hr_l0"] - self.input_size = lstm.input_size - self.proj_size = lstm.proj_size - self.hidden_size = lstm.hidden_size - - self.w_ih = w_ih - self.w_hh = w_hh - self.b = b_ih + b_hh - self.w_hr = w_hr - - def forward( - self, - input: torch.Tensor, - hx: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - """ - Args: - input: - A tensor of shape [T, N, hidden_size] - hx: - A tuple containing: - - h0: a tensor of shape (1, N, proj_size) - - c0: a tensor of shape (1, N, hidden_size) - Returns: - Return a tuple containing: - - output: a tensor of shape (T, N, proj_size). - - A tuple containing: - - h: a tensor of shape (1, N, proj_size) - - c: a tensor of shape (1, N, hidden_size) - - """ - x_list = input.unbind(dim=0) # We use batch_first=False - - if hx is not None: - h0, c0 = hx - else: - h0 = torch.zeros(1, input.size(1), self.proj_size) - c0 = torch.zeros(1, input.size(1), self.hidden_size) - h0 = h0.squeeze(0) - c0 = c0.squeeze(0) - y_list = [] - for x in x_list: - gates = F.linear(x, self.w_ih, self.b) + F.linear(h0, self.w_hh) - i, f, g, o = gates.chunk(4, dim=1) - - i = i.sigmoid() - f = f.sigmoid() - g = g.tanh() - o = o.sigmoid() - - c = f * c0 + i * g - h = o * c.tanh() - - h = F.linear(h, self.w_hr) - y_list.append(h) - - c0 = c - h0 = h - - y = torch.stack(y_list, dim=0) - - return y, (h0.unsqueeze(0), c0.unsqueeze(0)) diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/model.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/model.py deleted file mode 100644 index 272d06c37..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/model.py +++ /dev/null @@ -1,207 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import Tuple - -import k2 -import torch -import torch.nn as nn -from encoder_interface import EncoderInterface -from scaling import ScaledLinear - -from icefall.utils import add_sos - - -class Transducer(nn.Module): - """It implements https://arxiv.org/pdf/1211.3711.pdf - "Sequence Transduction with Recurrent Neural Networks" - """ - - def __init__( - self, - encoder: EncoderInterface, - decoder: nn.Module, - joiner: nn.Module, - encoder_dim: int, - decoder_dim: int, - joiner_dim: int, - vocab_size: int, - ): - """ - Args: - encoder: - It is the transcription network in the paper. Its accepts - two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). - It returns two tensors: `logits` of shape (N, T, encoder_dm) and - `logit_lens` of shape (N,). - decoder: - It is the prediction network in the paper. Its input shape - is (N, U) and its output shape is (N, U, decoder_dim). - It should contain one attribute: `blank_id`. - joiner: - It has two inputs with shapes: (N, T, encoder_dim) and - (N, U, decoder_dim). - Its output shape is (N, T, U, vocab_size). Note that its output - contains unnormalized probs, i.e., not processed by log-softmax. - """ - super().__init__() - assert isinstance(encoder, EncoderInterface), type(encoder) - assert hasattr(decoder, "blank_id") - - self.encoder = encoder - self.decoder = decoder - self.joiner = joiner - - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) - self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) - - def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - y: k2.RaggedTensor, - prune_range: int = 5, - am_scale: float = 0.0, - lm_scale: float = 0.0, - warmup: float = 1.0, - reduction: str = "sum", - delay_penalty: float = 0.0, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C). - x_lens: - A 1-D tensor of shape (N,). It contains the number of frames in `x` - before padding. - y: - A ragged tensor with 2 axes [utt][label]. It contains labels of each - utterance. - prune_range: - The prune range for rnnt loss, it means how many symbols(context) - we are considering for each frame to compute the loss. - am_scale: - The scale to smooth the loss with am (output of encoder network) - part - lm_scale: - The scale to smooth the loss with lm (output of predictor network) - part - warmup: - A value warmup >= 0 that determines which modules are active, values - warmup > 1 "are fully warmed up" and all modules will be active. - reduction: - "sum" to sum the losses over all utterances in the batch. - "none" to return the loss in a 1-D tensor for each utterance - in the batch. - delay_penalty: - A constant value used to penalize symbol delay, to encourage - streaming models to emit symbols earlier. - See https://github.com/k2-fsa/k2/issues/955 and - https://arxiv.org/pdf/2211.00490.pdf for more details. - Returns: - Returns: - Return the transducer loss. - - Note: - Regarding am_scale & lm_scale, it will make the loss-function one of - the form: - lm_scale * lm_probs + am_scale * am_probs + - (1-lm_scale-am_scale) * combined_probs - """ - assert reduction in ("sum", "none"), reduction - assert x.ndim == 3, x.shape - assert x_lens.ndim == 1, x_lens.shape - assert y.num_axes == 2, y.num_axes - - assert x.size(0) == x_lens.size(0) == y.dim0 - - encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) - assert torch.all(x_lens > 0) - - # Now for the decoder, i.e., the prediction network - row_splits = y.shape.row_splits(1) - y_lens = row_splits[1:] - row_splits[:-1] - - blank_id = self.decoder.blank_id - sos_y = add_sos(y, sos_id=blank_id) - - # sos_y_padded: [B, S + 1], start with SOS. - sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) - - # decoder_out: [B, S + 1, decoder_dim] - decoder_out = self.decoder(sos_y_padded) - - # Note: y does not start with SOS - # y_padded : [B, S] - y_padded = y.pad(mode="constant", padding_value=0) - - y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) - boundary[:, 2] = y_lens - boundary[:, 3] = x_lens - - lm = self.simple_lm_proj(decoder_out) - am = self.simple_am_proj(encoder_out) - - with torch.cuda.amp.autocast(enabled=False): - simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( - lm=lm.float(), - am=am.float(), - symbols=y_padded, - termination_symbol=blank_id, - lm_only_scale=lm_scale, - am_only_scale=am_scale, - boundary=boundary, - reduction=reduction, - delay_penalty=delay_penalty, - return_grad=True, - ) - - # ranges : [B, T, prune_range] - ranges = k2.get_rnnt_prune_ranges( - px_grad=px_grad, - py_grad=py_grad, - boundary=boundary, - s_range=prune_range, - ) - - # am_pruned : [B, T, prune_range, encoder_dim] - # lm_pruned : [B, T, prune_range, decoder_dim] - am_pruned, lm_pruned = k2.do_rnnt_pruning( - am=self.joiner.encoder_proj(encoder_out), - lm=self.joiner.decoder_proj(decoder_out), - ranges=ranges, - ) - - # logits : [B, T, prune_range, vocab_size] - - # project_input=False since we applied the decoder's input projections - # prior to do_rnnt_pruning (this is an optimization for speed). - logits = self.joiner(am_pruned, lm_pruned, project_input=False) - - with torch.cuda.amp.autocast(enabled=False): - pruned_loss = k2.rnnt_loss_pruned( - logits=logits.float(), - symbols=y_padded, - ranges=ranges, - termination_symbol=blank_id, - boundary=boundary, - delay_penalty=delay_penalty, - reduction=reduction, - ) - - return (simple_loss, pruned_loss) diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/onnx_check.py deleted file mode 100755 index 2d46eede1..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/onnx_check.py +++ /dev/null @@ -1,303 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2022 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This script checks that exported onnx models produce the same output -with the given torchscript model for the same input. - -Usage: - -./pruned_transducer_stateless2/onnx_check.py \ - --jit-filename ./t/cpu_jit.pt \ - --onnx-encoder-filename ./t/encoder.onnx \ - --onnx-decoder-filename ./t/decoder.onnx \ - --onnx-joiner-filename ./t/joiner.onnx \ - --onnx-joiner-encoder-proj-filename ./t/joiner_encoder_proj.onnx \ - --onnx-joiner-decoder-proj-filename ./t/joiner_decoder_proj.onnx - -You can generate cpu_jit.pt, encoder.onnx, decoder.onnx, and other -xxx.onnx files using ./export.py - -We provide pretrained models at: -https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2/tree/main/exp -""" - -import argparse -import logging - -from icefall import is_module_available - -if not is_module_available("onnxruntime"): - raise ValueError("Please 'pip install onnxruntime' first.") - -import onnxruntime as ort -import torch - -ort.set_default_logger_severity(3) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--jit-filename", - required=True, - type=str, - help="Path to the torchscript model exported by torch.jit.script", - ) - - parser.add_argument( - "--onnx-encoder-filename", - required=True, - type=str, - help="Path to the onnx encoder model", - ) - - parser.add_argument( - "--onnx-decoder-filename", - required=True, - type=str, - help="Path to the onnx decoder model", - ) - - parser.add_argument( - "--onnx-joiner-filename", - required=True, - type=str, - help="Path to the onnx joiner model", - ) - - parser.add_argument( - "--onnx-joiner-encoder-proj-filename", - required=True, - type=str, - help="Path to the onnx joiner encoder projection model", - ) - - parser.add_argument( - "--onnx-joiner-decoder-proj-filename", - required=True, - type=str, - help="Path to the onnx joiner decoder projection model", - ) - - return parser - - -def test_encoder( - model: torch.jit.ScriptModule, - encoder_session: ort.InferenceSession, -): - inputs = encoder_session.get_inputs() - outputs = encoder_session.get_outputs() - input_names = [n.name for n in inputs] - output_names = [n.name for n in outputs] - - assert inputs[0].shape == ["N", "T", 80] - assert inputs[1].shape == ["N"] - - for N in [1, 5]: - for T in [12, 25]: - print("N, T", N, T) - x = torch.rand(N, T, 80, dtype=torch.float32) - x_lens = torch.randint(low=10, high=T + 1, size=(N,)) - x_lens[0] = T - - encoder_inputs = { - input_names[0]: x.numpy(), - input_names[1]: x_lens.numpy(), - } - encoder_out, encoder_out_lens = encoder_session.run( - output_names, - encoder_inputs, - ) - - torch_encoder_out, torch_encoder_out_lens = model.encoder(x, x_lens) - - encoder_out = torch.from_numpy(encoder_out) - assert torch.allclose(encoder_out, torch_encoder_out, atol=1e-05), ( - (encoder_out - torch_encoder_out).abs().max(), - encoder_out.shape, - torch_encoder_out.shape, - ) - - -def test_decoder( - model: torch.jit.ScriptModule, - decoder_session: ort.InferenceSession, -): - inputs = decoder_session.get_inputs() - outputs = decoder_session.get_outputs() - input_names = [n.name for n in inputs] - output_names = [n.name for n in outputs] - - assert inputs[0].shape == ["N", 2] - for N in [1, 5, 10]: - y = torch.randint(low=1, high=500, size=(10, 2)) - - decoder_inputs = {input_names[0]: y.numpy()} - decoder_out = decoder_session.run( - output_names, - decoder_inputs, - )[0] - decoder_out = torch.from_numpy(decoder_out) - - torch_decoder_out = model.decoder(y, need_pad=False) - assert torch.allclose(decoder_out, torch_decoder_out, atol=1e-5), ( - (decoder_out - torch_decoder_out).abs().max() - ) - - -def test_joiner( - model: torch.jit.ScriptModule, - joiner_session: ort.InferenceSession, - joiner_encoder_proj_session: ort.InferenceSession, - joiner_decoder_proj_session: ort.InferenceSession, -): - joiner_inputs = joiner_session.get_inputs() - joiner_outputs = joiner_session.get_outputs() - joiner_input_names = [n.name for n in joiner_inputs] - joiner_output_names = [n.name for n in joiner_outputs] - - assert joiner_inputs[0].shape == ["N", 512] - assert joiner_inputs[1].shape == ["N", 512] - - joiner_encoder_proj_inputs = joiner_encoder_proj_session.get_inputs() - encoder_proj_input_name = joiner_encoder_proj_inputs[0].name - - assert joiner_encoder_proj_inputs[0].shape == ["N", 512] - - joiner_encoder_proj_outputs = joiner_encoder_proj_session.get_outputs() - encoder_proj_output_name = joiner_encoder_proj_outputs[0].name - - joiner_decoder_proj_inputs = joiner_decoder_proj_session.get_inputs() - decoder_proj_input_name = joiner_decoder_proj_inputs[0].name - - assert joiner_decoder_proj_inputs[0].shape == ["N", 512] - - joiner_decoder_proj_outputs = joiner_decoder_proj_session.get_outputs() - decoder_proj_output_name = joiner_decoder_proj_outputs[0].name - - for N in [1, 5, 10]: - encoder_out = torch.rand(N, 512) - decoder_out = torch.rand(N, 512) - - projected_encoder_out = torch.rand(N, 512) - projected_decoder_out = torch.rand(N, 512) - - joiner_inputs = { - joiner_input_names[0]: projected_encoder_out.numpy(), - joiner_input_names[1]: projected_decoder_out.numpy(), - } - joiner_out = joiner_session.run(joiner_output_names, joiner_inputs)[0] - joiner_out = torch.from_numpy(joiner_out) - - torch_joiner_out = model.joiner( - projected_encoder_out, - projected_decoder_out, - project_input=False, - ) - assert torch.allclose(joiner_out, torch_joiner_out, atol=1e-5), ( - (joiner_out - torch_joiner_out).abs().max() - ) - - # Now test encoder_proj - joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} - joiner_encoder_proj_out = joiner_encoder_proj_session.run( - [encoder_proj_output_name], joiner_encoder_proj_inputs - )[0] - joiner_encoder_proj_out = torch.from_numpy(joiner_encoder_proj_out) - - torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) - assert torch.allclose( - joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) - - # Now test decoder_proj - joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} - joiner_decoder_proj_out = joiner_decoder_proj_session.run( - [decoder_proj_output_name], joiner_decoder_proj_inputs - )[0] - joiner_decoder_proj_out = torch.from_numpy(joiner_decoder_proj_out) - - torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) - assert torch.allclose( - joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) - - -@torch.no_grad() -def main(): - args = get_parser().parse_args() - logging.info(vars(args)) - - model = torch.jit.load(args.jit_filename) - - options = ort.SessionOptions() - options.inter_op_num_threads = 1 - options.intra_op_num_threads = 1 - - logging.info("Test encoder") - encoder_session = ort.InferenceSession( - args.onnx_encoder_filename, - sess_options=options, - providers=["CPUExecutionProvider"], - ) - test_encoder(model, encoder_session) - - logging.info("Test decoder") - decoder_session = ort.InferenceSession( - args.onnx_decoder_filename, - sess_options=options, - providers=["CPUExecutionProvider"], - ) - test_decoder(model, decoder_session) - - logging.info("Test joiner") - joiner_session = ort.InferenceSession( - args.onnx_joiner_filename, - sess_options=options, - providers=["CPUExecutionProvider"], - ) - joiner_encoder_proj_session = ort.InferenceSession( - args.onnx_joiner_encoder_proj_filename, - sess_options=options, - providers=["CPUExecutionProvider"], - ) - joiner_decoder_proj_session = ort.InferenceSession( - args.onnx_joiner_decoder_proj_filename, - sess_options=options, - providers=["CPUExecutionProvider"], - ) - test_joiner( - model, - joiner_session, - joiner_encoder_proj_session, - joiner_decoder_proj_session, - ) - logging.info("Finished checking ONNX models") - - -if __name__ == "__main__": - torch.manual_seed(20220727) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py deleted file mode 100755 index c784853ee..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py +++ /dev/null @@ -1,428 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads ONNX models and uses them to decode waves. -You can use the following command to get the exported models: - -We use the pre-trained model from -https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless5_offline/ -as an example to show how to use this file. - -1. Download the pre-trained model - -cd egs/wenetspeech/ASR - -repo_url=https://huggingface.co/luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless5_offline/ -GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url -repo=$(basename $repo_url) - -pushd $repo -git lfs pull --include "data/lang_char/Linv.pt" -git lfs pull --include "exp/pretrained_epoch_4_avg_1.pt" -git lfs pull --include "exp/cpu_jit_epoch_4_avg_1_torch.1.7.1.pt" - -cd exp -ln -s pretrained_epoch_9_avg_1_torch.1.7.1.pt epoch-99.pt -popd - -2. Export the model to ONNX - -./pruned_transducer_stateless5/export-onnx.py \ - --lang-dir $repo/data/lang_char \ - --epoch 99 \ - --avg 1 \ - --use-averaged-model 0 \ - --exp-dir $repo/exp \ - --num-encoder-layers 24 \ - --dim-feedforward 1536 \ - --nhead 8 \ - --encoder-dim 384 \ - --decoder-dim 512 \ - --joiner-dim 512 - -It will generate the following 3 files inside $repo/exp: - - - encoder-epoch-99-avg-1.onnx - - decoder-epoch-99-avg-1.onnx - - joiner-epoch-99-avg-1.onnx - -3. Run this file - -./pruned_transducer_stateless5/onnx_pretrained.py \ - --encoder-model-filename $repo/exp/encoder-epoch-99-avg-1.onnx \ - --decoder-model-filename $repo/exp/decoder-epoch-99-avg-1.onnx \ - --joiner-model-filename $repo/exp/joiner-epoch-99-avg-1.onnx \ - --tokens $repo/data/lang_char/tokens.txt \ - $repo/test_wavs/DEV_T0000000000.wav \ - $repo/test_wavs/DEV_T0000000001.wav \ - $repo/test_wavs/DEV_T0000000002.wav -""" - -import argparse -import logging -import math -from typing import List, Tuple - -import k2 -import kaldifeat -import onnxruntime as ort -import torch -import torchaudio -from torch.nn.utils.rnn import pad_sequence - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--encoder-model-filename", - type=str, - required=True, - help="Path to the encoder onnx model. ", - ) - - parser.add_argument( - "--decoder-model-filename", - type=str, - required=True, - help="Path to the decoder onnx model. ", - ) - - parser.add_argument( - "--joiner-model-filename", - type=str, - required=True, - help="Path to the joiner onnx model. ", - ) - - parser.add_argument( - "--tokens", - type=str, - help="""Path to tokens.txt.""", - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - return parser - - -class OnnxModel: - def __init__( - self, - encoder_model_filename: str, - decoder_model_filename: str, - joiner_model_filename: str, - ): - session_opts = ort.SessionOptions() - session_opts.inter_op_num_threads = 1 - session_opts.intra_op_num_threads = 4 - - self.session_opts = session_opts - - self.init_encoder(encoder_model_filename) - self.init_decoder(decoder_model_filename) - self.init_joiner(joiner_model_filename) - - def init_encoder(self, encoder_model_filename: str): - self.encoder = ort.InferenceSession( - encoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - def init_decoder(self, decoder_model_filename: str): - self.decoder = ort.InferenceSession( - decoder_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - decoder_meta = self.decoder.get_modelmeta().custom_metadata_map - self.context_size = int(decoder_meta["context_size"]) - self.vocab_size = int(decoder_meta["vocab_size"]) - - logging.info(f"context_size: {self.context_size}") - logging.info(f"vocab_size: {self.vocab_size}") - - def init_joiner(self, joiner_model_filename: str): - self.joiner = ort.InferenceSession( - joiner_model_filename, - sess_options=self.session_opts, - providers=["CPUExecutionProvider"], - ) - - joiner_meta = self.joiner.get_modelmeta().custom_metadata_map - self.joiner_dim = int(joiner_meta["joiner_dim"]) - - logging.info(f"joiner_dim: {self.joiner_dim}") - - def run_encoder( - self, - x: torch.Tensor, - x_lens: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Args: - x: - A 3-D tensor of shape (N, T, C) - x_lens: - A 2-D tensor of shape (N,). Its dtype is torch.int64 - Returns: - Return a tuple containing: - - encoder_out, its shape is (N, T', joiner_dim) - - encoder_out_lens, its shape is (N,) - """ - out = self.encoder.run( - [ - self.encoder.get_outputs()[0].name, - self.encoder.get_outputs()[1].name, - ], - { - self.encoder.get_inputs()[0].name: x.numpy(), - self.encoder.get_inputs()[1].name: x_lens.numpy(), - }, - ) - return torch.from_numpy(out[0]), torch.from_numpy(out[1]) - - def run_decoder(self, decoder_input: torch.Tensor) -> torch.Tensor: - """ - Args: - decoder_input: - A 2-D tensor of shape (N, context_size) - Returns: - Return a 2-D tensor of shape (N, joiner_dim) - """ - out = self.decoder.run( - [self.decoder.get_outputs()[0].name], - {self.decoder.get_inputs()[0].name: decoder_input.numpy()}, - )[0] - - return torch.from_numpy(out) - - def run_joiner( - self, encoder_out: torch.Tensor, decoder_out: torch.Tensor - ) -> torch.Tensor: - """ - Args: - encoder_out: - A 2-D tensor of shape (N, joiner_dim) - decoder_out: - A 2-D tensor of shape (N, joiner_dim) - Returns: - Return a 2-D tensor of shape (N, vocab_size) - """ - out = self.joiner.run( - [self.joiner.get_outputs()[0].name], - { - self.joiner.get_inputs()[0].name: encoder_out.numpy(), - self.joiner.get_inputs()[1].name: decoder_out.numpy(), - }, - )[0] - - return torch.from_numpy(out) - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -def greedy_search( - model: OnnxModel, - encoder_out: torch.Tensor, - encoder_out_lens: torch.Tensor, -) -> List[List[int]]: - """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. - Args: - model: - The transducer model. - encoder_out: - A 3-D tensor of shape (N, T, joiner_dim) - encoder_out_lens: - A 1-D tensor of shape (N,). - Returns: - Return the decoded results for each utterance. - """ - assert encoder_out.ndim == 3, encoder_out.shape - assert encoder_out.size(0) >= 1, encoder_out.size(0) - - packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( - input=encoder_out, - lengths=encoder_out_lens.cpu(), - batch_first=True, - enforce_sorted=False, - ) - - blank_id = 0 # hard-code to 0 - - batch_size_list = packed_encoder_out.batch_sizes.tolist() - N = encoder_out.size(0) - - assert torch.all(encoder_out_lens > 0), encoder_out_lens - assert N == batch_size_list[0], (N, batch_size_list) - - context_size = model.context_size - hyps = [[blank_id] * context_size for _ in range(N)] - - decoder_input = torch.tensor( - hyps, - dtype=torch.int64, - ) # (N, context_size) - - decoder_out = model.run_decoder(decoder_input) - - offset = 0 - for batch_size in batch_size_list: - start = offset - end = offset + batch_size - current_encoder_out = packed_encoder_out.data[start:end] - # current_encoder_out's shape: (batch_size, joiner_dim) - offset = end - - decoder_out = decoder_out[:batch_size] - logits = model.run_joiner(current_encoder_out, decoder_out) - - # logits'shape (batch_size, vocab_size) - - assert logits.ndim == 2, logits.shape - y = logits.argmax(dim=1).tolist() - emitted = False - for i, v in enumerate(y): - if v != blank_id: - hyps[i].append(v) - emitted = True - if emitted: - # update decoder output - decoder_input = [h[-context_size:] for h in hyps[:batch_size]] - decoder_input = torch.tensor( - decoder_input, - dtype=torch.int64, - ) - decoder_out = model.run_decoder(decoder_input) - - sorted_ans = [h[context_size:] for h in hyps] - ans = [] - unsorted_indices = packed_encoder_out.unsorted_indices.tolist() - for i in range(N): - ans.append(sorted_ans[unsorted_indices[i]]) - - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - args = parser.parse_args() - logging.info(vars(args)) - model = OnnxModel( - encoder_model_filename=args.encoder_model_filename, - decoder_model_filename=args.decoder_model_filename, - joiner_model_filename=args.joiner_model_filename, - ) - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = "cpu" - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = args.sample_rate - opts.mel_opts.num_bins = 80 - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {args.sound_files}") - waves = read_sound_files( - filenames=args.sound_files, - expected_sample_rate=args.sample_rate, - ) - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence( - features, - batch_first=True, - padding_value=math.log(1e-10), - ) - - feature_lengths = torch.tensor(feature_lengths, dtype=torch.int64) - encoder_out, encoder_out_lens = model.run_encoder(features, feature_lengths) - - hyps = greedy_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - s = "\n" - - symbol_table = k2.SymbolTable.from_file(args.tokens) - - def token_ids_to_words(token_ids: List[int]) -> str: - text = "" - for i in token_ids: - text += symbol_table[i] - return text.replace("▁", " ").strip() - - for filename, hyp in zip(args.sound_files, hyps): - words = token_ids_to_words(hyp) - s += f"{filename}:\n{words}\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/optim.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/optim.py deleted file mode 100644 index f54bc2709..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/optim.py +++ /dev/null @@ -1,319 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey) -# -# See ../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from typing import List, Optional, Union - -import torch -from torch.optim import Optimizer - - -class Eve(Optimizer): - r""" - Implements Eve algorithm. This is a modified version of AdamW with a special - way of setting the weight-decay / shrinkage-factor, which is designed to make the - rms of the parameters approach a particular target_rms (default: 0.1). This is - for use with networks with 'scaled' versions of modules (see scaling.py), which - will be close to invariant to the absolute scale on the parameter matrix. - - The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. - The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. - Eve is unpublished so far. - - Arguments: - params (iterable): iterable of parameters to optimize or dicts defining - parameter groups - lr (float, optional): learning rate (default: 1e-3) - betas (Tuple[float, float], optional): coefficients used for computing - running averages of gradient and its square (default: (0.9, 0.999)) - eps (float, optional): term added to the denominator to improve - numerical stability (default: 1e-8) - weight_decay (float, optional): weight decay coefficient (default: 3e-4; - this value means that the weight would decay significantly after - about 3k minibatches. Is not multiplied by learning rate, but - is conditional on RMS-value of parameter being > target_rms. - target_rms (float, optional): target root-mean-square value of - parameters, if they fall below this we will stop applying weight decay. - - - .. _Adam\: A Method for Stochastic Optimization: - https://arxiv.org/abs/1412.6980 - .. _Decoupled Weight Decay Regularization: - https://arxiv.org/abs/1711.05101 - .. _On the Convergence of Adam and Beyond: - https://openreview.net/forum?id=ryQu7f-RZ - """ - - def __init__( - self, - params, - lr=1e-3, - betas=(0.9, 0.98), - eps=1e-8, - weight_decay=1e-3, - target_rms=0.1, - ): - if not 0.0 <= lr: - raise ValueError("Invalid learning rate: {}".format(lr)) - if not 0.0 <= eps: - raise ValueError("Invalid epsilon value: {}".format(eps)) - if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) - if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) - if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) - if not 0 < target_rms <= 10.0: - raise ValueError("Invalid target_rms value: {}".format(target_rms)) - defaults = dict( - lr=lr, - betas=betas, - eps=eps, - weight_decay=weight_decay, - target_rms=target_rms, - ) - super(Eve, self).__init__(params, defaults) - - def __setstate__(self, state): - super(Eve, self).__setstate__(state) - - @torch.no_grad() - def step(self, closure=None): - """Performs a single optimization step. - - Arguments: - closure (callable, optional): A closure that reevaluates the model - and returns the loss. - """ - loss = None - if closure is not None: - with torch.enable_grad(): - loss = closure() - - for group in self.param_groups: - for p in group["params"]: - if p.grad is None: - continue - - # Perform optimization step - grad = p.grad - if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") - - state = self.state[p] - - # State initialization - if len(state) == 0: - state["step"] = 0 - # Exponential moving average of gradient values - state["exp_avg"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - # Exponential moving average of squared gradient values - state["exp_avg_sq"] = torch.zeros_like( - p, memory_format=torch.preserve_format - ) - - exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] - - beta1, beta2 = group["betas"] - - state["step"] += 1 - bias_correction1 = 1 - beta1 ** state["step"] - bias_correction2 = 1 - beta2 ** state["step"] - - # Decay the first and second moment running average coefficient - exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( - group["eps"] - ) - - step_size = group["lr"] / bias_correction1 - target_rms = group["target_rms"] - weight_decay = group["weight_decay"] - - if p.numel() > 1: - # avoid applying this weight-decay on "scaling factors" - # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) - p.mul_(1 - (weight_decay * is_above_target_rms)) - p.addcdiv_(exp_avg, denom, value=-step_size) - - # Constrain the range of scalar weights - if p.numel() == 1: - p.clamp_(min=-10, max=2) - - return loss - - -class LRScheduler(object): - """ - Base-class for learning rate schedulers where the learning-rate depends on both the - batch and the epoch. - """ - - def __init__(self, optimizer: Optimizer, verbose: bool = False): - # Attach optimizer - if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) - self.optimizer = optimizer - self.verbose = verbose - - for group in optimizer.param_groups: - group.setdefault("initial_lr", group["lr"]) - - self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] - - self.epoch = 0 - self.batch = 0 - - def state_dict(self): - """Returns the state of the scheduler as a :class:`dict`. - - It contains an entry for every variable in self.__dict__ which - is not the optimizer. - """ - return { - "base_lrs": self.base_lrs, - "epoch": self.epoch, - "batch": self.batch, - } - - def load_state_dict(self, state_dict): - """Loads the schedulers state. - - Args: - state_dict (dict): scheduler state. Should be an object returned - from a call to :meth:`state_dict`. - """ - self.__dict__.update(state_dict) - - def get_last_lr(self) -> List[float]: - """Return last computed learning rate by current scheduler. Will be a list of float.""" - return self._last_lr - - def get_lr(self): - # Compute list of learning rates from self.epoch and self.batch and - # self.base_lrs; this must be overloaded by the user. - # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ] - raise NotImplementedError - - def step_batch(self, batch: Optional[int] = None) -> None: - # Step the batch index, or just set it. If `batch` is specified, it - # must be the batch index from the start of training, i.e. summed over - # all epochs. - # You can call this in any order; if you don't provide 'batch', it should - # of course be called once per batch. - if batch is not None: - self.batch = batch - else: - self.batch = self.batch + 1 - self._set_lrs() - - def step_epoch(self, epoch: Optional[int] = None): - # Step the epoch index, or just set it. If you provide the 'epoch' arg, - # you should call this at the start of the epoch; if you don't provide the 'epoch' - # arg, you should call it at the end of the epoch. - if epoch is not None: - self.epoch = epoch - else: - self.epoch = self.epoch + 1 - self._set_lrs() - - def _set_lrs(self): - values = self.get_lr() - assert len(values) == len(self.optimizer.param_groups) - - for i, data in enumerate(zip(self.optimizer.param_groups, values)): - param_group, lr = data - param_group["lr"] = lr - self.print_lr(self.verbose, i, lr) - self._last_lr = [group["lr"] for group in self.optimizer.param_groups] - - def print_lr(self, is_verbose, group, lr): - """Display the current learning rate.""" - if is_verbose: - print( - f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate" - f" of group {group} to {lr:.4e}." - ) - - -class Eden(LRScheduler): - """ - Eden scheduler. - lr = initial_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * - (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) - - E.g. suggest initial-lr = 0.003 (passed to optimizer). - - Args: - optimizer: the optimizer to change the learning rates on - lr_batches: the number of batches after which we start significantly - decreasing the learning rate, suggest 5000. - lr_epochs: the number of epochs after which we start significantly - decreasing the learning rate, suggest 6 if you plan to do e.g. - 20 to 40 epochs, but may need smaller number if dataset is huge - and you will do few epochs. - """ - - def __init__( - self, - optimizer: Optimizer, - lr_batches: Union[int, float], - lr_epochs: Union[int, float], - verbose: bool = False, - ): - super(Eden, self).__init__(optimizer, verbose) - self.lr_batches = lr_batches - self.lr_epochs = lr_epochs - - def get_lr(self): - factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 - ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 - ) - return [x * factor for x in self.base_lrs] - - -def _test_eden(): - m = torch.nn.Linear(100, 100) - optim = Eve(m.parameters(), lr=0.003) - - scheduler = Eden(optim, lr_batches=30, lr_epochs=2, verbose=True) - - for epoch in range(10): - scheduler.step_epoch(epoch) # sets epoch to `epoch` - - for step in range(20): - x = torch.randn(200, 100).detach() - x.requires_grad = True - y = m(x) - dy = torch.randn(200, 100).detach() - f = (y * dy).sum() - f.backward() - - optim.step() - scheduler.step_batch() - optim.zero_grad() - print("last lr = ", scheduler.get_last_lr()) - print("state dict = ", scheduler.state_dict()) - - -if __name__ == "__main__": - _test_eden() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/pretrained.py deleted file mode 100755 index 07a470693..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/pretrained.py +++ /dev/null @@ -1,339 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# 2022 Xiaomi Crop. (authors: Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --decoding-method greedy_search \ - --max-sym-per-frame 1 \ - /path/to/foo.wav \ - /path/to/bar.wav -(2) modified beam search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --decoding-method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav -(3) fast beam search -./pruned_transducer_stateless2/pretrained.py \ - --checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \ - --lang-dir ./data/lang_char \ - --decoding-method fast_beam_search \ - --beam 4 \ - --max-contexts 4 \ - --max-states 8 \ - /path/to/foo.wav \ - /path/to/bar.wav -You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`. -Note: ./pruned_transducer_stateless2/exp/pretrained.pt is generated by -./pruned_transducer_stateless2/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from tokenizer import Tokenizer -from torch.nn.utils.rnn import pad_sequence -from train import get_params, get_transducer_model - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--lang-dir", - type=str, - help="""Path to lang. - """, - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""Used only when --decoding-method is beam_search - and modified_beam_search """, - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --decoding-method is greedy_search. - """, - ) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - Tokenizer.add_arguments(parser) - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = Tokenizer.load(params.lang_dir, params.lang_type) - - # is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - else: - decoding_graph = None - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.decoding_method}" - if params.decoding_method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.decoding_method}") - - hyps.append(sp.decode(hyp).split()) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = "".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/scaling.py deleted file mode 100644 index 91d64c1df..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/scaling.py +++ /dev/null @@ -1,1014 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey, Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import collections -import random -from itertools import repeat -from typing import Optional, Tuple - -import torch -import torch.backends.cudnn.rnn as rnn -import torch.nn as nn -from torch import _VF, Tensor - -from icefall.utils import is_jit_tracing - - -def _ntuple(n): - def parse(x): - if isinstance(x, collections.Iterable): - return x - return tuple(repeat(x, n)) - - return parse - - -_single = _ntuple(1) -_pair = _ntuple(2) - - -class ActivationBalancerFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 - ) -> Tensor: - if x.requires_grad: - if channel_dim < 0: - channel_dim += x.ndim - - # sum_dims = [d for d in range(x.ndim) if d != channel_dim] - # The above line is not torch scriptable for torch 1.6.0 - # torch.jit.frontend.NotSupportedError: comprehension ifs not supported yet: # noqa - sum_dims = [] - for d in range(x.ndim): - if d != channel_dim: - sum_dims.append(d) - - xgt0 = x > 0 - proportion_positive = torch.mean( - xgt0.to(x.dtype), dim=sum_dims, keepdim=True - ) - factor1 = ( - (min_positive - proportion_positive).relu() - * (max_factor / min_positive) - if min_positive != 0.0 - else 0.0 - ) - factor2 = ( - (proportion_positive - max_positive).relu() - * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 - else 0.0 - ) - factor = factor1 + factor2 - if isinstance(factor, float): - factor = torch.zeros_like(proportion_positive) - - mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = mean_abs < min_abs - above_threshold = mean_abs > max_abs - - ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) - ctx.max_factor = max_factor - ctx.sum_dims = sum_dims - return x - - @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None, None, None, None]: - factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors - dtype = x_grad.dtype - scale_factor = ( - (below_threshold.to(dtype) - above_threshold.to(dtype)) - * (xgt0.to(dtype) - 0.5) - * (ctx.max_factor * 2.0) - ) - - neg_delta_grad = x_grad.abs() * (factor + scale_factor) - return x_grad - neg_delta_grad, None, None, None, None, None, None - - -class GradientFilterFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x: Tensor, - batch_dim: int, # e.g., 1 - threshold: float, # e.g., 10.0 - *params: Tensor, # module parameters - ) -> Tuple[Tensor, ...]: - if x.requires_grad: - if batch_dim < 0: - batch_dim += x.ndim - ctx.batch_dim = batch_dim - ctx.threshold = threshold - return (x,) + params - - @staticmethod - def backward( - ctx, - x_grad: Tensor, - *param_grads: Tensor, - ) -> Tuple[Tensor, ...]: - eps = 1.0e-20 - dim = ctx.batch_dim - norm_dims = [d for d in range(x_grad.ndim) if d != dim] - norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt() - median_norm = norm_of_batch.median() - - cutoff = median_norm * ctx.threshold - inv_mask = (cutoff + norm_of_batch) / (cutoff + eps) - mask = 1.0 / (inv_mask + eps) - x_grad = x_grad * mask - - avg_mask = 1.0 / (inv_mask.mean() + eps) - param_grads = [avg_mask * g for g in param_grads] - - return (x_grad, None, None) + tuple(param_grads) - - -class GradientFilter(torch.nn.Module): - """This is used to filter out elements that have extremely large gradients - in batch and the module parameters with soft masks. - - Args: - batch_dim (int): - The batch dimension. - threshold (float): - For each element in batch, its gradient will be - filtered out if the gradient norm is larger than - `grad_norm_threshold * median`, where `median` is the median - value of gradient norms of all elememts in batch. - """ - - def __init__(self, batch_dim: int = 1, threshold: float = 10.0): - super(GradientFilter, self).__init__() - self.batch_dim = batch_dim - self.threshold = threshold - - def forward(self, x: Tensor, *params: Tensor) -> Tuple[Tensor, ...]: - if torch.jit.is_scripting() or is_jit_tracing(): - return (x,) + params - else: - return GradientFilterFunction.apply( - x, - self.batch_dim, - self.threshold, - *params, - ) - - -class BasicNorm(torch.nn.Module): - """ - This is intended to be a simpler, and hopefully cheaper, replacement for - LayerNorm. The observation this is based on, is that Transformer-type - networks, especially with pre-norm, sometimes seem to set one of the - feature dimensions to a large constant value (e.g. 50), which "defeats" - the LayerNorm because the output magnitude is then not strongly dependent - on the other (useful) features. Presumably the weight and bias of the - LayerNorm are required to allow it to do this. - - So the idea is to introduce this large constant value as an explicit - parameter, that takes the role of the "eps" in LayerNorm, so the network - doesn't have to do this trick. We make the "eps" learnable. - - Args: - num_channels: the number of channels, e.g. 512. - channel_dim: the axis/dimension corresponding to the channel, - interprted as an offset from the input's ndim if negative. - shis is NOT the num_channels; it should typically be one of - {-2, -1, 0, 1, 2, 3}. - eps: the initial "epsilon" that we add as ballast in: - scale = ((input_vec**2).mean() + epsilon)**-0.5 - Note: our epsilon is actually large, but we keep the name - to indicate the connection with conventional LayerNorm. - learn_eps: if true, we learn epsilon; if false, we keep it - at the initial value. - """ - - def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True, - ) -> None: - super(BasicNorm, self).__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - if learn_eps: - self.eps = nn.Parameter(torch.tensor(eps).log().detach()) - else: - self.register_buffer("eps", torch.tensor(eps).log().detach()) - - def forward(self, x: Tensor) -> Tensor: - if not is_jit_tracing(): - assert x.shape[self.channel_dim] == self.num_channels - scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() - ) ** -0.5 - return x * scales - - -class ScaledLinear(nn.Linear): - """ - A modified version of nn.Linear where the parameters are scaled before - use, via: - weight = self.weight * self.weight_scale.exp() - bias = self.bias * self.bias_scale.exp() - - Args: - Accepts the standard args and kwargs that nn.Linear accepts - e.g. in_features, out_features, bias=False. - - initial_scale: you can override this if you want to increase - or decrease the initial magnitude of the module's output - (affects the initialization of weight_scale and bias_scale). - Another option, if you want to do something like this, is - to re-initialize the parameters. - initial_speed: this affects how fast the parameter will - learn near the start of training; you can set it to a - value less than one if you suspect that a module - is contributing to instability near the start of training. - Nnote: regardless of the use of this option, it's best to - use schedulers like Noam that have a warm-up period. - Alternatively you can set it to more than 1 if you want it to - initially train faster. Must be greater than 0. - """ - - def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs, - ): - super(ScaledLinear, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter("bias_scale", None) - - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in nn.Linear - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3**0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += torch.tensor(scale / std).log() - - def get_weight(self): - return self.weight * self.weight_scale.exp() - - def get_bias(self): - if self.bias is None or self.bias_scale is None: - return None - else: - return self.bias * self.bias_scale.exp() - - def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) - - -class ScaledConv1d(nn.Conv1d): - # See docs for ScaledLinear - def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs, - ): - super(ScaledConv1d, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - - self.bias_scale: Optional[nn.Parameter] # for torchscript - - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter("bias_scale", None) - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in base class - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3**0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += torch.tensor(scale / std).log() - - def get_weight(self): - return self.weight * self.weight_scale.exp() - - def get_bias(self): - bias = self.bias - bias_scale = self.bias_scale - if bias is None or bias_scale is None: - return None - else: - return bias * bias_scale.exp() - - def forward(self, input: Tensor) -> Tensor: - F = torch.nn.functional - if self.padding_mode != "zeros": - return F.conv1d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - self.get_weight(), - self.get_bias(), - self.stride, - (0,), - self.dilation, - self.groups, - ) - return F.conv1d( - input, - self.get_weight(), - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) - - -class ScaledConv2d(nn.Conv2d): - # See docs for ScaledLinear - def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - **kwargs, - ): - super(ScaledConv2d, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - self.weight_scale = nn.Parameter(initial_scale.clone().detach()) - if self.bias is not None: - self.bias_scale = nn.Parameter(initial_scale.clone().detach()) - else: - self.register_parameter("bias_scale", None) - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in base class - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3**0.5) * std - nn.init.uniform_(self.weight, -a, a) - if self.bias is not None: - nn.init.constant_(self.bias, 0.0) - fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) - with torch.no_grad(): - self.weight_scale += torch.tensor(scale / std).log() - - def get_weight(self): - return self.weight * self.weight_scale.exp() - - def get_bias(self): - # see https://github.com/pytorch/pytorch/issues/24135 - bias = self.bias - bias_scale = self.bias_scale - if bias is None or bias_scale is None: - return None - else: - return bias * bias_scale.exp() - - def _conv_forward(self, input, weight): - F = torch.nn.functional - if self.padding_mode != "zeros": - return F.conv2d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - weight, - self.get_bias(), - self.stride, - (0, 0), - self.dilation, - self.groups, - ) - return F.conv2d( - input, - weight, - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) - - def forward(self, input: Tensor) -> Tensor: - return self._conv_forward(input, self.get_weight()) - - -class ScaledLSTM(nn.LSTM): - # See docs for ScaledLinear. - # This class implements LSTM with scaling mechanism, using `torch._VF.lstm` - # Please refer to https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py - def __init__( - self, - *args, - initial_scale: float = 1.0, - initial_speed: float = 1.0, - grad_norm_threshold: float = 10.0, - **kwargs, - ): - if "bidirectional" in kwargs: - assert kwargs["bidirectional"] is False - super(ScaledLSTM, self).__init__(*args, **kwargs) - initial_scale = torch.tensor(initial_scale).log() - self._scales_names = [] - self._scales = [] - for name in self._flat_weights_names: - scale_name = name + "_scale" - self._scales_names.append(scale_name) - param = nn.Parameter(initial_scale.clone().detach()) - setattr(self, scale_name, param) - self._scales.append(param) - - self.grad_filter = GradientFilter(batch_dim=1, threshold=grad_norm_threshold) - - self._reset_parameters( - initial_speed - ) # Overrides the reset_parameters in base class - - def _reset_parameters(self, initial_speed: float): - std = 0.1 / initial_speed - a = (3**0.5) * std - scale = self.hidden_size**-0.5 - v = scale / std - for idx, name in enumerate(self._flat_weights_names): - if "weight" in name: - nn.init.uniform_(self._flat_weights[idx], -a, a) - with torch.no_grad(): - self._scales[idx] += torch.tensor(v).log() - elif "bias" in name: - nn.init.constant_(self._flat_weights[idx], 0.0) - - def _flatten_parameters(self, flat_weights) -> None: - """Resets parameter data pointer so that they can use faster code paths. - - Right now, this works only if the module is on the GPU and cuDNN is enabled. - Otherwise, it's a no-op. - - This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa - """ - # Short-circuits if _flat_weights is only partially instantiated - if len(flat_weights) != len(self._flat_weights_names): - return - - for w in flat_weights: - if not isinstance(w, Tensor): - return - # Short-circuits if any tensor in flat_weights is not acceptable to cuDNN - # or the tensors in flat_weights are of different dtypes - - first_fw = flat_weights[0] - dtype = first_fw.dtype - for fw in flat_weights: - if ( - not isinstance(fw.data, Tensor) - or not (fw.data.dtype == dtype) - or not fw.data.is_cuda - or not torch.backends.cudnn.is_acceptable(fw.data) - ): - return - - # If any parameters alias, we fall back to the slower, copying code path. This is - # a sufficient check, because overlapping parameter buffers that don't completely - # alias would break the assumptions of the uniqueness check in - # Module.named_parameters(). - unique_data_ptrs = set(p.data_ptr() for p in flat_weights) - if len(unique_data_ptrs) != len(flat_weights): - return - - with torch.cuda.device_of(first_fw): - # Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is - # an inplace operation on self._flat_weights - with torch.no_grad(): - if torch._use_cudnn_rnn_flatten_weight(): - num_weights = 4 if self.bias else 2 - if self.proj_size > 0: - num_weights += 1 - torch._cudnn_rnn_flatten_weight( - flat_weights, - num_weights, - self.input_size, - rnn.get_cudnn_mode(self.mode), - self.hidden_size, - self.proj_size, - self.num_layers, - self.batch_first, - bool(self.bidirectional), - ) - - def _get_flat_weights(self): - """Get scaled weights, and resets their data pointer.""" - flat_weights = [] - for idx in range(len(self._flat_weights_names)): - flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp()) - self._flatten_parameters(flat_weights) - return flat_weights - - def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None): - # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa - # The change for calling `_VF.lstm()` is: - # self._flat_weights -> self._get_flat_weights() - if hx is None: - h_zeros = torch.zeros( - self.num_layers, - input.size(1), - self.proj_size if self.proj_size > 0 else self.hidden_size, - dtype=input.dtype, - device=input.device, - ) - c_zeros = torch.zeros( - self.num_layers, - input.size(1), - self.hidden_size, - dtype=input.dtype, - device=input.device, - ) - hx = (h_zeros, c_zeros) - - self.check_forward_args(input, hx, None) - - flat_weights = self._get_flat_weights() - input, *flat_weights = self.grad_filter(input, *flat_weights) - - result = _VF.lstm( - input, - hx, - flat_weights, - self.bias, - self.num_layers, - self.dropout, - self.training, - self.bidirectional, - self.batch_first, - ) - - output = result[0] - hidden = result[1:] - return output, hidden - - -class ActivationBalancer(torch.nn.Module): - """ - Modifies the backpropped derivatives of a function to try to encourage, for - each channel, that it is positive at least a proportion `threshold` of the - time. It does this by multiplying negative derivative values by up to - (1+max_factor), and positive derivative values by up to (1-max_factor), - interpolated from 1 at the threshold to those extremal values when none - of the inputs are positive. - - - Args: - channel_dim: the dimension/axis corresponding to the channel, e.g. - -1, 0, 1, 2; will be interpreted as an offset from x.ndim if negative. - min_positive: the minimum, per channel, of the proportion of the time - that (x > 0), below which we start to modify the derivatives. - max_positive: the maximum, per channel, of the proportion of the time - that (x > 0), above which we start to modify the derivatives. - max_factor: the maximum factor by which we modify the derivatives for - either the sign constraint or the magnitude constraint; - e.g. with max_factor=0.02, the the derivatives would be multiplied by - values in the range [0.98..1.02]. - min_abs: the minimum average-absolute-value per channel, which - we allow, before we start to modify the derivatives to prevent - this. - max_abs: the maximum average-absolute-value per channel, which - we allow, before we start to modify the derivatives to prevent - this. - balance_prob: the probability to apply the ActivationBalancer. - """ - - def __init__( - self, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0, - balance_prob: float = 0.25, - ): - super(ActivationBalancer, self).__init__() - self.channel_dim = channel_dim - self.min_positive = min_positive - self.max_positive = max_positive - self.max_factor = max_factor - self.min_abs = min_abs - self.max_abs = max_abs - assert 0 < balance_prob <= 1, balance_prob - self.balance_prob = balance_prob - - def forward(self, x: Tensor) -> Tensor: - if random.random() >= self.balance_prob: - return x - - return ActivationBalancerFunction.apply( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - self.max_factor / self.balance_prob, - self.min_abs, - self.max_abs, - ) - - -class DoubleSwishFunction(torch.autograd.Function): - """ - double_swish(x) = x * torch.sigmoid(x-1) - This is a definition, originally motivated by its close numerical - similarity to swish(swish(x)), where swish(x) = x * sigmoid(x). - - Memory-efficient derivative computation: - double_swish(x) = x * s, where s(x) = torch.sigmoid(x-1) - double_swish'(x) = d/dx double_swish(x) = x * s'(x) + x' * s(x) = x * s'(x) + s(x). - Now, s'(x) = s(x) * (1-s(x)). - double_swish'(x) = x * s'(x) + s(x). - = x * s(x) * (1-s(x)) + s(x). - = double_swish(x) * (1-s(x)) + s(x) - ... so we just need to remember s(x) but not x itself. - """ - - @staticmethod - def forward(ctx, x: Tensor) -> Tensor: - x = x.detach() - s = torch.sigmoid(x - 1.0) - y = x * s - ctx.save_for_backward(s, y) - return y - - @staticmethod - def backward(ctx, y_grad: Tensor) -> Tensor: - s, y = ctx.saved_tensors - return (y * (1 - s) + s) * y_grad - - -class DoubleSwish(torch.nn.Module): - def forward(self, x: Tensor) -> Tensor: - """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). - """ - if torch.jit.is_scripting() or is_jit_tracing(): - return x * torch.sigmoid(x - 1.0) - else: - return DoubleSwishFunction.apply(x) - - -class ScaledEmbedding(nn.Module): - r"""This is a modified version of nn.Embedding that introduces a learnable scale - on the parameters. Note: due to how we initialize it, it's best used with - schedulers like Noam that have a warmup period. - - It is a simple lookup table that stores embeddings of a fixed dictionary and size. - - This module is often used to store word embeddings and retrieve them using indices. - The input to the module is a list of indices, and the output is the corresponding - word embeddings. - - Args: - num_embeddings (int): size of the dictionary of embeddings - embedding_dim (int): the size of each embedding vector - padding_idx (int, optional): If given, pads the output with the embedding vector at :attr:`padding_idx` - (initialized to zeros) whenever it encounters the index. - scale_grad_by_freq (boolean, optional): If given, this will scale gradients by the inverse of frequency of - the words in the mini-batch. Default ``False``. - sparse (bool, optional): If ``True``, gradient w.r.t. :attr:`weight` matrix will be a sparse tensor. - See Notes for more details regarding sparse gradients. - - initial_speed (float, optional): This affects how fast the parameter will - learn near the start of training; you can set it to a value less than - one if you suspect that a module is contributing to instability near - the start of training. Note: regardless of the use of this option, - it's best to use schedulers like Noam that have a warm-up period. - Alternatively you can set it to more than 1 if you want it to - initially train faster. Must be greater than 0. - - - Attributes: - weight (Tensor): the learnable weights of the module of shape (num_embeddings, embedding_dim) - initialized from :math:`\mathcal{N}(0, 1)` - - Shape: - - Input: :math:`(*)`, LongTensor of arbitrary shape containing the indices to extract - - Output: :math:`(*, H)`, where `*` is the input shape and :math:`H=\text{embedding\_dim}` - - .. note:: - Keep in mind that only a limited number of optimizers support - sparse gradients: currently it's :class:`optim.SGD` (`CUDA` and `CPU`), - :class:`optim.SparseAdam` (`CUDA` and `CPU`) and :class:`optim.Adagrad` (`CPU`) - - .. note:: - With :attr:`padding_idx` set, the embedding vector at - :attr:`padding_idx` is initialized to all zeros. However, note that this - vector can be modified afterwards, e.g., using a customized - initialization method, and thus changing the vector used to pad the - output. The gradient for this vector from :class:`~torch.nn.Embedding` - is always zero. - - Examples:: - - >>> # an Embedding module containing 10 tensors of size 3 - >>> embedding = nn.Embedding(10, 3) - >>> # a batch of 2 samples of 4 indices each - >>> input = torch.LongTensor([[1,2,4,5],[4,3,2,9]]) - >>> embedding(input) - tensor([[[-0.0251, -1.6902, 0.7172], - [-0.6431, 0.0748, 0.6969], - [ 1.4970, 1.3448, -0.9685], - [-0.3677, -2.7265, -0.1685]], - - [[ 1.4970, 1.3448, -0.9685], - [ 0.4362, -0.4004, 0.9400], - [-0.6431, 0.0748, 0.6969], - [ 0.9124, -2.3616, 1.1151]]]) - - - >>> # example with padding_idx - >>> embedding = nn.Embedding(10, 3, padding_idx=0) - >>> input = torch.LongTensor([[0,2,0,5]]) - >>> embedding(input) - tensor([[[ 0.0000, 0.0000, 0.0000], - [ 0.1535, -2.0309, 0.9315], - [ 0.0000, 0.0000, 0.0000], - [-0.1655, 0.9897, 0.0635]]]) - - """ - __constants__ = [ - "num_embeddings", - "embedding_dim", - "padding_idx", - "scale_grad_by_freq", - "sparse", - ] - - num_embeddings: int - embedding_dim: int - padding_idx: int - scale_grad_by_freq: bool - weight: Tensor - sparse: bool - - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - initial_speed: float = 1.0, - ) -> None: - super(ScaledEmbedding, self).__init__() - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - if padding_idx is not None: - if padding_idx > 0: - assert ( - padding_idx < self.num_embeddings - ), "Padding_idx must be within num_embeddings" - elif padding_idx < 0: - assert ( - padding_idx >= -self.num_embeddings - ), "Padding_idx must be within num_embeddings" - padding_idx = self.num_embeddings + padding_idx - self.padding_idx = padding_idx - self.scale_grad_by_freq = scale_grad_by_freq - - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() - self.sparse = sparse - - self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) - self.reset_parameters(initial_speed) - - def reset_parameters(self, initial_speed: float = 1.0) -> None: - std = 0.1 / initial_speed - nn.init.normal_(self.weight, std=std) - nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) - - if self.padding_idx is not None: - with torch.no_grad(): - self.weight[self.padding_idx].fill_(0) - - def forward(self, input: Tensor) -> Tensor: - F = torch.nn.functional - scale = self.scale.exp() - if input.numel() < self.num_embeddings: - return ( - F.embedding( - input, - self.weight, - self.padding_idx, - None, - 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, - self.sparse, - ) - * scale - ) - else: - return F.embedding( - input, - self.weight * scale, - self.padding_idx, - None, - 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, - self.sparse, - ) - - def extra_repr(self) -> str: - # s = "{num_embeddings}, {embedding_dim}, scale={scale}" - s = "{num_embeddings}, {embedding_dim}" - if self.padding_idx is not None: - s += ", padding_idx={padding_idx}" - if self.scale_grad_by_freq is not False: - s += ", scale_grad_by_freq={scale_grad_by_freq}" - if self.sparse is not False: - s += ", sparse=True" - return s.format(**self.__dict__) - - -def _test_activation_balancer_sign(): - probs = torch.arange(0, 1, 0.01) - N = 1000 - x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer( - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - max_factor=0.2, - min_abs=0.0, - ) - - y_grad = torch.sign(torch.randn(probs.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_activation_balancer_sign: x = ", x) - print("_test_activation_balancer_sign: y grad = ", y_grad) - print("_test_activation_balancer_sign: x grad = ", x.grad) - - -def _test_activation_balancer_magnitude(): - magnitudes = torch.arange(0, 1, 0.01) - N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) - x = x.detach() - x.requires_grad = True - m = ActivationBalancer( - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - max_factor=0.2, - min_abs=0.2, - max_abs=0.8, - ) - - y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) - - y = m(x) - y.backward(gradient=y_grad) - print("_test_activation_balancer_magnitude: x = ", x) - print("_test_activation_balancer_magnitude: y grad = ", y_grad) - print("_test_activation_balancer_magnitude: x grad = ", x.grad) - - -def _test_basic_norm(): - num_channels = 128 - m = BasicNorm(num_channels=num_channels, channel_dim=1) - - x = torch.randn(500, num_channels) - - y = m(x) - - assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() - print("x rms = ", x_rms) - print("y rms = ", y_rms) - assert y_rms < x_rms - assert y_rms > 0.5 * x_rms - - -def _test_double_swish_deriv(): - x = torch.randn(10, 12, dtype=torch.double) * 0.5 - x.requires_grad = True - m = DoubleSwish() - torch.autograd.gradcheck(m, x) - - -def _test_scaled_lstm(): - N, L = 2, 30 - dim_in, dim_hidden = 10, 20 - m = ScaledLSTM(input_size=dim_in, hidden_size=dim_hidden, bias=True) - x = torch.randn(L, N, dim_in) - h0 = torch.randn(1, N, dim_hidden) - c0 = torch.randn(1, N, dim_hidden) - y, (h, c) = m(x, (h0, c0)) - assert y.shape == (L, N, dim_hidden) - assert h.shape == (1, N, dim_hidden) - assert c.shape == (1, N, dim_hidden) - - -def _test_grad_filter(): - threshold = 50.0 - time, batch, channel = 200, 5, 128 - grad_filter = GradientFilter(batch_dim=1, threshold=threshold) - - for i in range(2): - x = torch.randn(time, batch, channel, requires_grad=True) - w = nn.Parameter(torch.ones(5)) - b = nn.Parameter(torch.zeros(5)) - - x_out, w_out, b_out = grad_filter(x, w, b) - - w_out_grad = torch.randn_like(w) - b_out_grad = torch.randn_like(b) - x_out_grad = torch.rand_like(x) - if i % 2 == 1: - # The gradient norm of the first element must be larger than - # `threshold * median`, where `median` is the median value - # of gradient norms of all elements in batch. - x_out_grad[:, 0, :] = torch.full((time, channel), threshold) - - torch.autograd.backward( - [x_out, w_out, b_out], [x_out_grad, w_out_grad, b_out_grad] - ) - - print( - "_test_grad_filter: for gradient norms, the first element > median * threshold ", # noqa - i % 2 == 1, - ) - - print( - "_test_grad_filter: x_out_grad norm = ", - (x_out_grad**2).mean(dim=(0, 2)).sqrt(), - ) - print( - "_test_grad_filter: x.grad norm = ", - (x.grad**2).mean(dim=(0, 2)).sqrt(), - ) - print("_test_grad_filter: w_out_grad = ", w_out_grad) - print("_test_grad_filter: w.grad = ", w.grad) - print("_test_grad_filter: b_out_grad = ", b_out_grad) - print("_test_grad_filter: b.grad = ", b.grad) - - -if __name__ == "__main__": - _test_activation_balancer_sign() - _test_activation_balancer_magnitude() - _test_basic_norm() - _test_double_swish_deriv() - _test_scaled_lstm() - _test_grad_filter() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/scaling_converter.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/scaling_converter.py deleted file mode 100644 index a6540c584..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/scaling_converter.py +++ /dev/null @@ -1,320 +0,0 @@ -# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -This file provides functions to convert `ScaledLinear`, `ScaledConv1d`, -`ScaledConv2d`, and `ScaledEmbedding` to their non-scaled counterparts: -`nn.Linear`, `nn.Conv1d`, `nn.Conv2d`, and `nn.Embedding`. - -The scaled version are required only in the training time. It simplifies our -life by converting them to their non-scaled version during inference. -""" - -import copy -import re -from typing import List - -import torch -import torch.nn as nn -from lstmp import LSTMP -from scaling import ( - ActivationBalancer, - BasicNorm, - ScaledConv1d, - ScaledConv2d, - ScaledEmbedding, - ScaledLinear, - ScaledLSTM, -) - - -class NonScaledNorm(nn.Module): - """See BasicNorm for doc""" - - def __init__( - self, - num_channels: int, - eps_exp: float, - channel_dim: int = -1, # CAUTION: see documentation. - ): - super().__init__() - self.num_channels = num_channels - self.channel_dim = channel_dim - self.eps_exp = eps_exp - - def forward(self, x: torch.Tensor) -> torch.Tensor: - if not torch.jit.is_tracing(): - assert x.shape[self.channel_dim] == self.num_channels - scales = ( - torch.mean(x * x, dim=self.channel_dim, keepdim=True) + self.eps_exp - ).pow(-0.5) - return x * scales - - -def scaled_linear_to_linear(scaled_linear: ScaledLinear) -> nn.Linear: - """Convert an instance of ScaledLinear to nn.Linear. - - Args: - scaled_linear: - The layer to be converted. - Returns: - Return a linear layer. It satisfies: - - scaled_linear(x) == linear(x) - - for any given input tensor `x`. - """ - assert isinstance(scaled_linear, ScaledLinear), type(scaled_linear) - - weight = scaled_linear.get_weight() - bias = scaled_linear.get_bias() - has_bias = bias is not None - - linear = torch.nn.Linear( - in_features=scaled_linear.in_features, - out_features=scaled_linear.out_features, - bias=True, # otherwise, it throws errors when converting to PNNX format - # device=weight.device, # Pytorch version before v1.9.0 does not have - # this argument. Comment out for now, we will - # see if it will raise error for versions - # after v1.9.0 - ) - linear.weight.data.copy_(weight) - - if has_bias: - linear.bias.data.copy_(bias) - else: - linear.bias.data.zero_() - - return linear - - -def scaled_conv1d_to_conv1d(scaled_conv1d: ScaledConv1d) -> nn.Conv1d: - """Convert an instance of ScaledConv1d to nn.Conv1d. - - Args: - scaled_conv1d: - The layer to be converted. - Returns: - Return an instance of nn.Conv1d that has the same `forward()` behavior - of the given `scaled_conv1d`. - """ - assert isinstance(scaled_conv1d, ScaledConv1d), type(scaled_conv1d) - - weight = scaled_conv1d.get_weight() - bias = scaled_conv1d.get_bias() - has_bias = bias is not None - - conv1d = nn.Conv1d( - in_channels=scaled_conv1d.in_channels, - out_channels=scaled_conv1d.out_channels, - kernel_size=scaled_conv1d.kernel_size, - stride=scaled_conv1d.stride, - padding=scaled_conv1d.padding, - dilation=scaled_conv1d.dilation, - groups=scaled_conv1d.groups, - bias=scaled_conv1d.bias is not None, - padding_mode=scaled_conv1d.padding_mode, - ) - - conv1d.weight.data.copy_(weight) - if has_bias: - conv1d.bias.data.copy_(bias) - - return conv1d - - -def scaled_conv2d_to_conv2d(scaled_conv2d: ScaledConv2d) -> nn.Conv2d: - """Convert an instance of ScaledConv2d to nn.Conv2d. - - Args: - scaled_conv2d: - The layer to be converted. - Returns: - Return an instance of nn.Conv2d that has the same `forward()` behavior - of the given `scaled_conv2d`. - """ - assert isinstance(scaled_conv2d, ScaledConv2d), type(scaled_conv2d) - - weight = scaled_conv2d.get_weight() - bias = scaled_conv2d.get_bias() - has_bias = bias is not None - - conv2d = nn.Conv2d( - in_channels=scaled_conv2d.in_channels, - out_channels=scaled_conv2d.out_channels, - kernel_size=scaled_conv2d.kernel_size, - stride=scaled_conv2d.stride, - padding=scaled_conv2d.padding, - dilation=scaled_conv2d.dilation, - groups=scaled_conv2d.groups, - bias=scaled_conv2d.bias is not None, - padding_mode=scaled_conv2d.padding_mode, - ) - - conv2d.weight.data.copy_(weight) - if has_bias: - conv2d.bias.data.copy_(bias) - - return conv2d - - -def scaled_embedding_to_embedding( - scaled_embedding: ScaledEmbedding, -) -> nn.Embedding: - """Convert an instance of ScaledEmbedding to nn.Embedding. - - Args: - scaled_embedding: - The layer to be converted. - Returns: - Return an instance of nn.Embedding that has the same `forward()` behavior - of the given `scaled_embedding`. - """ - assert isinstance(scaled_embedding, ScaledEmbedding), type(scaled_embedding) - embedding = nn.Embedding( - num_embeddings=scaled_embedding.num_embeddings, - embedding_dim=scaled_embedding.embedding_dim, - padding_idx=scaled_embedding.padding_idx, - scale_grad_by_freq=scaled_embedding.scale_grad_by_freq, - sparse=scaled_embedding.sparse, - ) - weight = scaled_embedding.weight - scale = scaled_embedding.scale - - embedding.weight.data.copy_(weight * scale.exp()) - - return embedding - - -def convert_basic_norm(basic_norm: BasicNorm) -> NonScaledNorm: - assert isinstance(basic_norm, BasicNorm), type(BasicNorm) - norm = NonScaledNorm( - num_channels=basic_norm.num_channels, - eps_exp=basic_norm.eps.data.exp().item(), - channel_dim=basic_norm.channel_dim, - ) - return norm - - -def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM: - """Convert an instance of ScaledLSTM to nn.LSTM. - - Args: - scaled_lstm: - The layer to be converted. - Returns: - Return an instance of nn.LSTM that has the same `forward()` behavior - of the given `scaled_lstm`. - """ - assert isinstance(scaled_lstm, ScaledLSTM), type(scaled_lstm) - lstm = nn.LSTM( - input_size=scaled_lstm.input_size, - hidden_size=scaled_lstm.hidden_size, - num_layers=scaled_lstm.num_layers, - bias=scaled_lstm.bias, - batch_first=scaled_lstm.batch_first, - dropout=scaled_lstm.dropout, - bidirectional=scaled_lstm.bidirectional, - proj_size=scaled_lstm.proj_size, - ) - - assert lstm._flat_weights_names == scaled_lstm._flat_weights_names - for idx in range(len(scaled_lstm._flat_weights_names)): - scaled_weight = scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() - lstm._flat_weights[idx].data.copy_(scaled_weight) - - return lstm - - -# Copied from https://pytorch.org/docs/1.9.0/_modules/torch/nn/modules/module.html#Module.get_submodule # noqa -# get_submodule was added to nn.Module at v1.9.0 -def get_submodule(model, target): - if target == "": - return model - atoms: List[str] = target.split(".") - mod: torch.nn.Module = model - for item in atoms: - if not hasattr(mod, item): - raise AttributeError( - mod._get_name() + " has no " "attribute `" + item + "`" - ) - mod = getattr(mod, item) - if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not " "an nn.Module") - return mod - - -def convert_scaled_to_non_scaled( - model: nn.Module, - inplace: bool = False, - is_onnx: bool = False, -): - """Convert `ScaledLinear`, `ScaledConv1d`, and `ScaledConv2d` - in the given modle to their unscaled version `nn.Linear`, `nn.Conv1d`, - and `nn.Conv2d`. - - Args: - model: - The model to be converted. - inplace: - If True, the input model is modified inplace. - If False, the input model is copied and we modify the copied version. - is_onnx: - If True, we are going to export the model to ONNX. In this case, - we will convert nn.LSTM with proj_size to LSTMP. - Return: - Return a model without scaled layers. - """ - if not inplace: - model = copy.deepcopy(model) - - excluded_patterns = r"(self|src)_attn\.(in|out)_proj" - p = re.compile(excluded_patterns) - - d = {} - for name, m in model.named_modules(): - if isinstance(m, ScaledLinear): - if p.search(name) is not None: - continue - d[name] = scaled_linear_to_linear(m) - elif isinstance(m, ScaledConv1d): - d[name] = scaled_conv1d_to_conv1d(m) - elif isinstance(m, ScaledConv2d): - d[name] = scaled_conv2d_to_conv2d(m) - elif isinstance(m, ScaledEmbedding): - d[name] = scaled_embedding_to_embedding(m) - elif isinstance(m, BasicNorm): - d[name] = convert_basic_norm(m) - elif isinstance(m, ScaledLSTM): - if is_onnx: - d[name] = LSTMP(scaled_lstm_to_lstm(m)) - # See - # https://github.com/pytorch/pytorch/issues/47887 - # d[name] = torch.jit.script(LSTMP(scaled_lstm_to_lstm(m))) - else: - d[name] = scaled_lstm_to_lstm(m) - elif isinstance(m, ActivationBalancer): - d[name] = nn.Identity() - - for k, v in d.items(): - if "." in k: - parent, child = k.rsplit(".", maxsplit=1) - setattr(get_submodule(model, parent), child, v) - else: - setattr(model, k, v) - - return model diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py deleted file mode 120000 index 6dcf53ac2..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -/var/data/share20/qc/k2/Github/icefall/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/streaming_decode.py deleted file mode 120000 index bf5d78edd..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ /dev/null @@ -1 +0,0 @@ -/var/data/share20/qc/k2/Github/icefall/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/test_model.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/test_model.py deleted file mode 120000 index 8d25a9eea..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/test_model.py +++ /dev/null @@ -1 +0,0 @@ -/var/data/share20/qc/k2/Github/icefall/egs/librispeech/ASR/pruned_transducer_stateless2/test_model.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/tokenizer.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/tokenizer.py deleted file mode 120000 index 958c99e85..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/tokenizer.py +++ /dev/null @@ -1 +0,0 @@ -../local/utils/tokenizer.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless2/train.py b/egs/reazonspeech/ASR/pruned_transducer_stateless2/train.py deleted file mode 100755 index 4d0776dde..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless2/train.py +++ /dev/null @@ -1,1061 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang -# Mingshuang Luo) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -For training with the L subset: - -export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" - -./pruned_transducer_stateless2/train.py \ - --lang-dir data/lang_char \ - --exp-dir pruned_transducer_stateless2/exp \ - --world-size 8 \ - --num-epochs 15 \ - --start-epoch 0 \ - --max-duration 180 \ - --valid-interval 3000 \ - --model-warm-step 3000 \ - --save-every-n 8000 \ - --training-subset L - -# For mix precision training: - -./pruned_transducer_stateless2/train.py \ - --lang-dir data/lang_char \ - --exp-dir pruned_transducer_stateless2/exp \ - --world-size 8 \ - --num-epochs 10 \ - --start-epoch 0 \ - --max-duration 180 \ - --valid-interval 3000 \ - --model-warm-step 3000 \ - --save-every-n 8000 \ - --use-fp16 True \ - --training-subset L - -For training with the M subset: - -./pruned_transducer_stateless2/train.py \ - --lang-dir data/lang_char \ - --exp-dir pruned_transducer_stateless2/exp \ - --world-size 8 \ - --num-epochs 29 \ - --start-epoch 0 \ - --max-duration 180 \ - --valid-interval 1000 \ - --model-warm-step 500 \ - --save-every-n 1000 \ - --training-subset M - -For training with the S subset: - -./pruned_transducer_stateless2/train.py \ - --lang-dir data/lang_char \ - --exp-dir pruned_transducer_stateless2/exp \ - --world-size 8 \ - --num-epochs 29 \ - --start-epoch 0 \ - --max-duration 180 \ - --valid-interval 400 \ - --model-warm-step 100 \ - --save-every-n 1000 \ - --training-subset S -""" - -import argparse -import logging -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import ReazonSpeechAsrDataModule -from conformer import Conformer -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, Eve -from tokenizer import Tokenizer -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter - -from icefall import diagnostics -# from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import save_checkpoint_with_global_batch_idx -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.lexicon import Lexicon -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=0, - help="""Resume training from from this epoch. - If it is positive, it will load checkpoint from - pruned_transducer_stateless2/exp/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--lang-dir", - type=str, - default="data/lang_char", - help="""The lang dir - It contains language related input files such as - "lexicon.txt" - """, - ) - - parser.add_argument( - "--initial-lr", - type=float, - default=0.003, - help="The initial learning rate. This value should not need to be changed.", - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate decreases. - We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=6, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=8000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=20, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--valid-interval", - type=int, - default=3000, - help="""When training_subset is L, set the valid_interval to 3000. - When training_subset is M, set the valid_interval to 1000. - When training_subset is S, set the valid_interval to 400. - """, - ) - - parser.add_argument( - "--model-warm-step", - type=int, - default=3000, - help="""When training_subset is L, set the model_warm_step to 3000. - When training_subset is M, set the model_warm_step to 500. - When training_subset is S, set the model_warm_step to 100. - """, - ) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - Explanation of options saved in `params`: - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - best_train_epoch: It is the epoch that has the best training loss. - - best_valid_epoch: It is the epoch that has the best validation loss. - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - feature_dim: The model input dim. It has to match the one used - in computing features. - - subsampling_factor: The subsampling factor for the model. - - encoder_dim: Hidden dim for multi-head attention model. - - num_decoder_layers: Number of decoder layer of transformer decoder. - - warm_step: The warm_step for Noam optimizer. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - # parameters for conformer - "feature_dim": 80, - "subsampling_factor": 4, - "encoder_dim": 512, - "nhead": 8, - "dim_feedforward": 2048, - "num_encoder_layers": 12, - # parameters for decoder - "decoder_dim": 512, - # parameters for joiner - "joiner_dim": 512, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Conformer and Transformer - encoder = Conformer( - num_features=params.feature_dim, - subsampling_factor=params.subsampling_factor, - d_model=params.encoder_dim, - nhead=params.nhead, - dim_feedforward=params.dim_feedforward, - num_encoder_layers=params.num_encoder_layers, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=params.encoder_dim, - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is positive, it will load the checkpoint from - `params.start_epoch - 1`. - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 0: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: nn.Module, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: nn.Module, - sp: Tokenizer, - batch: dict, - is_training: bool, - warmup: float = 1.0, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute RNN-T loss given the model and its inputs. - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Conformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - warmup=warmup, - ) - # after the main warmup step, we keep pruned_loss_scale small - # for the same amount of time (model_warm_step), to avoid - # overwhelming the simple_loss and causing it to diverge, - # in case it had not fully learned the alignment yet. - pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) - ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: nn.Module, - sp: Tokenizer, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: nn.Module, - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: Tokenizer, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=(params.batch_idx_train / params.model_warm_step), - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except: # noqa - display_and_save_batch(batch, params=params) - raise - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}" - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - - if batch_idx > 0 and batch_idx % params.valid_interval == 0: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = Tokenizer.load(args.lang_dir, args.lang_type) - - # is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoints = load_checkpoint_if_available(params=params, model=model) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank]) - model.device = device - - optimizer = Eve(model.parameters(), lr=params.initial_lr) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - reazonspeech = ReazonSpeechAsrDataModule(args) - - train_cuts = reazonspeech.train_cuts() - valid_cuts = reazonspeech.valid_cuts() - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 25 seconds - # - # Caution: There is a reason to select 30.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 30.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./conformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 1) // 2 - 1) // 2 - tokens = c.supervisions[0].text.replace(" ", "") - - if T < len(tokens): - logging.warning( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - valid_dl = reazonspeech.valid_dataloaders(valid_cuts) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = reazonspeech.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - if not params.print_diagnostics and params.start_batch == 0: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs): - scheduler.step_epoch(epoch) - fix_random_seed(params.seed + epoch) - train_dl.sampler.set_epoch(epoch) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - texts = batch["supervisions"]["text"] - num_tokens = sum(len(i) for i in texts) - - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: nn.Module, - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: Tokenizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 0 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - # warmup = 0.0 is so that the derivs for the pruned loss stay zero - # (i.e. are not remembered by the decaying-average in adam), because - # we want to avoid these params being subject to shrinkage in adam. - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - warmup=0.0, - ) - loss.backward() - optimizer.step() - optimizer.zero_grad() - except RuntimeError as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params) - raise - - -def main(): - parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) - args = parser.parse_args() - args.lang_dir = Path(args.lang_dir) - args.exp_dir = Path(args.exp_dir) - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py deleted file mode 120000 index a48591198..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/asr_datamodule.py +++ /dev/null @@ -1 +0,0 @@ -../local/utils/asr_datamodule.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py deleted file mode 120000 index d7349b0a3..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/beam_search.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/decode.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/decode.py deleted file mode 100755 index 446890daa..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/decode.py +++ /dev/null @@ -1,846 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: -(1) greedy search -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --lang data/lang_char \ - --decoding-method greedy_search - -(2) beam search (not recommended) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method beam_search \ - --lang data/lang_char \ - --beam-size 4 - -(3) modified beam search -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method modified_beam_search \ - --lang data/lang_char \ - --beam-size 4 - -(4) fast beam search (one best) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method fast_beam_search \ - --beam 20.0 \ - --max-contexts 8 \ - --lang data/lang_char \ - --max-states 64 - -(5) fast beam search (nbest) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method fast_beam_search_nbest \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --lang data/lang_char \ - --nbest-scale 0.5 - -(6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method fast_beam_search_nbest_oracle \ - --beam 20.0 \ - --max-contexts 8 \ - --max-states 64 \ - --num-paths 200 \ - --lang data/lang_char \ - --nbest-scale 0.5 - -(7) fast beam search (with LG) -./pruned_transducer_stateless7_streaming/decode.py \ - --epoch 28 \ - --avg 15 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --max-duration 600 \ - --decode-chunk-len 32 \ - --decoding-method fast_beam_search_nbest_LG \ - --beam 20.0 \ - --max-contexts 8 \ - --lang data/lang_char \ - --max-states 64 -""" - - -import argparse -import logging -import math -from collections import defaultdict -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import torch -import torch.nn as nn -from asr_datamodule import ReazonSpeechAsrDataModule -from beam_search import ( - beam_search, - fast_beam_search_nbest, - fast_beam_search_nbest_LG, - fast_beam_search_nbest_oracle, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from tokenizer import Tokenizer -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.lexicon import Lexicon -from icefall.utils import ( - AttributeDict, - setup_logger, - store_transcripts, - str2bool, - write_error_stats, -) - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--gpu", - type=int, - default=0, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--res-dir", - type=Path, - default=None, - help="The path to save results.", - ) - - parser.add_argument( - "--lang-dir", - type=Path, - default="data/lang_char", - help="The lang dir. It should contain at least a word table.", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - - fast_beam_search_nbest - - fast_beam_search_nbest_oracle - - fast_beam_search_nbest_LG - If you use fast_beam_search_nbest_LG, you have to specify - `--lang-dir`, which should contain `LG.pt`. - """, - ) - - parser.add_argument( - "--decoding-graph", - type=str, - default="", - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --decoding-method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=20.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search, - fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle - """, - ) - - parser.add_argument( - "--ngram-lm-scale", - type=float, - default=0.01, - help=""" - Used only when --decoding_method is fast_beam_search_nbest_LG. - It specifies the scale for n-gram LM scores. - """, - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=8, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=64, - help="""Used only when --decoding-method is - fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, - and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. - Used only when --decoding_method is greedy_search""", - ) - - parser.add_argument( - "--num-paths", - type=int, - default=200, - help="""Number of paths for nbest decoding. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--nbest-scale", - type=float, - default=0.5, - help="""Scale applied to lattice scores when computing nbest paths. - Used only when the decoding method is fast_beam_search_nbest, - fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", - ) - - parser.add_argument( - "--pad-feature", - type=int, - default=30, - help=""" - Number of frames to pad at the end. - """, - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_batch( - params: AttributeDict, - model: nn.Module, - sp: Tokenizer, - batch: dict, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[List[str]]]: - """Decode one batch and return the result in a dict. The dict has the - following format: - - - key: It indicates the setting used for decoding. For example, - if greedy_search is used, it would be "greedy_search" - If beam search with a beam size of 7 is used, it would be - "beam_7" - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return the decoding result. See above description for the format of - the returned dict. - """ - device = next(model.parameters()).device - feature = batch["inputs"] - assert feature.ndim == 3 - - feature = feature.to(device) - # at entry, feature is (N, T, C) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.pad_feature: - feature_lens += params.pad_feature - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.pad_feature), - value=LOG_EPS, - ) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) - - hyps = [] - - if params.decoding_method == "fast_beam_search": - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "fast_beam_search_nbest_LG": - hyp_tokens = fast_beam_search_nbest_LG( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in hyp_tokens: - hyps.append([word_table[i] for i in hyp]) - elif params.decoding_method == "fast_beam_search_nbest": - hyp_tokens = fast_beam_search_nbest( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "fast_beam_search_nbest_oracle": - hyp_tokens = fast_beam_search_nbest_oracle( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - num_paths=params.num_paths, - ref_texts=sp.encode(supervisions["text"]), - nbest_scale=params.nbest_scale, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - elif params.decoding_method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(sp.text2word(hyp)) - else: - batch_size = encoder_out.size(0) - - for i in range(batch_size): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.decoding_method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.decoding_method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError( - f"Unsupported decoding method: {params.decoding_method}" - ) - hyps.append(sp.text2word(sp.decode(hyp))) - - if params.decoding_method == "greedy_search": - return {"greedy_search": hyps} - elif "fast_beam_search" in params.decoding_method: - key = f"beam_{params.beam}_" - key += f"max_contexts_{params.max_contexts}_" - key += f"max_states_{params.max_states}" - if "nbest" in params.decoding_method: - key += f"_num_paths_{params.num_paths}_" - key += f"nbest_scale_{params.nbest_scale}" - if "LG" in params.decoding_method: - key += f"_ngram_lm_scale_{params.ngram_lm_scale}" - - return {key: hyps} - else: - return {f"beam_size_{params.beam_size}": hyps} - - -def decode_dataset( - dl: torch.utils.data.DataLoader, - params: AttributeDict, - model: nn.Module, - sp: Tokenizer, - word_table: Optional[k2.SymbolTable] = None, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: - """Decode dataset. - - Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - word_table: - The word symbol table. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search, fast_beam_search_nbest, - fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - num_cuts = 0 - - try: - num_batches = len(dl) - except TypeError: - num_batches = "?" - - if params.decoding_method == "greedy_search": - log_interval = 50 - else: - log_interval = 20 - - results = defaultdict(list) - for batch_idx, batch in enumerate(dl): - texts = batch["supervisions"]["text"] - cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] - - hyps_dict = decode_one_batch( - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - word_table=word_table, - batch=batch, - ) - - for name, hyps in hyps_dict.items(): - this_batch = [] - assert len(hyps) == len(texts) - for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): - ref_words = sp.text2word(ref_text) - this_batch.append((cut_id, ref_words, hyp_words)) - - results[name].extend(this_batch) - - num_cuts += len(texts) - - if batch_idx % log_interval == 0: - batch_str = f"{batch_idx}/{num_batches}" - - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") - return results - - -def save_results( - params: AttributeDict, - test_set_name: str, - results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], -): - test_set_wers = dict() - for key, results in results_dict.items(): - recog_path = params.res_dir / f"recogs-{test_set_name}-{params.suffix}.txt" - results = sorted(results) - store_transcripts(filename=recog_path, texts=results) - - logging.info(f"The transcripts are stored in {recog_path}") - - # The following prints out WERs, per-word error statistics and aligned - # ref/hyp pairs. - errs_filename = params.res_dir / f"errs-{test_set_name}-{params.suffix}.txt" - with open(errs_filename, "w") as f: - wer = write_error_stats( - f, f"{test_set_name}-{key}", results, enable_log=True - ) - test_set_wers[key] = wer - - logging.info("Wrote detailed error stats to {}".format(errs_filename)) - - test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) - errs_info = params.res_dir / f"wer-summary-{test_set_name}-{params.suffix}.txt" - with open(errs_info, "w") as f: - print("settings\tWER", file=f) - for key, val in test_set_wers: - print("{}\t{}".format(key, val), file=f) - - s = "\nFor {}, WER of different settings are:\n".format(test_set_name) - note = "\tbest for {}".format(test_set_name) - for key, val in test_set_wers: - s += "{}\t{}{}\n".format(key, val, note) - note = "" - logging.info(s) - - return test_set_wers - - -@torch.no_grad() -def main(): - parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - assert params.decoding_method in ( - "greedy_search", - "beam_search", - "fast_beam_search", - "fast_beam_search_nbest", - "fast_beam_search_nbest_LG", - "fast_beam_search_nbest_oracle", - "modified_beam_search", - ) - if not params.res_dir: - params.res_dir = params.exp_dir / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" - - if "fast_beam_search" in params.decoding_method: - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - if "nbest" in params.decoding_method: - params.suffix += f"-nbest-scale-{params.nbest_scale}" - params.suffix += f"-num-paths-{params.num_paths}" - if "LG" in params.decoding_method: - params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" - elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" - else: - params.suffix += f"-context-{params.context_size}" - params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", params.gpu) - - logging.info(f"Device: {device}") - - sp = Tokenizer.load(params.lang, params.lang_type) - - # and are defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - assert model.encoder.decode_chunk_size == params.decode_chunk_len // 2, ( - model.encoder.decode_chunk_size, - params.decode_chunk_len, - ) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - - decoding_graph = None - word_table = None - - if params.decoding_graph: - decoding_graph = k2.Fsa.from_dict( - torch.load(params.decoding_graph, map_location=device) - ) - elif "fast_beam_search" in params.decoding_method: - if params.decoding_method == "fast_beam_search_nbest_LG": - lexicon = Lexicon(params.lang_dir) - word_table = lexicon.word_table - lg_filename = params.lang_dir / "LG.pt" - logging.info(f"Loading {lg_filename}") - decoding_graph = k2.Fsa.from_dict( - torch.load(lg_filename, map_location=device) - ) - decoding_graph.scores *= params.ngram_lm_scale - else: - word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - # we need cut ids to display recognition results. - args.return_cuts = True - reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - - for subdir in ["valid"]: - results_dict = decode_dataset( - dl=reazonspeech_corpus.test_dataloaders(getattr(reazonspeech_corpus, f"{subdir}_cuts")()), - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - tot_err = save_results( - params=params, - test_set_name=subdir, - results_dict=results_dict, - ) - with ( - params.res_dir - / ( - f"{subdir}-{params.decode_chunk_len}_{params.beam_size}" - f"_{params.avg}_{params.epoch}.cer" - ) - ).open("w") as fout: - if len(tot_err) == 1: - fout.write(f"{tot_err[0][1]}") - else: - fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py deleted file mode 120000 index ca8fed319..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decode_stream.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/decoder.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/decoder.py deleted file mode 120000 index 1ce277aa6..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/decoder.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/decoder.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py deleted file mode 100755 index 072679cfc..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/do_not_use_it_directly.py +++ /dev/null @@ -1,1261 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --max-duration 550 -""" - - -import argparse -import copy -import logging -import math -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import ReazonSpeechAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from tokenizer import Tokenizer -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer_for_ncnn_export_only import Zipformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] -LOG_EPS = math.log(1e-10) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=50, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="pruned_transducer_stateless7_streaming/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--pad-feature", - type=int, - default=0, - help=""" - Number of frames to pad at the end. - """, - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 1000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - num_left_chunks=params.num_left_chunks, - short_chunk_size=params.short_chunk_size, - decode_chunk_size=params.decode_chunk_len // 2, - is_pnnx=True, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: Tokenizer, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.pad_feature: - feature_lens += params.pad_feature - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.pad_feature), - value=LOG_EPS, - ) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: Tokenizer, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: Tokenizer, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except Exception as e: # noqa - logging.error(e, exc_info=True) - display_and_save_batch(batch, params=params, sp=sp) - raise e - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - log_mode = logging.info - log_mode(f"Epoch {params.cur_epoch}, validation: {valid_info}") - log_mode( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, master_port=params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = Tokenizer.load(args.lang, args.lang_type) - - # is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 0.3 or c.duration > 30.0: - logging.debug( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.info( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - train_cuts = reazonspeech_corpus.train_cuts() - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = reazonspeech_corpus.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = reazonspeech_corpus.valid_cuts() - valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts) - - if params.start_batch <= 0 and not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: Tokenizer, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: Tokenizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - raise RuntimeError("Please don't use this file directly!") - parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) - args = parser.parse_args() - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py deleted file mode 120000 index cb673b3eb..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/encoder_interface.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/export.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/export.py deleted file mode 100644 index 666bcc831..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/export.py +++ /dev/null @@ -1,250 +0,0 @@ -#!/usr/bin/env python3 -# -# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# This script converts several saved checkpoints -# to a single one using model averaging. - -import argparse -import logging -from pathlib import Path - -import torch -from scaling_converter import convert_scaled_to_non_scaled -from tokenizer import Tokenizer -from train import add_model_arguments, get_params, get_transducer_model - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import str2bool - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=30, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--avg", - type=int, - default=9, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless7_streaming/exp", - help="""It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--jit", - type=str2bool, - default=False, - help="""True to save a model after applying torch.jit.script. - It will generate a file named cpu_jit.pt - - Check ./jit_pretrained.py for how to use it. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - add_model_arguments(parser) - - return parser - - -@torch.no_grad() -def main(): - parser = get_parser() - Tokenizer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - sp = Tokenizer.load(params.lang, params.lang_type) - - # is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - model.to(device) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if i >= 1: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to("cpu") - model.eval() - - if params.jit is True: - convert_scaled_to_non_scaled(model, inplace=True) - # We won't use the forward() method of the model in C++, so just ignore - # it here. - # Otherwise, one of its arguments is a ragged tensor and is not - # torch scriptabe. - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - logging.info("Using torch.jit.script") - model = torch.jit.script(model) - filename = params.exp_dir / "cpu_jit.pt" - model.save(str(filename)) - logging.info(f"Saved to {filename}") - else: - logging.info("Not using torchscript. Export model.state_dict()") - # Save it using a format so that it can be loaded - # by :func:`load_checkpoint` - filename = params.exp_dir / "pretrained.pt" - torch.save({"model": model.state_dict()}, str(filename)) - logging.info(f"Saved to {filename}") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/joiner.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/joiner.py deleted file mode 120000 index 482ebcfef..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/joiner.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/joiner.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/model.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/model.py deleted file mode 120000 index 16c2bf28d..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/model.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/model.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/optim.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/optim.py deleted file mode 120000 index 522bbaff9..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/optim.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/optim.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py deleted file mode 100644 index 932026868..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/pretrained.py +++ /dev/null @@ -1,347 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -This script loads a checkpoint and uses it to decode waves. -You can generate the checkpoint with the following command: - -./pruned_transducer_stateless7_streaming/export.py \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --epoch 20 \ - --avg 10 - -Usage of this script: - -(1) greedy search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --lang data/lang_char \ - --method greedy_search \ - /path/to/foo.wav \ - /path/to/bar.wav - -(2) beam search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --lang data/lang_char \ - --method beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(3) modified beam search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --lang data/lang_char \ - --method modified_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -(4) fast beam search -./pruned_transducer_stateless7_streaming/pretrained.py \ - --checkpoint ./pruned_transducer_stateless7_streaming/exp/pretrained.pt \ - --lang data/lang_char \ - --method fast_beam_search \ - --beam-size 4 \ - /path/to/foo.wav \ - /path/to/bar.wav - -You can also use `./pruned_transducer_stateless7_streaming/exp/epoch-xx.pt`. - -Note: ./pruned_transducer_stateless7_streaming/exp/pretrained.pt is generated by -./pruned_transducer_stateless7_streaming/export.py -""" - - -import argparse -import logging -import math -from typing import List - -import k2 -import kaldifeat -import torch -import torchaudio -from beam_search import ( - beam_search, - fast_beam_search_one_best, - greedy_search, - greedy_search_batch, - modified_beam_search, -) -from tokenizer import Tokenizer -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--checkpoint", - type=str, - required=True, - help="Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint().", - ) - - parser.add_argument( - "--method", - type=str, - default="greedy_search", - help="""Possible values are: - - greedy_search - - beam_search - - modified_beam_search - - fast_beam_search - """, - ) - - parser.add_argument( - "sound_files", - type=str, - nargs="+", - help="The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz.", - ) - - parser.add_argument( - "--sample-rate", - type=int, - default=16000, - help="The sample rate of the input sound file", - ) - - parser.add_argument( - "--beam-size", - type=int, - default=4, - help="""An integer indicating how many candidates we will keep for each - frame. Used only when --method is beam_search or - modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=8, - help="""Used only when --method is fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - parser.add_argument( - "--max-sym-per-frame", - type=int, - default=1, - help="""Maximum number of symbols per frame. Used only when - --method is greedy_search. - """, - ) - - add_model_arguments(parser) - - return parser - - -def read_sound_files( - filenames: List[str], expected_sample_rate: float -) -> List[torch.Tensor]: - """Read a list of sound files into a list 1-D float32 torch tensors. - Args: - filenames: - A list of sound filenames. - expected_sample_rate: - The expected sample rate of the sound files. - Returns: - Return a list of 1-D float32 torch tensors. - """ - ans = [] - for f in filenames: - wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" - # We use only the first channel - ans.append(wave[0]) - return ans - - -@torch.no_grad() -def main(): - parser = get_parser() - Tokenizer.add_arguments(parser) - args = parser.parse_args() - - params = get_params() - - params.update(vars(args)) - - sp = Tokenizer.load(params.lang, params.lang_type) - - # is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(f"{params}") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - logging.info(f"device: {device}") - - logging.info("Creating model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - checkpoint = torch.load(args.checkpoint, map_location="cpu") - model.load_state_dict(checkpoint["model"], strict=False) - model.to(device) - model.eval() - model.device = device - - logging.info("Constructing Fbank computer") - opts = kaldifeat.FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = params.sample_rate - opts.mel_opts.num_bins = params.feature_dim - - fbank = kaldifeat.Fbank(opts) - - logging.info(f"Reading sound files: {params.sound_files}") - waves = read_sound_files( - filenames=params.sound_files, expected_sample_rate=params.sample_rate - ) - waves = [w.to(device) for w in waves] - - logging.info("Decoding started") - features = fbank(waves) - feature_lengths = [f.size(0) for f in features] - - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) - - feature_lengths = torch.tensor(feature_lengths, device=device) - - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) - - num_waves = encoder_out.size(0) - hyps = [] - msg = f"Using {params.method}" - if params.method == "beam_search": - msg += f" with beam size {params.beam_size}" - logging.info(msg) - - if params.method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - hyp_tokens = fast_beam_search_one_best( - model=model, - decoding_graph=decoding_graph, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam, - max_contexts=params.max_contexts, - max_states=params.max_states, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "modified_beam_search": - hyp_tokens = modified_beam_search( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - beam=params.beam_size, - ) - - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - elif params.method == "greedy_search" and params.max_sym_per_frame == 1: - hyp_tokens = greedy_search_batch( - model=model, - encoder_out=encoder_out, - encoder_out_lens=encoder_out_lens, - ) - for hyp in sp.decode(hyp_tokens): - hyps.append(hyp.split()) - else: - for i in range(num_waves): - # fmt: off - encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] - # fmt: on - if params.method == "greedy_search": - hyp = greedy_search( - model=model, - encoder_out=encoder_out_i, - max_sym_per_frame=params.max_sym_per_frame, - ) - elif params.method == "beam_search": - hyp = beam_search( - model=model, - encoder_out=encoder_out_i, - beam=params.beam_size, - ) - else: - raise ValueError(f"Unsupported method: {params.method}") - - hyps.append(sp.decode(hyp).split()) - - s = "\n" - for filename, hyp in zip(params.sound_files, hyps): - words = " ".join(hyp) - s += f"{filename}:\n{words}\n\n" - logging.info(s) - - logging.info("Decoding Done") - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - main() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/scaling.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/scaling.py deleted file mode 120000 index a7ef73bcb..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/scaling.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py deleted file mode 120000 index 566c317ff..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/scaling_converter.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py deleted file mode 120000 index 92c3904af..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming-ncnn-decode.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py deleted file mode 120000 index 2adf271c1..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/streaming_beam_search.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py deleted file mode 100755 index 4c18c7563..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/streaming_decode.py +++ /dev/null @@ -1,597 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Usage: -./pruned_transducer_stateless7_streaming/streaming_decode.py \ - --epoch 28 \ - --avg 15 \ - --decode-chunk-len 32 \ - --exp-dir ./pruned_transducer_stateless7_streaming/exp \ - --decoding_method greedy_search \ - --lang data/lang_char \ - --num-decode-streams 2000 -""" - -import argparse -import logging -import math -from pathlib import Path -from typing import Dict, List, Optional, Tuple - -import k2 -import numpy as np -import torch -import torch.nn as nn -from asr_datamodule import ReazonSpeechAsrDataModule -from decode import save_results -from decode_stream import DecodeStream -from kaldifeat import Fbank, FbankOptions -from lhotse import CutSet -from streaming_beam_search import ( - fast_beam_search_one_best, - greedy_search, - modified_beam_search, -) -from tokenizer import Tokenizer -from torch.nn.utils.rnn import pad_sequence -from train import add_model_arguments, get_params, get_transducer_model -from zipformer import stack_states, unstack_states - -from icefall.checkpoint import ( - average_checkpoints, - average_checkpoints_with_averaged_model, - find_checkpoints, - load_checkpoint, -) -from icefall.utils import AttributeDict, setup_logger, str2bool - -LOG_EPS = math.log(1e-10) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=28, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 0. - You can specify --avg to use more checkpoints for model averaging.""", - ) - - parser.add_argument( - "--iter", - type=int, - default=0, - help="""If positive, --epoch is ignored and it - will use the checkpoint exp_dir/checkpoint-iter.pt. - You can specify --avg to use more checkpoints for model averaging. - """, - ) - - parser.add_argument( - "--gpu", - type=int, - default=0, - ) - - parser.add_argument( - "--avg", - type=int, - default=15, - help="Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'", - ) - - parser.add_argument( - "--use-averaged-model", - type=str2bool, - default=True, - help="Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. ", - ) - - parser.add_argument( - "--exp-dir", - type=str, - default="pruned_transducer_stateless2/exp", - help="The experiment dir", - ) - - parser.add_argument( - "--bpe-model", - type=str, - default="data/lang_bpe_500/bpe.model", - help="Path to the BPE model", - ) - - parser.add_argument( - "--decoding-method", - type=str, - default="greedy_search", - help="""Supported decoding methods are: - greedy_search - modified_beam_search - fast_beam_search - """, - ) - - parser.add_argument( - "--decoding-graph", - type=str, - default="", - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--num_active_paths", - type=int, - default=4, - help="""An interger indicating how many candidates we will keep for each - frame. Used only when --decoding-method is modified_beam_search.""", - ) - - parser.add_argument( - "--beam", - type=float, - default=4.0, - help="""A floating point value to calculate the cutoff score during beam - search (i.e., `cutoff = max-score - beam`), which is the same as the - `beam` in Kaldi. - Used only when --decoding-method is fast_beam_search""", - ) - - parser.add_argument( - "--max-contexts", - type=int, - default=4, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--max-states", - type=int, - default=32, - help="""Used only when --decoding-method is - fast_beam_search""", - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--num-decode-streams", - type=int, - default=2000, - help="The number of streams that can be decoded parallel.", - ) - - parser.add_argument( - "--res-dir", - type=Path, - default=None, - help="The path to save results.", - ) - - add_model_arguments(parser) - - return parser - - -def decode_one_chunk( - params: AttributeDict, - model: nn.Module, - decode_streams: List[DecodeStream], -) -> List[int]: - """Decode one chunk frames of features for each decode_streams and - return the indexes of finished streams in a List. - - Args: - params: - It's the return value of :func:`get_params`. - model: - The neural model. - decode_streams: - A List of DecodeStream, each belonging to a utterance. - Returns: - Return a List containing which DecodeStreams are finished. - """ - device = model.device - - features = [] - feature_lens = [] - states = [] - processed_lens = [] - - for stream in decode_streams: - feat, feat_len = stream.get_feature_frames(params.decode_chunk_len) - features.append(feat) - feature_lens.append(feat_len) - states.append(stream.states) - processed_lens.append(stream.done_frames) - - feature_lens = torch.tensor(feature_lens, device=device) - features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) - - # We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling - # factor in encoders is 8. - # After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8. - tail_length = 23 - if features.size(1) < tail_length: - pad_length = tail_length - features.size(1) - feature_lens += pad_length - features = torch.nn.functional.pad( - features, - (0, 0, 0, pad_length), - mode="constant", - value=LOG_EPS, - ) - - states = stack_states(states) - processed_lens = torch.tensor(processed_lens, device=device) - - encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward( - x=features, - x_lens=feature_lens, - states=states, - ) - - encoder_out = model.joiner.encoder_proj(encoder_out) - - if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) - elif params.decoding_method == "fast_beam_search": - processed_lens = processed_lens + encoder_out_lens - fast_beam_search_one_best( - model=model, - encoder_out=encoder_out, - processed_lens=processed_lens, - streams=decode_streams, - beam=params.beam, - max_states=params.max_states, - max_contexts=params.max_contexts, - ) - elif params.decoding_method == "modified_beam_search": - modified_beam_search( - model=model, - streams=decode_streams, - encoder_out=encoder_out, - num_active_paths=params.num_active_paths, - ) - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - - states = unstack_states(new_states) - - finished_streams = [] - for i in range(len(decode_streams)): - decode_streams[i].states = states[i] - decode_streams[i].done_frames += encoder_out_lens[i] - if decode_streams[i].done: - finished_streams.append(i) - - return finished_streams - - -def decode_dataset( - cuts: CutSet, - params: AttributeDict, - model: nn.Module, - sp: Tokenizer, - decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[List[str], List[str]]]]: - """Decode dataset. - - Args: - cuts: - Lhotse Cutset containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - sp: - The BPE model. - decoding_graph: - The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used - only when --decoding_method is fast_beam_search. - Returns: - Return a dict, whose key may be "greedy_search" if greedy search - is used, or it may be "beam_7" if beam size of 7 is used. - Its value is a list of tuples. Each tuple contains two elements: - The first is the reference transcript, and the second is the - predicted result. - """ - device = model.device - - opts = FbankOptions() - opts.device = device - opts.frame_opts.dither = 0 - opts.frame_opts.snip_edges = False - opts.frame_opts.samp_freq = 16000 - opts.mel_opts.num_bins = 80 - - log_interval = 50 - - decode_results = [] - # Contain decode streams currently running. - decode_streams = [] - for num, cut in enumerate(cuts): - # each utterance has a DecodeStream. - initial_states = model.encoder.get_init_state(device=device) - decode_stream = DecodeStream( - params=params, - cut_id=cut.id, - initial_states=initial_states, - decoding_graph=decoding_graph, - device=device, - ) - - audio: np.ndarray = cut.load_audio() - # audio.shape: (1, num_samples) - assert len(audio.shape) == 2 - assert audio.shape[0] == 1, "Should be single channel" - assert audio.dtype == np.float32, audio.dtype - - # The trained model is using normalized samples - assert audio.max() <= 1, "Should be normalized to [-1, 1])" - - samples = torch.from_numpy(audio).squeeze(0) - - fbank = Fbank(opts) - feature = fbank(samples.to(device)) - decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len) - decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode] - - decode_streams.append(decode_stream) - - while len(decode_streams) >= params.num_decode_streams: - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - sp.text2word(decode_streams[i].ground_truth), - sp.text2word(sp.decode(decode_streams[i].decoding_result())), - ) - ) - del decode_streams[i] - - if num % log_interval == 0: - logging.info(f"Cuts processed until now is {num}.") - - # decode final chunks of last sequences - while len(decode_streams): - finished_streams = decode_one_chunk( - params=params, model=model, decode_streams=decode_streams - ) - for i in sorted(finished_streams, reverse=True): - decode_results.append( - ( - decode_streams[i].id, - sp.text2word(decode_streams[i].ground_truth), - sp.text2word(sp.decode(decode_streams[i].decoding_result())), - ) - ) - del decode_streams[i] - - if params.decoding_method == "greedy_search": - key = "greedy_search" - elif params.decoding_method == "fast_beam_search": - key = ( - f"beam_{params.beam}_" - f"max_contexts_{params.max_contexts}_" - f"max_states_{params.max_states}" - ) - elif params.decoding_method == "modified_beam_search": - key = f"num_active_paths_{params.num_active_paths}" - else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") - return {key: decode_results} - - -@torch.no_grad() -def main(): - parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) - args = parser.parse_args() - args.exp_dir = Path(args.exp_dir) - - params = get_params() - params.update(vars(args)) - - if not params.res_dir: - params.res_dir = params.exp_dir / "streaming" / params.decoding_method - - if params.iter > 0: - params.suffix = f"iter-{params.iter}-avg-{params.avg}" - else: - params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - - # for streaming - params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}" - - # for fast_beam_search - if params.decoding_method == "fast_beam_search": - params.suffix += f"-beam-{params.beam}" - params.suffix += f"-max-contexts-{params.max_contexts}" - params.suffix += f"-max-states-{params.max_states}" - - if params.use_averaged_model: - params.suffix += "-use-averaged-model" - - setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") - logging.info("Decoding started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", params.gpu) - - logging.info(f"Device: {device}") - - sp = Tokenizer.load(params.lang, params.lang_type) - - # and is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.unk_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - if not params.use_averaged_model: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - elif params.avg == 1: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - else: - start = params.epoch - params.avg + 1 - filenames = [] - for i in range(start, params.epoch + 1): - if start >= 0: - filenames.append(f"{params.exp_dir}/epoch-{i}.pt") - logging.info(f"averaging {filenames}") - model.to(device) - model.load_state_dict(average_checkpoints(filenames, device=device)) - else: - if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] - if len(filenames) == 0: - raise ValueError( - f"No checkpoints found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - elif len(filenames) < params.avg + 1: - raise ValueError( - f"Not enough checkpoints ({len(filenames)}) found for" - f" --iter {params.iter}, --avg {params.avg}" - ) - filename_start = filenames[-1] - filename_end = filenames[0] - logging.info( - "Calculating the averaged model over iteration checkpoints" - f" from {filename_start} (excluded) to {filename_end}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - else: - assert params.avg > 0, params.avg - start = params.epoch - params.avg - assert start >= 1, start - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" - ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, - ) - ) - - model.to(device) - model.eval() - model.device = device - - decoding_graph = None - if params.decoding_graph: - decoding_graph = k2.Fsa.from_dict( - torch.load(params.decoding_graph, map_location=device) - ) - elif params.decoding_method == "fast_beam_search": - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - args.return_cuts = True - reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - - for subdir in ["valid"]: - results_dict = decode_dataset( - cuts=getattr(reazonspeech_corpus, f"{subdir}_cuts")(), - params=params, - model=model, - sp=sp, - decoding_graph=decoding_graph, - ) - tot_err = save_results( - params=params, test_set_name=subdir, results_dict=results_dict - ) - - with ( - params.res_dir - / ( - f"{subdir}-{params.decode_chunk_len}" - f"_{params.avg}_{params.epoch}.cer" - ) - ).open("w") as fout: - if len(tot_err) == 1: - fout.write(f"{tot_err[0][1]}") - else: - fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) - - logging.info("Done!") - - -if __name__ == "__main__": - main() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/tokenizer.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/tokenizer.py deleted file mode 120000 index 958c99e85..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/tokenizer.py +++ /dev/null @@ -1 +0,0 @@ -../local/utils/tokenizer.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/train.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/train.py deleted file mode 100755 index 32cd4c576..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/train.py +++ /dev/null @@ -1,1259 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, -# Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -Usage: - -export CUDA_VISIBLE_DEVICES="0,1,2,3" - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless7_streaming/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7_streaming/exp \ - --lang data/lang_char \ - --max-duration 550 -""" - - -import argparse -import copy -import logging -import math -import warnings -from pathlib import Path -from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union - -import k2 -import optim -import torch -import torch.multiprocessing as mp -import torch.nn as nn -from asr_datamodule import ReazonSpeechAsrDataModule -from decoder import Decoder -from joiner import Joiner -from lhotse.cut import Cut -from lhotse.dataset.sampling.base import CutSampler -from lhotse.utils import fix_random_seed -from model import Transducer -from optim import Eden, ScaledAdam -from tokenizer import Tokenizer -from torch import Tensor -from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.tensorboard import SummaryWriter -from zipformer import Zipformer - -from icefall import diagnostics -from icefall.checkpoint import load_checkpoint, remove_checkpoints -from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist -from icefall.env import get_env_info -from icefall.hooks import register_inf_check_hooks -from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool - -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] -LOG_EPS = math.log(1e-10) - - -def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: - if isinstance(model, DDP): - # get underlying nn.Module - model = model.module - for module in model.modules(): - if hasattr(module, "batch_count"): - module.batch_count = batch_count - - -def add_model_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--num-encoder-layers", - type=str, - default="2,4,3,2,4", - help="Number of zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--feedforward-dims", - type=str, - default="1024,1024,2048,2048,1024", - help="Feedforward dimension of the zipformer encoder layers, comma separated.", - ) - - parser.add_argument( - "--nhead", - type=str, - default="8,8,8,8,8", - help="Number of attention heads in the zipformer encoder layers.", - ) - - parser.add_argument( - "--encoder-dims", - type=str, - default="384,384,384,384,384", - help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", - ) - - parser.add_argument( - "--attention-dims", - type=str, - default="192,192,192,192,192", - help="""Attention dimension in the 2 blocks of zipformer encoder layers, comma separated; - not the same as embedding dimension.""", - ) - - parser.add_argument( - "--encoder-unmasked-dims", - type=str, - default="256,256,256,256,256", - help="Unmasked dimensions in the encoders, relates to augmentation during training. " - "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " - " worse.", - ) - - parser.add_argument( - "--zipformer-downsampling-factors", - type=str, - default="1,2,4,8,2", - help="Downsampling factor for each stack of encoder layers.", - ) - - parser.add_argument( - "--cnn-module-kernels", - type=str, - default="31,31,31,31,31", - help="Sizes of kernels in convolution modules", - ) - - parser.add_argument( - "--decoder-dim", - type=int, - default=512, - help="Embedding dimension in the decoder model.", - ) - - parser.add_argument( - "--joiner-dim", - type=int, - default=512, - help="""Dimension used in the joiner model. - Outputs from the encoder and decoder model are projected - to this dimension before adding. - """, - ) - - parser.add_argument( - "--short-chunk-size", - type=int, - default=50, - help="""Chunk length of dynamic training, the chunk size would be either - max sequence length of current batch or uniformly sampled from (1, short_chunk_size). - """, - ) - - parser.add_argument( - "--num-left-chunks", - type=int, - default=4, - help="How many left context can be seen in chunks when calculating attention.", - ) - - parser.add_argument( - "--decode-chunk-len", - type=int, - default=32, - help="The chunk size for decoding (in frames before subsampling)", - ) - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--world-size", - type=int, - default=1, - help="Number of GPUs for DDP training.", - ) - - parser.add_argument( - "--master-port", - type=int, - default=12354, - help="Master port to use for DDP training.", - ) - - parser.add_argument( - "--tensorboard", - type=str2bool, - default=True, - help="Should various information be logged in tensorboard.", - ) - - parser.add_argument( - "--num-epochs", - type=int, - default=30, - help="Number of epochs to train.", - ) - - parser.add_argument( - "--start-epoch", - type=int, - default=1, - help="""Resume training from this epoch. It should be positive. - If larger than 1, it will load checkpoint from - exp-dir/epoch-{start_epoch-1}.pt - """, - ) - - parser.add_argument( - "--start-batch", - type=int, - default=0, - help="""If positive, --start-epoch is ignored and - it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="pruned_transducer_stateless7_streaming/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--base-lr", type=float, default=0.05, help="The base learning rate." - ) - - parser.add_argument( - "--lr-batches", - type=float, - default=5000, - help="""Number of steps that affects how rapidly the learning rate - decreases. We suggest not to change this.""", - ) - - parser.add_argument( - "--lr-epochs", - type=float, - default=3.5, - help="""Number of epochs that affects how rapidly the learning rate decreases. - """, - ) - - parser.add_argument( - "--context-size", - type=int, - default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", - ) - - parser.add_argument( - "--prune-range", - type=int, - default=5, - help="The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss", - ) - - parser.add_argument( - "--lm-scale", - type=float, - default=0.25, - help="The scale to smooth the loss with lm " - "(output of prediction network) part.", - ) - - parser.add_argument( - "--am-scale", - type=float, - default=0.0, - help="The scale to smooth the loss with am (output of encoder network) part.", - ) - - parser.add_argument( - "--simple-loss-scale", - type=float, - default=0.5, - help="To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss.", - ) - - parser.add_argument( - "--seed", - type=int, - default=42, - help="The seed for random generators intended for reproducibility", - ) - - parser.add_argument( - "--print-diagnostics", - type=str2bool, - default=False, - help="Accumulate stats on activations, print them and exit.", - ) - - parser.add_argument( - "--inf-check", - type=str2bool, - default=False, - help="Add hooks to check for infinite module outputs and gradients.", - ) - - parser.add_argument( - "--save-every-n", - type=int, - default=2000, - help="""Save checkpoint after processing this number of batches" - periodically. We save checkpoint to exp-dir/ whenever - params.batch_idx_train % save_every_n == 0. The checkpoint filename - has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' - Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the - end of each epoch where `xxx` is the epoch number counting from 0. - """, - ) - - parser.add_argument( - "--keep-last-k", - type=int, - default=30, - help="""Only keep this number of checkpoints on disk. - For instance, if it is 3, there are only 3 checkpoints - in the exp-dir with filenames `checkpoint-xxx.pt`. - It does not affect checkpoints with name `epoch-xxx.pt`. - """, - ) - - parser.add_argument( - "--average-period", - type=int, - default=200, - help="""Update the averaged model, namely `model_avg`, after processing - this number of batches. `model_avg` is a separate version of model, - in which each floating-point parameter is the average of all the - parameters from the start of training. Each time we take the average, - we do: `model_avg = model * (average_period / batch_idx_train) + - model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. - """, - ) - - parser.add_argument( - "--use-fp16", - type=str2bool, - default=False, - help="Whether to use half precision training.", - ) - - parser.add_argument( - "--pad-feature", - type=int, - default=0, - help=""" - Number of frames to pad at the end. - """, - ) - - add_model_arguments(parser) - - return parser - - -def get_params() -> AttributeDict: - """Return a dict containing training parameters. - - All training related parameters that are not passed from the commandline - are saved in the variable `params`. - - Commandline options are merged into `params` after they are parsed, so - you can also access them via `params`. - - Explanation of options saved in `params`: - - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. - """ - params = AttributeDict( - { - "best_train_loss": float("inf"), - "best_valid_loss": float("inf"), - "best_train_epoch": -1, - "best_valid_epoch": -1, - "batch_idx_train": 0, - "log_interval": 50, - "reset_interval": 200, - "valid_interval": 1000, # For the 100h subset, use 800 - # parameters for zipformer - "feature_dim": 80, - "subsampling_factor": 4, # not passed in, this is fixed. - "warm_step": 2000, - "env_info": get_env_info(), - } - ) - - return params - - -def get_encoder_model(params: AttributeDict) -> nn.Module: - # TODO: We can add an option to switch between Zipformer and Transformer - def to_int_tuple(s: str): - return tuple(map(int, s.split(","))) - - encoder = Zipformer( - num_features=params.feature_dim, - output_downsampling_factor=2, - zipformer_downsampling_factors=to_int_tuple( - params.zipformer_downsampling_factors - ), - encoder_dims=to_int_tuple(params.encoder_dims), - attention_dim=to_int_tuple(params.attention_dims), - encoder_unmasked_dims=to_int_tuple(params.encoder_unmasked_dims), - nhead=to_int_tuple(params.nhead), - feedforward_dim=to_int_tuple(params.feedforward_dims), - cnn_module_kernels=to_int_tuple(params.cnn_module_kernels), - num_encoder_layers=to_int_tuple(params.num_encoder_layers), - num_left_chunks=params.num_left_chunks, - short_chunk_size=params.short_chunk_size, - decode_chunk_size=params.decode_chunk_len // 2, - ) - return encoder - - -def get_decoder_model(params: AttributeDict) -> nn.Module: - decoder = Decoder( - vocab_size=params.vocab_size, - decoder_dim=params.decoder_dim, - blank_id=params.blank_id, - context_size=params.context_size, - ) - return decoder - - -def get_joiner_model(params: AttributeDict) -> nn.Module: - joiner = Joiner( - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return joiner - - -def get_transducer_model(params: AttributeDict) -> nn.Module: - encoder = get_encoder_model(params) - decoder = get_decoder_model(params) - joiner = get_joiner_model(params) - - model = Transducer( - encoder=encoder, - decoder=decoder, - joiner=joiner, - encoder_dim=int(params.encoder_dims.split(",")[-1]), - decoder_dim=params.decoder_dim, - joiner_dim=params.joiner_dim, - vocab_size=params.vocab_size, - ) - return model - - -def load_checkpoint_if_available( - params: AttributeDict, - model: nn.Module, - model_avg: nn.Module = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, -) -> Optional[Dict[str, Any]]: - """Load checkpoint from file. - - If params.start_batch is positive, it will load the checkpoint from - `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if - params.start_epoch is larger than 1, it will load the checkpoint from - `params.start_epoch - 1`. - - Apart from loading state dict for `model` and `optimizer` it also updates - `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, - and `best_valid_loss` in `params`. - - Args: - params: - The return value of :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer that we are using. - scheduler: - The scheduler that we are using. - Returns: - Return a dict containing previously saved training info. - """ - if params.start_batch > 0: - filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" - elif params.start_epoch > 1: - filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" - else: - return None - - assert filename.is_file(), f"{filename} does not exist!" - - saved_params = load_checkpoint( - filename, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - ) - - keys = [ - "best_train_epoch", - "best_valid_epoch", - "batch_idx_train", - "best_train_loss", - "best_valid_loss", - ] - for k in keys: - params[k] = saved_params[k] - - if params.start_batch > 0: - if "cur_epoch" in saved_params: - params["start_epoch"] = saved_params["cur_epoch"] - - return saved_params - - -def save_checkpoint( - params: AttributeDict, - model: Union[nn.Module, DDP], - model_avg: Optional[nn.Module] = None, - optimizer: Optional[torch.optim.Optimizer] = None, - scheduler: Optional[LRSchedulerType] = None, - sampler: Optional[CutSampler] = None, - scaler: Optional[GradScaler] = None, - rank: int = 0, -) -> None: - """Save model, optimizer, scheduler and training stats to file. - - Args: - params: - It is returned by :func:`get_params`. - model: - The training model. - model_avg: - The stored model averaged from the start of training. - optimizer: - The optimizer used in the training. - sampler: - The sampler for the training dataset. - scaler: - The scaler used for mix precision training. - """ - if rank != 0: - return - filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" - save_checkpoint_impl( - filename=filename, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=sampler, - scaler=scaler, - rank=rank, - ) - - if params.best_train_epoch == params.cur_epoch: - best_train_filename = params.exp_dir / "best-train-loss.pt" - copyfile(src=filename, dst=best_train_filename) - - if params.best_valid_epoch == params.cur_epoch: - best_valid_filename = params.exp_dir / "best-valid-loss.pt" - copyfile(src=filename, dst=best_valid_filename) - - -def compute_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: Tokenizer, - batch: dict, - is_training: bool, -) -> Tuple[Tensor, MetricsTracker]: - """ - Compute transducer loss given the model and its inputs. - - Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. - """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device - feature = batch["inputs"] - # at entry, feature is (N, T, C) - assert feature.ndim == 3 - feature = feature.to(device) - - supervisions = batch["supervisions"] - feature_lens = supervisions["num_frames"].to(device) - - if params.pad_feature: - feature_lens += params.pad_feature - feature = torch.nn.functional.pad( - feature, - pad=(0, 0, 0, params.pad_feature), - value=LOG_EPS, - ) - - batch_idx_train = params.batch_idx_train - warm_step = params.warm_step - - texts = batch["supervisions"]["text"] - y = sp.encode(texts, out_type=int) - y = k2.RaggedTensor(y).to(device) - - with torch.set_grad_enabled(is_training): - simple_loss, pruned_loss = model( - x=feature, - x_lens=feature_lens, - y=y, - prune_range=params.prune_range, - am_scale=params.am_scale, - lm_scale=params.lm_scale, - ) - - s = params.simple_loss_scale - # take down the scale on the simple loss from 1.0 at the start - # to params.simple_loss scale by warm_step. - simple_loss_scale = ( - s - if batch_idx_train >= warm_step - else 1.0 - (batch_idx_train / warm_step) * (1.0 - s) - ) - pruned_loss_scale = ( - 1.0 - if batch_idx_train >= warm_step - else 0.1 + 0.9 * (batch_idx_train / warm_step) - ) - - loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss - - assert loss.requires_grad == is_training - - info = MetricsTracker() - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() - - # Note: We use reduction=sum while computing the loss. - info["loss"] = loss.detach().cpu().item() - info["simple_loss"] = simple_loss.detach().cpu().item() - info["pruned_loss"] = pruned_loss.detach().cpu().item() - - return loss, info - - -def compute_validation_loss( - params: AttributeDict, - model: Union[nn.Module, DDP], - sp: Tokenizer, - valid_dl: torch.utils.data.DataLoader, - world_size: int = 1, -) -> MetricsTracker: - """Run the validation process.""" - model.eval() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(valid_dl): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=False, - ) - assert loss.requires_grad is False - tot_loss = tot_loss + loss_info - - if world_size > 1: - tot_loss.reduce(loss.device) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - if loss_value < params.best_valid_loss: - params.best_valid_epoch = params.cur_epoch - params.best_valid_loss = loss_value - - return tot_loss - - -def train_one_epoch( - params: AttributeDict, - model: Union[nn.Module, DDP], - optimizer: torch.optim.Optimizer, - scheduler: LRSchedulerType, - sp: Tokenizer, - train_dl: torch.utils.data.DataLoader, - valid_dl: torch.utils.data.DataLoader, - scaler: GradScaler, - model_avg: Optional[nn.Module] = None, - tb_writer: Optional[SummaryWriter] = None, - world_size: int = 1, - rank: int = 0, -) -> None: - """Train the model for one epoch. - - The training loss from the mean of all frames is saved in - `params.train_loss`. It runs the validation process every - `params.valid_interval` batches. - - Args: - params: - It is returned by :func:`get_params`. - model: - The model for training. - optimizer: - The optimizer we are using. - scheduler: - The learning rate scheduler, we call step() every step. - train_dl: - Dataloader for the training dataset. - valid_dl: - Dataloader for the validation dataset. - scaler: - The scaler used for mix precision training. - model_avg: - The stored model averaged from the start of training. - tb_writer: - Writer to write log messages to tensorboard. - world_size: - Number of nodes in DDP training. If it is 1, DDP is disabled. - rank: - The rank of the node in DDP training. If no DDP is used, it should - be set to 0. - """ - model.train() - - tot_loss = MetricsTracker() - - for batch_idx, batch in enumerate(train_dl): - params.batch_idx_train += 1 - batch_size = len(batch["supervisions"]["text"]) - - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, loss_info = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - # summary stats - tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info - - # NOTE: We use reduction==sum and loss is computed over utterances - # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) - - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() - except Exception as e: # noqa - logging.error(e, exc_info=True) - display_and_save_batch(batch, params=params, sp=sp) - raise e - - if params.print_diagnostics and batch_idx == 5: - return - - if ( - rank == 0 - and params.batch_idx_train > 0 - and params.batch_idx_train % params.average_period == 0 - ): - update_averaged_model( - params=params, - model_cur=model, - model_avg=model_avg, - ) - - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: - # If the grad scale was less than 1, try increasing it. The _growth_interval - # of the grad scaler is configurable, but we can't configure it to have different - # behavior depending on the current grad scale. - cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): - scaler.update(cur_grad_scale * 2.0) - if cur_grad_scale < 0.01: - logging.warning(f"Grad scale is small: {cur_grad_scale}") - if cur_grad_scale < 1.0e-05: - raise RuntimeError( - f"grad_scale is too small, exiting: {cur_grad_scale}" - ) - - if batch_idx % params.log_interval == 0: - cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 - - logging.info( - f"Epoch {params.cur_epoch}, " - f"batch {batch_idx}, loss[{loss_info}], " - f"tot_loss[{tot_loss}], batch size: {batch_size}, " - f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") - ) - - if tb_writer is not None: - tb_writer.add_scalar( - "train/learning_rate", cur_lr, params.batch_idx_train - ) - - loss_info.write_summary( - tb_writer, "train/current_", params.batch_idx_train - ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - if params.use_fp16: - tb_writer.add_scalar( - "train/grad_scale", - cur_grad_scale, - params.batch_idx_train, - ) - - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - sp=sp, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - log_mode = logging.info - log_mode(f"Epoch {params.cur_epoch}, validation: {valid_info}") - log_mode( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) - - loss_value = tot_loss["loss"] / tot_loss["frames"] - params.train_loss = loss_value - if params.train_loss < params.best_train_loss: - params.best_train_epoch = params.cur_epoch - params.best_train_loss = params.train_loss - - -def run(rank, world_size, args): - """ - Args: - rank: - It is a value between 0 and `world_size-1`, which is - passed automatically by `mp.spawn()` in :func:`main`. - The node with rank 0 is responsible for saving checkpoint. - world_size: - Number of GPUs for DDP training. - args: - The return value of get_parser().parse_args() - """ - params = get_params() - params.update(vars(args)) - - fix_random_seed(params.seed) - if world_size > 1: - setup_dist(rank, world_size, master_port=params.master_port) - - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - sp = Tokenizer.load(args.lang, args.lang_type) - - # is defined in local/prepare_lang_char.py - params.blank_id = sp.piece_to_id("") - params.vocab_size = sp.get_piece_size() - - logging.info(params) - - logging.info("About to create model") - model = get_transducer_model(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") - - assert params.save_every_n >= params.average_period - model_avg: Optional[nn.Module] = None - if rank == 0: - # model_avg is only used with rank 0 - model_avg = copy.deepcopy(model).to(torch.float64) - - assert params.start_epoch > 0, params.start_epoch - checkpoints = load_checkpoint_if_available( - params=params, model=model, model_avg=model_avg - ) - - model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) - - parameters_names = [] - parameters_names.append( - [name_param_pair[0] for name_param_pair in model.named_parameters()] - ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - parameters_names=parameters_names, - ) - - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) - - if checkpoints and "optimizer" in checkpoints: - logging.info("Loading optimizer state dict") - optimizer.load_state_dict(checkpoints["optimizer"]) - - if ( - checkpoints - and "scheduler" in checkpoints - and checkpoints["scheduler"] is not None - ): - logging.info("Loading scheduler state dict") - scheduler.load_state_dict(checkpoints["scheduler"]) - - if params.print_diagnostics: - opts = diagnostics.TensorDiagnosticOptions( - 512 - ) # allow 4 megabytes per sub-module - diagnostic = diagnostics.attach_diagnostics(model, opts) - - if params.inf_check: - register_inf_check_hooks(model) - - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 0.3 or c.duration > 30.0: - logging.debug( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - T = ((c.num_frames - 7) // 2 + 1) // 2 - tokens = sp.encode(c.supervisions[0].text, out_type=str) - - if T < len(tokens): - logging.info( - f"Exclude cut with ID {c.id} from training. " - f"Number of frames (before subsampling): {c.num_frames}. " - f"Number of frames (after subsampling): {T}. " - f"Text: {c.supervisions[0].text}. " - f"Tokens: {tokens}. " - f"Number of tokens: {len(tokens)}" - ) - return False - - return True - - reazonspeech_corpus = ReazonSpeechAsrDataModule(args) - train_cuts = reazonspeech_corpus.train_cuts() - - train_cuts = train_cuts.filter(remove_short_and_long_utt) - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: - # We only load the sampler's state dict when it loads a checkpoint - # saved in the middle of an epoch - sampler_state_dict = checkpoints["sampler"] - else: - sampler_state_dict = None - - train_dl = reazonspeech_corpus.train_dataloaders( - train_cuts, sampler_state_dict=sampler_state_dict - ) - - valid_cuts = reazonspeech_corpus.valid_cuts() - valid_dl = reazonspeech_corpus.valid_dataloaders(valid_cuts) - - if params.start_batch <= 0 and not params.print_diagnostics: - scan_pessimistic_batches_for_oom( - model=model, - train_dl=train_dl, - optimizer=optimizer, - sp=sp, - params=params, - ) - - scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) - if checkpoints and "grad_scaler" in checkpoints: - logging.info("Loading grad scaler state dict") - scaler.load_state_dict(checkpoints["grad_scaler"]) - - for epoch in range(params.start_epoch, params.num_epochs + 1): - scheduler.step_epoch(epoch - 1) - fix_random_seed(params.seed + epoch - 1) - train_dl.sampler.set_epoch(epoch - 1) - - if tb_writer is not None: - tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) - - params.cur_epoch = epoch - - train_one_epoch( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sp=sp, - train_dl=train_dl, - valid_dl=valid_dl, - scaler=scaler, - tb_writer=tb_writer, - world_size=world_size, - rank=rank, - ) - - if params.print_diagnostics: - diagnostic.print_diagnostics() - break - - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - - logging.info("Done!") - - if world_size > 1: - torch.distributed.barrier() - cleanup_dist() - - -def display_and_save_batch( - batch: dict, - params: AttributeDict, - sp: Tokenizer, -) -> None: - """Display the batch statistics and save the batch into disk. - - Args: - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - params: - Parameters for training. See :func:`get_params`. - sp: - The BPE model. - """ - from lhotse.utils import uuid4 - - filename = f"{params.exp_dir}/batch-{uuid4()}.pt" - logging.info(f"Saving batch to {filename}") - torch.save(batch, filename) - - supervisions = batch["supervisions"] - features = batch["inputs"] - - logging.info(f"features shape: {features.shape}") - - y = sp.encode(supervisions["text"], out_type=int) - num_tokens = sum(len(i) for i in y) - logging.info(f"num tokens: {num_tokens}") - - -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - sp: Tokenizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches - - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - sp=sp, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params, sp=sp) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - - -def main(): - parser = get_parser() - ReazonSpeechAsrDataModule.add_arguments(parser) - Tokenizer.add_arguments(parser) - args = parser.parse_args() - - world_size = args.world_size - assert world_size >= 1 - if world_size > 1: - mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) - else: - run(rank=0, world_size=1, args=args) - - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - -if __name__ == "__main__": - main() diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py deleted file mode 120000 index ec183baa7..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer.py \ No newline at end of file diff --git a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py b/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py deleted file mode 120000 index d301e1f9b..000000000 --- a/egs/reazonspeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py +++ /dev/null @@ -1 +0,0 @@ -../../../librispeech/ASR/pruned_transducer_stateless7_streaming/zipformer_for_ncnn_export_only.py \ No newline at end of file