Zipformer recipe for ReazonSpeech (#1611)

* Add first cut at ReazonSpeech recipe

This recipe is mostly based on egs/csj, but tweaked to the point that
can be run with ReazonSpeech corpus.

Signed-off-by: Fujimoto Seiji <fujimoto@ceptord.net>

---------

Signed-off-by: Fujimoto Seiji <fujimoto@ceptord.net>
Co-authored-by: Fujimoto Seiji <fujimoto@ceptord.net>
Co-authored-by: Chen <qc@KDM00.cm.cluster>
Co-authored-by: root <root@KDA01.cm.cluster>
This commit is contained in:
Triplecq 2024-06-13 02:19:03 -04:00 committed by GitHub
parent d5be739639
commit 3b40d9bbb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 5488 additions and 0 deletions

View File

@ -0,0 +1,29 @@
# Introduction
**ReazonSpeech** is an open-source dataset that contains a diverse set of natural Japanese speech, collected from terrestrial television streams. It contains more than 35,000 hours of audio.
The dataset is available on Hugging Face. For more details, please visit:
- Dataset: https://huggingface.co/datasets/reazon-research/reazonspeech
- Paper: https://research.reazon.jp/_static/reazonspeech_nlp2023.pdf
[./RESULTS.md](./RESULTS.md) contains the latest results.
# Transducers
There are various folders containing the name `transducer` in this folder. The following table lists the differences among them.
| | Encoder | Decoder | Comment |
| ---------------------------------------- | -------------------- | ------------------ | ------------------------------------------------- |
| `zipformer` | Upgraded Zipformer | Embedding + Conv1d | The latest recipe |
The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). We place an additional Conv1d layer right after the input embedding layer.

View File

@ -0,0 +1,49 @@
## Results
### Zipformer
#### Non-streaming
##### large-scaled model, number of model parameters: 159337842, i.e., 159.34 M
| decoding method | In-Distribution CER | JSUT | CommonVoice | TEDx | comment |
| :------------------: | :-----------------: | :--: | :---------: | :---: | :----------------: |
| greedy search | 4.2 | 6.7 | 7.84 | 17.9 | --epoch 39 --avg 7 |
| modified beam search | 4.13 | 6.77 | 7.69 | 17.82 | --epoch 39 --avg 7 |
The training command is:
```shell
./zipformer/train.py \
--world-size 8 \
--num-epochs 40 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-large \
--causal 0 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--lang data/lang_char \
--max-duration 1600
```
The decoding command is:
```shell
./zipformer/decode.py \
--epoch 40 \
--avg 16 \
--exp-dir zipformer/exp-large \
--max-duration 600 \
--causal 0 \
--decoding-method greedy_search \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--lang data/lang_char \
--blank-penalty 0
```

View File

@ -0,0 +1,146 @@
#!/usr/bin/env python3
# Copyright 2023 The University of Electro-Communications (Author: Teo Wen Shen) # noqa
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import os
from pathlib import Path
from typing import List, Tuple
import torch
# fmt: off
from lhotse import ( # See the following for why LilcomChunkyWriter is preferred; https://github.com/k2-fsa/icefall/pull/404; https://github.com/lhotse-speech/lhotse/pull/527
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
RecordingSet,
SupervisionSet,
)
# fmt: on
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
# Do this outside of main() in case it needs to take effect
# even when we are not invoking the main (e.g. when spawning subprocesses).
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
RNG_SEED = 42
concat_params = {"gap": 1.0, "maxlen": 10.0}
def make_cutset_blueprints(
manifest_dir: Path,
) -> List[Tuple[str, CutSet]]:
cut_sets = []
# Create test dataset
logging.info("Creating test cuts.")
cut_sets.append(
(
"test",
CutSet.from_manifests(
recordings=RecordingSet.from_file(
manifest_dir / "reazonspeech_recordings_test.jsonl.gz"
),
supervisions=SupervisionSet.from_file(
manifest_dir / "reazonspeech_supervisions_test.jsonl.gz"
),
),
)
)
# Create dev dataset
logging.info("Creating dev cuts.")
cut_sets.append(
(
"dev",
CutSet.from_manifests(
recordings=RecordingSet.from_file(
manifest_dir / "reazonspeech_recordings_dev.jsonl.gz"
),
supervisions=SupervisionSet.from_file(
manifest_dir / "reazonspeech_supervisions_dev.jsonl.gz"
),
),
)
)
# Create train dataset
logging.info("Creating train cuts.")
cut_sets.append(
(
"train",
CutSet.from_manifests(
recordings=RecordingSet.from_file(
manifest_dir / "reazonspeech_recordings_train.jsonl.gz"
),
supervisions=SupervisionSet.from_file(
manifest_dir / "reazonspeech_supervisions_train.jsonl.gz"
),
),
)
)
return cut_sets
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("-m", "--manifest-dir", type=Path)
return parser.parse_args()
def main():
args = get_args()
extractor = Fbank(FbankConfig(num_mel_bins=80))
num_jobs = min(16, os.cpu_count())
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
if (args.manifest_dir / ".reazonspeech-fbank.done").exists():
logging.info(
"Previous fbank computed for ReazonSpeech found. "
f"Delete {args.manifest_dir / '.reazonspeech-fbank.done'} to allow recomputing fbank."
)
return
else:
cut_sets = make_cutset_blueprints(args.manifest_dir)
for part, cut_set in cut_sets:
logging.info(f"Processing {part}")
cut_set = cut_set.compute_and_store_features(
extractor=extractor,
num_jobs=num_jobs,
storage_path=(args.manifest_dir / f"feats_{part}").as_posix(),
storage_type=LilcomChunkyWriter,
)
cut_set.to_file(args.manifest_dir / f"reazonspeech_cuts_{part}.jsonl.gz")
logging.info("All fbank computed for ReazonSpeech.")
(args.manifest_dir / ".reazonspeech-fbank.done").touch()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,58 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# 2022 The University of Electro-Communications (author: Teo Wen Shen) # noqa
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from pathlib import Path
from lhotse import CutSet, load_manifest
ARGPARSE_DESCRIPTION = """
This file displays duration statistics of utterances in a manifest.
You can use the displayed value to choose minimum/maximum duration
to remove short and long utterances during the training.
See the function `remove_short_and_long_utt()` in
pruned_transducer_stateless5/train.py for usage.
"""
def get_parser():
parser = argparse.ArgumentParser(
description=ARGPARSE_DESCRIPTION,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests")
return parser.parse_args()
def main():
args = get_parser()
for part in ["train", "dev"]:
path = args.manifest_dir / f"reazonspeech_cuts_{part}.jsonl.gz"
cuts: CutSet = load_manifest(path)
print("\n---------------------------------\n")
print(path.name + ":")
cuts.describe()
if __name__ == "__main__":
main()

View File

@ -0,0 +1,75 @@
#!/usr/bin/env python3
# Copyright 2022 The University of Electro-Communications (Author: Teo Wen Shen) # noqa
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from lhotse import CutSet
def get_args():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"train_cut", metavar="train-cut", type=Path, help="Path to the train cut"
)
parser.add_argument(
"--lang-dir",
type=Path,
default=Path("data/lang_char"),
help=(
"Name of lang dir. "
"If not set, this will default to lang_char_{trans-mode}"
),
)
return parser.parse_args()
def main():
args = get_args()
logging.basicConfig(
format=("%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"),
level=logging.INFO,
)
sysdef_string = set(["<blk>", "<unk>", "<sos/eos>", " "])
token_set = set()
logging.info(f"Creating vocabulary from {args.train_cut}.")
train_cut: CutSet = CutSet.from_file(args.train_cut)
for cut in train_cut:
for sup in cut.supervisions:
token_set.update(sup.text)
token_set = ["<blk>"] + sorted(token_set - sysdef_string) + ["<unk>", "<sos/eos>"]
args.lang_dir.mkdir(parents=True, exist_ok=True)
(args.lang_dir / "tokens.txt").write_text(
"\n".join(f"{t}\t{i}" for i, t in enumerate(token_set))
)
(args.lang_dir / "lang_type").write_text("char")
logging.info("Done.")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,355 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import inspect
import logging
from functools import lru_cache
from pathlib import Path
from typing import Any, Dict, List, Optional
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import (
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class ReazonSpeechAsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/manifests"),
help="Path to directory with train/dev/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=200.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--num-buckets",
type=int,
default=30,
help="The number of buckets for the DynamicBucketingSampler"
"(you might want to increase it for larger datasets).",
)
group.add_argument(
"--concatenate-cuts",
type=str2bool,
default=False,
help="When enabled, utterances (cuts) will be concatenated "
"to minimize the amount of padding.",
)
group.add_argument(
"--duration-factor",
type=float,
default=1.0,
help="Determines the maximum duration of a concatenated cut "
"relative to the duration of the longest cut in a batch.",
)
group.add_argument(
"--gap",
type=float,
default=1.0,
help="The amount of padding (in seconds) inserted between "
"concatenated cuts. This padding is filled with noise when "
"noise augmentation is used.",
)
group.add_argument(
"--on-the-fly-feats",
type=str2bool,
default=False,
help="When enabled, use on-the-fly cut mixing and feature "
"extraction. Will drop existing precomputed feature manifests "
"if available.",
)
group.add_argument(
"--shuffle",
type=str2bool,
default=True,
help="When enabled (=default), the examples will be "
"shuffled for each epoch.",
)
group.add_argument(
"--drop-last",
type=str2bool,
default=True,
help="Whether to drop last batch. Used by sampler.",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=False,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=2,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--enable-spec-aug",
type=str2bool,
default=True,
help="When enabled, use SpecAugment for training dataset.",
)
group.add_argument(
"--spec-aug-time-warp-factor",
type=int,
default=80,
help="Used only when --enable-spec-aug is True. "
"It specifies the factor for time warping in SpecAugment. "
"Larger values mean more warping. "
"A value less than 1 means to disable time warp.",
)
group.add_argument(
"--enable-musan",
type=str2bool,
default=False,
help="When enabled, select noise from MUSAN and mix it"
"with training dataset. ",
)
def train_dataloaders(
self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None
) -> DataLoader:
"""
Args:
cuts_train:
CutSet for training.
sampler_state_dict:
The state dict for the training sampler.
"""
transforms = []
input_transforms = []
if self.args.enable_spec_aug:
logging.info("Enable SpecAugment")
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
# Set the value of num_frame_masks according to Lhotse's version.
# In different Lhotse's versions, the default of num_frame_masks is
# different.
num_frame_masks = 10
num_frame_masks_parameter = inspect.signature(
SpecAugment.__init__
).parameters["num_frame_masks"]
if num_frame_masks_parameter.default == 1:
num_frame_masks = 2
logging.info(f"Num frame mask: {num_frame_masks}")
input_transforms.append(
SpecAugment(
time_warp_factor=self.args.spec_aug_time_warp_factor,
num_frame_masks=num_frame_masks,
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
)
)
else:
logging.info("Disable SpecAugment")
logging.info("About to create train dataset")
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.on_the_fly_feats:
# NOTE: the PerturbSpeed transform should be added only if we
# remove it from data prep stage.
# Add on-the-fly speed perturbation; since originally it would
# have increased epoch size by 3, we will apply prob 2/3 and use
# 3x more epochs.
# Speed perturbation probably should come first before
# concatenation, but in principle the transforms order doesn't have
# to be strict (e.g. could be randomized)
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
# Drop feats to be on the safe side.
train = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
input_transforms=input_transforms,
return_cuts=self.args.return_cuts,
)
if self.args.bucketing_sampler:
logging.info("Using DynamicBucketingSampler.")
train_sampler = DynamicBucketingSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
num_buckets=self.args.num_buckets,
drop_last=self.args.drop_last,
)
else:
logging.info("Using SimpleCutSampler.")
train_sampler = SimpleCutSampler(
cuts_train,
max_duration=self.args.max_duration,
shuffle=self.args.shuffle,
)
logging.info("About to create train dataloader")
if sampler_state_dict is not None:
logging.info("Loading sampler state dict")
train_sampler.load_state_dict(sampler_state_dict)
train_dl = DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return train_dl
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
transforms = []
if self.args.concatenate_cuts:
transforms = [
CutConcatenate(
duration_factor=self.args.duration_factor, gap=self.args.gap
)
] + transforms
logging.info("About to create dev dataset")
if self.args.on_the_fly_feats:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
else:
validate = K2SpeechRecognitionDataset(
cut_transforms=transforms,
return_cuts=self.args.return_cuts,
)
valid_sampler = DynamicBucketingSampler(
cuts_valid,
max_duration=self.args.max_duration,
shuffle=False,
)
logging.info("About to create dev dataloader")
valid_dl = DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=2,
persistent_workers=False,
)
return valid_dl
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.info("About to create test dataset")
test = K2SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else PrecomputedFeatures(),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
)
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
)
return test_dl
@lru_cache()
def train_cuts(self) -> CutSet:
logging.info("About to get train cuts")
return load_manifest_lazy(
self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz"
)
@lru_cache()
def valid_cuts(self) -> CutSet:
logging.info("About to get dev cuts")
return load_manifest_lazy(
self.args.manifest_dir / "reazonspeech_cuts_dev.jsonl.gz"
)
@lru_cache()
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
return load_manifest_lazy(
self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz"
)

View File

@ -0,0 +1,253 @@
import argparse
from pathlib import Path
from typing import Callable, List, Union
import sentencepiece as spm
from k2 import SymbolTable
class Tokenizer:
text2word: Callable[[str], List[str]]
@staticmethod
def add_arguments(parser: argparse.ArgumentParser):
group = parser.add_argument_group(title="Lang related options")
group.add_argument("--lang", type=Path, help="Path to lang directory.")
group.add_argument(
"--lang-type",
type=str,
default=None,
help=(
"Either 'bpe' or 'char'. If not provided, it expects lang_dir/lang_type to exists. "
"Note: 'bpe' directly loads sentencepiece.SentencePieceProcessor"
),
)
@staticmethod
def Load(lang_dir: Path, lang_type="", oov="<unk>"):
if not lang_type:
assert (lang_dir / "lang_type").exists(), "lang_type not specified."
lang_type = (lang_dir / "lang_type").read_text().strip()
tokenizer = None
if lang_type == "bpe":
assert (
lang_dir / "bpe.model"
).exists(), f"No BPE .model could be found in {lang_dir}."
tokenizer = spm.SentencePieceProcessor()
tokenizer.Load(str(lang_dir / "bpe.model"))
elif lang_type == "char":
tokenizer = CharTokenizer(lang_dir, oov=oov)
else:
raise NotImplementedError(f"{lang_type} not supported at the moment.")
return tokenizer
load = Load
def PieceToId(self, piece: str) -> int:
raise NotImplementedError(
"You need to implement this function in the child class."
)
piece_to_id = PieceToId
def IdToPiece(self, id: int) -> str:
raise NotImplementedError(
"You need to implement this function in the child class."
)
id_to_piece = IdToPiece
def GetPieceSize(self) -> int:
raise NotImplementedError(
"You need to implement this function in the child class."
)
get_piece_size = GetPieceSize
def __len__(self) -> int:
return self.get_piece_size()
def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]:
raise NotImplementedError(
"You need to implement this function in the child class."
)
def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]:
raise NotImplementedError(
"You need to implement this function in the child class."
)
def EncodeAsIds(self, input: str) -> List[int]:
return self.EncodeAsIdsBatch([input])[0]
def EncodeAsPieces(self, input: str) -> List[str]:
return self.EncodeAsPiecesBatch([input])[0]
def Encode(
self, input: Union[str, List[str]], out_type=int
) -> Union[List, List[List]]:
if not input:
return []
if isinstance(input, list):
if out_type is int:
return self.EncodeAsIdsBatch(input)
if out_type is str:
return self.EncodeAsPiecesBatch(input)
if out_type is int:
return self.EncodeAsIds(input)
if out_type is str:
return self.EncodeAsPieces(input)
encode = Encode
def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]:
raise NotImplementedError(
"You need to implement this function in the child class."
)
def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]:
raise NotImplementedError(
"You need to implement this function in the child class."
)
def DecodeIds(self, input: List[int]) -> str:
return self.DecodeIdsBatch([input])[0]
def DecodePieces(self, input: List[str]) -> str:
return self.DecodePiecesBatch([input])[0]
def Decode(
self,
input: Union[int, List[int], List[str], List[List[int]], List[List[str]]],
) -> Union[List[str], str]:
if not input:
return ""
if isinstance(input, int):
return self.id_to_piece(input)
elif isinstance(input, str):
raise TypeError(
"Unlike spm.SentencePieceProcessor, cannot decode from type str."
)
if isinstance(input[0], list):
if not input[0] or isinstance(input[0][0], int):
return self.DecodeIdsBatch(input)
if isinstance(input[0][0], str):
return self.DecodePiecesBatch(input)
if isinstance(input[0], int):
return self.DecodeIds(input)
if isinstance(input[0], str):
return self.DecodePieces(input)
raise RuntimeError("Unknown input type")
decode = Decode
def SplitBatch(self, input: List[str]) -> List[List[str]]:
raise NotImplementedError(
"You need to implement this function in the child class."
)
def Split(self, input: Union[List[str], str]) -> Union[List[List[str]], List[str]]:
if isinstance(input, list):
return self.SplitBatch(input)
elif isinstance(input, str):
return self.SplitBatch([input])[0]
raise RuntimeError("Unknown input type")
split = Split
class CharTokenizer(Tokenizer):
def __init__(self, lang_dir: Path, oov="<unk>", sep=""):
assert (
lang_dir / "tokens.txt"
).exists(), f"tokens.txt could not be found in {lang_dir}."
token_table = SymbolTable.from_file(lang_dir / "tokens.txt")
assert (
"#0" not in token_table
), "This tokenizer does not support disambig symbols."
self._id2sym = token_table._id2sym
self._sym2id = token_table._sym2id
self.oov = oov
self.oov_id = self._sym2id[oov]
self.sep = sep
if self.sep:
self.text2word = lambda x: x.split(self.sep)
else:
self.text2word = lambda x: list(x.replace(" ", ""))
def piece_to_id(self, piece: str) -> int:
try:
return self._sym2id[piece]
except KeyError:
return self.oov_id
def id_to_piece(self, id: int) -> str:
return self._id2sym[id]
def get_piece_size(self) -> int:
return len(self._sym2id)
def EncodeAsIdsBatch(self, input: List[str]) -> List[List[int]]:
return [[self.piece_to_id(i) for i in self.text2word(text)] for text in input]
def EncodeAsPiecesBatch(self, input: List[str]) -> List[List[str]]:
return [
[i if i in self._sym2id else self.oov for i in self.text2word(text)]
for text in input
]
def DecodeIdsBatch(self, input: List[List[int]]) -> List[str]:
return [self.sep.join(self.id_to_piece(i) for i in text) for text in input]
def DecodePiecesBatch(self, input: List[List[str]]) -> List[str]:
return [self.sep.join(text) for text in input]
def SplitBatch(self, input: List[str]) -> List[List[str]]:
return [self.text2word(text) for text in input]
def test_CharTokenizer():
test_single_string = "こんにちは"
test_multiple_string = [
"今日はいい天気ですよね",
"諏訪湖は綺麗でしょう",
"这在词表外",
"分かち 書き に し た 文章 です",
"",
]
test_empty_string = ""
sp = Tokenizer.load(Path("lang_char"), "char", oov="<unk>")
splitter = sp.split
print(sp.encode(test_single_string, out_type=str))
print(sp.encode(test_single_string, out_type=int))
print(sp.encode(test_multiple_string, out_type=str))
print(sp.encode(test_multiple_string, out_type=int))
print(sp.encode(test_empty_string, out_type=str))
print(sp.encode(test_empty_string, out_type=int))
print(sp.decode(sp.encode(test_single_string, out_type=str)))
print(sp.decode(sp.encode(test_single_string, out_type=int)))
print(sp.decode(sp.encode(test_multiple_string, out_type=str)))
print(sp.decode(sp.encode(test_multiple_string, out_type=int)))
print(sp.decode(sp.encode(test_empty_string, out_type=str)))
print(sp.decode(sp.encode(test_empty_string, out_type=int)))
print(splitter(test_single_string))
print(splitter(test_multiple_string))
print(splitter(test_empty_string))
if __name__ == "__main__":
test_CharTokenizer()

View File

@ -0,0 +1,96 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script checks the following assumptions of the generated manifest:
- Single supervision per cut
- Supervision time bounds are within cut time bounds
We will add more checks later if needed.
Usage example:
python3 ./local/validate_manifest.py \
./data/fbank/librispeech_cuts_train-clean-100.jsonl.gz
"""
import argparse
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--manifest",
type=Path,
help="Path to the manifest file",
)
return parser.parse_args()
def validate_one_supervision_per_cut(c: Cut):
if len(c.supervisions) != 1:
raise ValueError(f"{c.id} has {len(c.supervisions)} supervisions")
def validate_supervision_and_cut_time_bounds(c: Cut):
s = c.supervisions[0]
# Removed because when the cuts were trimmed from supervisions,
# the start time of the supervision can be lesser than cut start time.
# https://github.com/lhotse-speech/lhotse/issues/813
# if s.start < c.start:
# raise ValueError(
# f"{c.id}: Supervision start time {s.start} is less "
# f"than cut start time {c.start}"
# )
if s.end > c.end:
raise ValueError(
f"{c.id}: Supervision end time {s.end} is larger "
f"than cut end time {c.end}"
)
def main():
args = get_args()
manifest = Path(args.manifest)
logging.info(f"Validating {manifest}")
assert manifest.is_file(), f"{manifest} does not exist"
cut_set = load_manifest(manifest)
assert isinstance(cut_set, CutSet)
for c in cut_set:
validate_one_supervision_per_cut(c)
validate_supervision_and_cut_time_bounds(c)
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

86
egs/reazonspeech/ASR/prepare.sh Executable file
View File

@ -0,0 +1,86 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
set -eou pipefail
nj=15
stage=-1
stop_stage=100
# We assume dl_dir (download dir) contains the following
# directories and files. If not, they will be downloaded
# by this script automatically.
#
# - $dl_dir/ReazonSpeech
# You can find FLAC files in this directory.
# You can download them from https://huggingface.co/datasets/reazon-research/reazonspeech
#
# - $dl_dir/dataset.json
# The metadata of the ReazonSpeech dataset.
dl_dir=$PWD/download
. shared/parse_options.sh || exit 1
# All files generated by this script are saved in "data".
# You can safely remove "data" and rerun this script to regenerate it.
mkdir -p data
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
log "Running prepare.sh"
log "dl_dir: $dl_dir"
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
log "Stage 0: Download data"
# If you have pre-downloaded it to /path/to/ReazonSpeech,
# you can create a symlink
#
# ln -sfv /path/to/ReazonSpeech $dl_dir/ReazonSpeech
#
if [ ! -d $dl_dir/ReazonSpeech/downloads ]; then
# Download small-v1 by default.
lhotse download reazonspeech --subset small-v1 $dl_dir
fi
fi
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
log "Stage 1: Prepare ReazonSpeech manifest"
# We assume that you have downloaded the ReazonSpeech corpus
# to $dl_dir/ReazonSpeech
mkdir -p data/manifests
if [ ! -e data/manifests/.reazonspeech.done ]; then
lhotse prepare reazonspeech -j $nj $dl_dir/ReazonSpeech data/manifests
touch data/manifests/.reazonspeech.done
fi
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
log "Stage 2: Compute ReazonSpeech fbank"
if [ ! -e data/manifests/.reazonspeech-validated.done ]; then
python local/compute_fbank_reazonspeech.py --manifest-dir data/manifests
python local/validate_manifest.py --manifest data/manifests/reazonspeech_cuts_train.jsonl.gz
python local/validate_manifest.py --manifest data/manifests/reazonspeech_cuts_dev.jsonl.gz
python local/validate_manifest.py --manifest data/manifests/reazonspeech_cuts_test.jsonl.gz
touch data/manifests/.reazonspeech-validated.done
fi
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
log "Stage 3: Prepare ReazonSpeech lang_char"
python local/prepare_lang_char.py data/manifests/reazonspeech_cuts_train.jsonl.gz
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
log "Stage 4: Show manifest statistics"
python local/display_manifest_statistics.py --manifest-dir data/manifests > data/manifests/manifest_statistics.txt
cat data/manifests/manifest_statistics.txt
fi

1
egs/reazonspeech/ASR/shared Symbolic link
View File

@ -0,0 +1 @@
../../../icefall/shared/

View File

@ -0,0 +1 @@
../local/utils/asr_datamodule.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/beam_search.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/ctc_decode.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decode_stream.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/decoder.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/encoder_interface.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export-onnx.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/export.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/generate_averaged_model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/joiner.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/model.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/my_profile.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/onnx_pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/optim.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/pretrained.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/scaling_converter.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/streaming_beam_search.py

View File

@ -0,0 +1,597 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corporation (Authors: Wei Kang, Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
./pruned_transducer_stateless7_streaming/streaming_decode.py \
--epoch 28 \
--avg 15 \
--decode-chunk-len 32 \
--exp-dir ./pruned_transducer_stateless7_streaming/exp \
--decoding_method greedy_search \
--lang data/lang_char \
--num-decode-streams 2000
"""
import argparse
import logging
import math
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import numpy as np
import torch
import torch.nn as nn
from asr_datamodule import ReazonSpeechAsrDataModule
from decode import save_results
from decode_stream import DecodeStream
from kaldifeat import Fbank, FbankOptions
from lhotse import CutSet
from streaming_beam_search import (
fast_beam_search_one_best,
greedy_search,
modified_beam_search,
)
from tokenizer import Tokenizer
from torch.nn.utils.rnn import pad_sequence
from train import add_model_arguments, get_params, get_transducer_model
from zipformer import stack_states, unstack_states
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import AttributeDict, setup_logger, str2bool
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--gpu",
type=int,
default=0,
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
help="The experiment dir",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Supported decoding methods are:
greedy_search
modified_beam_search
fast_beam_search
""",
)
parser.add_argument(
"--decoding-graph",
type=str,
default="",
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--num_active_paths",
type=int,
default=4,
help="""An interger indicating how many candidates we will keep for each
frame. Used only when --decoding-method is modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4.0,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=32,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
)
parser.add_argument(
"--num-decode-streams",
type=int,
default=2000,
help="The number of streams that can be decoded parallel.",
)
parser.add_argument(
"--res-dir",
type=Path,
default=None,
help="The path to save results.",
)
add_model_arguments(parser)
return parser
def decode_one_chunk(
params: AttributeDict,
model: nn.Module,
decode_streams: List[DecodeStream],
) -> List[int]:
"""Decode one chunk frames of features for each decode_streams and
return the indexes of finished streams in a List.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
decode_streams:
A List of DecodeStream, each belonging to a utterance.
Returns:
Return a List containing which DecodeStreams are finished.
"""
device = model.device
features = []
feature_lens = []
states = []
processed_lens = []
for stream in decode_streams:
feat, feat_len = stream.get_feature_frames(params.decode_chunk_len)
features.append(feat)
feature_lens.append(feat_len)
states.append(stream.states)
processed_lens.append(stream.done_frames)
feature_lens = torch.tensor(feature_lens, device=device)
features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS)
# We subsample features with ((x_len - 7) // 2 + 1) // 2 and the max downsampling
# factor in encoders is 8.
# After feature embedding (x_len - 7) // 2, we have (23 - 7) // 2 = 8.
tail_length = 23
if features.size(1) < tail_length:
pad_length = tail_length - features.size(1)
feature_lens += pad_length
features = torch.nn.functional.pad(
features,
(0, 0, 0, pad_length),
mode="constant",
value=LOG_EPS,
)
states = stack_states(states)
processed_lens = torch.tensor(processed_lens, device=device)
encoder_out, encoder_out_lens, new_states = model.encoder.streaming_forward(
x=features,
x_lens=feature_lens,
states=states,
)
encoder_out = model.joiner.encoder_proj(encoder_out)
if params.decoding_method == "greedy_search":
greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams)
elif params.decoding_method == "fast_beam_search":
processed_lens = processed_lens + encoder_out_lens
fast_beam_search_one_best(
model=model,
encoder_out=encoder_out,
processed_lens=processed_lens,
streams=decode_streams,
beam=params.beam,
max_states=params.max_states,
max_contexts=params.max_contexts,
)
elif params.decoding_method == "modified_beam_search":
modified_beam_search(
model=model,
streams=decode_streams,
encoder_out=encoder_out,
num_active_paths=params.num_active_paths,
)
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
states = unstack_states(new_states)
finished_streams = []
for i in range(len(decode_streams)):
decode_streams[i].states = states[i]
decode_streams[i].done_frames += encoder_out_lens[i]
if decode_streams[i].done:
finished_streams.append(i)
return finished_streams
def decode_dataset(
cuts: CutSet,
params: AttributeDict,
model: nn.Module,
sp: Tokenizer,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
cuts:
Lhotse Cutset containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
device = model.device
opts = FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = 16000
opts.mel_opts.num_bins = 80
log_interval = 50
decode_results = []
# Contain decode streams currently running.
decode_streams = []
for num, cut in enumerate(cuts):
# each utterance has a DecodeStream.
initial_states = model.encoder.get_init_state(device=device)
decode_stream = DecodeStream(
params=params,
cut_id=cut.id,
initial_states=initial_states,
decoding_graph=decoding_graph,
device=device,
)
audio: np.ndarray = cut.load_audio()
# audio.shape: (1, num_samples)
assert len(audio.shape) == 2
assert audio.shape[0] == 1, "Should be single channel"
assert audio.dtype == np.float32, audio.dtype
# The trained model is using normalized samples
assert audio.max() <= 1, "Should be normalized to [-1, 1])"
samples = torch.from_numpy(audio).squeeze(0)
fbank = Fbank(opts)
feature = fbank(samples.to(device))
decode_stream.set_features(feature, tail_pad_len=params.decode_chunk_len)
decode_stream.ground_truth = cut.supervisions[0].custom[params.transcript_mode]
decode_streams.append(decode_stream)
while len(decode_streams) >= params.num_decode_streams:
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
sp.text2word(decode_streams[i].ground_truth),
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
)
)
del decode_streams[i]
if num % log_interval == 0:
logging.info(f"Cuts processed until now is {num}.")
# decode final chunks of last sequences
while len(decode_streams):
finished_streams = decode_one_chunk(
params=params, model=model, decode_streams=decode_streams
)
for i in sorted(finished_streams, reverse=True):
decode_results.append(
(
decode_streams[i].id,
sp.text2word(decode_streams[i].ground_truth),
sp.text2word(sp.decode(decode_streams[i].decoding_result())),
)
)
del decode_streams[i]
if params.decoding_method == "greedy_search":
key = "greedy_search"
elif params.decoding_method == "fast_beam_search":
key = (
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
)
elif params.decoding_method == "modified_beam_search":
key = f"num_active_paths_{params.num_active_paths}"
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
return {key: decode_results}
@torch.no_grad()
def main():
parser = get_parser()
ReazonSpeechAsrDataModule.add_arguments(parser)
Tokenizer.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
if not params.res_dir:
params.res_dir = params.exp_dir / "streaming" / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
# for streaming
params.suffix += f"-streaming-chunk-size-{params.decode_chunk_len}"
# for fast_beam_search
if params.decoding_method == "fast_beam_search":
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", params.gpu)
logging.info(f"Device: {device}")
sp = Tokenizer.load(params.lang, params.lang_type)
# <blk> and <unk> is defined in local/prepare_lang_char.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
logging.info("About to create model")
model = get_transducer_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
model.device = device
decoding_graph = None
if params.decoding_graph:
decoding_graph = k2.Fsa.from_dict(
torch.load(params.decoding_graph, map_location=device)
)
elif params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
args.return_cuts = True
reazonspeech_corpus = ReazonSpeechAsrDataModule(args)
for subdir in ["valid"]:
results_dict = decode_dataset(
cuts=getattr(reazonspeech_corpus, f"{subdir}_cuts")(),
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
tot_err = save_results(
params=params, test_set_name=subdir, results_dict=results_dict
)
with (
params.res_dir
/ (
f"{subdir}-{params.decode_chunk_len}"
f"_{params.avg}_{params.epoch}.cer"
)
).open("w") as fout:
if len(tot_err) == 1:
fout.write(f"{tot_err[0][1]}")
else:
fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/subsampling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/test_scaling.py

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/test_subsampling.py

View File

@ -0,0 +1 @@
../local/utils/tokenizer.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../../../librispeech/ASR/zipformer/zipformer.py