mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
New recipe: tiny_transducer_ctc (#848)
* initial commit * update readme * Update README.md * change bool to str2bool for arg parser * run validation only at the end of epoch * black format * black format
This commit is contained in:
parent
161ab90dfb
commit
c970df512b
184
egs/librispeech/ASR/tiny_transducer_ctc/README.md
Normal file
184
egs/librispeech/ASR/tiny_transducer_ctc/README.md
Normal file
@ -0,0 +1,184 @@
|
||||
## Introduction
|
||||
|
||||
This recipe is intended for streaming ASR on very low cost devices, with model parameters in the range of 1-2M. It uses a small convolutional net as the encoder. It is trained with combined transducer and CTC losses, and supports both phone and BPE lexicons. For phone lexicon, you can do transducer decoding using a method with LG, but the results were bad.
|
||||
|
||||
The encoder consists of 2 subsampling layers followed by a stack of Conv1d-batchnorm-activation-causal_squeeze_excite blocks, with optional skip connections. To reduce latency (at the cost of slightly higher WER), half of the blocks use causal convolution.
|
||||
|
||||
A few remarks & observations:
|
||||
|
||||
1. Phone lexicon works better than BPE for CTC decoding (with HLG) but worse for transducer decoding.
|
||||
|
||||
2. SpecAugment is not helpful for very small models as they tend to underfit rather than overfit. For the large model, a less aggressive SpecAugment (see asr_datamodule.py) improved the result a little.
|
||||
|
||||
3. Squeeze-and-excitation worked like a charm! It reduces WER quite a bit with marginal increase of parameters and MAC ops. To make it causal I changed the global average pooling layer to a moving average filter, so only historical context is used.
|
||||
|
||||
## Pretrained models
|
||||
|
||||
You can find pretrained models, training logs, decoding logs, and decoding results at:
|
||||
<https://huggingface.co/wangtiance/tiny_transducer_ctc/tree/main>
|
||||
|
||||
## Results on full libri
|
||||
|
||||
I tried 3 different sizes of the encoder. The parameters are around 1M, 2M and 4M, respectively. For CTC decoding, whole-lattice-rescoring frequently causes OOM error so the result is not shown.
|
||||
|
||||
### Small encoder
|
||||
|
||||
The small encoder uses 10 layers of 1D convolution block with 256 channels, without skip connections. The encoder, decoder and joiner dim is 256. Algorithmic latency is 280ms. Multiply-add ops for the encoder is 22.0Mops. It is more applicable for ASR products with limited vocabulary (like a fixed set of phrases or short sentences).
|
||||
|
||||
#### CTC decoding with phone lexicon
|
||||
Total parameters: 1073392
|
||||
|
||||
Parameters for CTC decoding: 865816
|
||||
|
||||
| | test-clean | test-other | comment |
|
||||
|-----------------|------------|------------|----------------------|
|
||||
| 1best | 9.68 | 24.9 | --epoch 30 --avg 2 |
|
||||
| nbest-rescoring | 8.2 | 22.7 | --epoch 30 --avg 2 |
|
||||
|
||||
The training commands are:
|
||||
```bash
|
||||
|
||||
./tiny_transducer_ctc/train.py \
|
||||
--num-epochs 30 \
|
||||
--full-libri 1 \
|
||||
--max-duration 600 \
|
||||
--exp-dir tiny_transducer_ctc/exp_small_phone \
|
||||
--ctc-loss-scale 0.7 \
|
||||
--enable-spec-aug 0 \
|
||||
--lang-dir lang_phone \
|
||||
--encoder-dim 256 \
|
||||
--decoder-dim 256 \
|
||||
--joiner-dim 256 \
|
||||
--conv-layers 10 \
|
||||
--channels 256 \
|
||||
--skip-add 0 \
|
||||
```
|
||||
|
||||
#### Transducer decoding with BPE 500 lexicon
|
||||
Total parameters: 1623264
|
||||
|
||||
Parameters for transducer decoding: 1237764
|
||||
|
||||
| | test-clean | test-other | comment |
|
||||
|--------------------|------------|------------|----------------------|
|
||||
| greedy_search | 14.47 | 32.03 | --epoch 30 --avg 1 |
|
||||
| fast_beam_search | 13.38 | 29.61 | --epoch 30 --avg 1 |
|
||||
|modified_beam_search| 13.02 | 29.32 | --epoch 30 --avg 1 |
|
||||
|
||||
The training commands are:
|
||||
```bash
|
||||
|
||||
./tiny_transducer_ctc/train.py \
|
||||
--num-epochs 30 \
|
||||
--full-libri 1 \
|
||||
--max-duration 600 \
|
||||
--exp-dir tiny_transducer_ctc/exp_small_bpe \
|
||||
--ctc-loss-scale 0.2 \
|
||||
--enable-spec-aug 0 \
|
||||
--lang-dir lang_bpe_500 \
|
||||
--encoder-dim 256 \
|
||||
--decoder-dim 256 \
|
||||
--joiner-dim 256 \
|
||||
--conv-layers 10 \
|
||||
--channels 256 \
|
||||
--skip-add 0 \
|
||||
```
|
||||
|
||||
### Middle encoder
|
||||
|
||||
The middle encoder uses 18 layers of 1D convolution block with 300 channels, with skip connections. The encoder, decoder and joiner dim is 256. Algorithmic latency is 440ms. Multiply-add ops for the encoder is 50.1Mops. Note that the nbest-rescoring result is better than the tdnn_lstm_ctc recipe with whole-lattice-rescoring.
|
||||
|
||||
#### CTC decoding with phone lexicon
|
||||
Total parameters: 2186242
|
||||
|
||||
Parameters for CTC decoding: 1978666
|
||||
|
||||
| | test-clean | test-other | comment |
|
||||
|-----------------|------------|------------|----------------------|
|
||||
| 1best | 7.48 | 18.94 | --epoch 30 --avg 1 |
|
||||
| nbest-rescoring | 6.31 | 16.89 | --epoch 30 --avg 1 |
|
||||
|
||||
The training commands are:
|
||||
```bash
|
||||
|
||||
./tiny_transducer_ctc/train.py \
|
||||
--num-epochs 30 \
|
||||
--full-libri 1 \
|
||||
--max-duration 600 \
|
||||
--exp-dir tiny_transducer_ctc/exp_middle_phone \
|
||||
--ctc-loss-scale 0.7 \
|
||||
--enable-spec-aug 0 \
|
||||
--lang-dir lang_phone \
|
||||
--encoder-dim 256 \
|
||||
--decoder-dim 256 \
|
||||
--joiner-dim 256 \
|
||||
--conv-layers 18 \
|
||||
--channels 300 \
|
||||
--skip-add 1 \
|
||||
```
|
||||
|
||||
#### Transducer decoding with BPE 500 lexicon
|
||||
Total parameters: 2735794
|
||||
|
||||
Parameters for transducer decoding: 2350294
|
||||
|
||||
| | test-clean | test-other | comment |
|
||||
|--------------------|------------|------------|----------------------|
|
||||
| greedy_search | 10.26 | 25.13 | --epoch 30 --avg 2 |
|
||||
| fast_beam_search | 9.69 | 23.58 | --epoch 30 --avg 2 |
|
||||
|modified_beam_search| 9.43 | 23.53 | --epoch 30 --avg 2 |
|
||||
|
||||
The training commands are:
|
||||
```bash
|
||||
|
||||
./tiny_transducer_ctc/train.py \
|
||||
--num-epochs 30 \
|
||||
--full-libri 1 \
|
||||
--max-duration 600 \
|
||||
--exp-dir tiny_transducer_ctc/exp_middle_bpe \
|
||||
--ctc-loss-scale 0.2 \
|
||||
--enable-spec-aug 0 \
|
||||
--lang-dir lang_bpe_500 \
|
||||
--encoder-dim 256 \
|
||||
--decoder-dim 256 \
|
||||
--joiner-dim 256 \
|
||||
--conv-layers 18 \
|
||||
--channels 300 \
|
||||
--skip-add 1 \
|
||||
```
|
||||
|
||||
### Large encoder
|
||||
|
||||
The large encoder uses 18 layers of 1D convolution block with 400 channels, with skip connections. The encoder, decoder and joiner dim is 400. Algorithmic latency is 440ms. Multiply-add ops for the encoder is 88.8Mops. It is interesting to see how much the gap is if we simply scale down more complicated models like Zipformer or emformer.
|
||||
|
||||
|
||||
#### Transducer decoding with BPE 500 lexicon
|
||||
Total parameters: 4821330
|
||||
|
||||
Parameters for transducer decoding: 4219830
|
||||
|
||||
| | test-clean | test-other | comment |
|
||||
|--------------------|------------|------------|----------------------|
|
||||
| greedy_search | 8.29 | 21.11 | --epoch 30 --avg 1 |
|
||||
| fast_beam_search | 7.91 | 20.1 | --epoch 30 --avg 1 |
|
||||
|modified_beam_search| 7.74 | 19.89 | --epoch 30 --avg 1 |
|
||||
|
||||
|
||||
The training commands are:
|
||||
```bash
|
||||
|
||||
./tiny_transducer_ctc/train.py \
|
||||
--num-epochs 30 \
|
||||
--full-libri 1 \
|
||||
--max-duration 600 \
|
||||
--exp-dir tiny_transducer_ctc/exp_large_bpe \
|
||||
--ctc-loss-scale 0.2 \
|
||||
--enable-spec-aug 1 \
|
||||
--lang-dir lang_bpe_500 \
|
||||
--encoder-dim 400 \
|
||||
--decoder-dim 400 \
|
||||
--joiner-dim 400 \
|
||||
--conv-layers 18 \
|
||||
--channels 400 \
|
||||
--skip-add 1 \
|
||||
```
|
454
egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py
Normal file
454
egs/librispeech/ASR/tiny_transducer_ctc/asr_datamodule.py
Normal file
@ -0,0 +1,454 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import inspect
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
|
||||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
DynamicBucketingSampler,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||
AudioSamples,
|
||||
OnTheFlyFeatures,
|
||||
)
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class _SeedWorkers:
|
||||
def __init__(self, seed: int):
|
||||
self.seed = seed
|
||||
|
||||
def __call__(self, worker_id: int):
|
||||
fix_random_seed(self.seed + worker_id)
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModule:
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
def __init__(self, args: argparse.Namespace):
|
||||
self.args = args
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--manifest-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the DynamicBucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--drop-last",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to drop last batch. Used by sampler.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-spec-aug",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use SpecAugment for training dataset.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--spec-aug-time-warp-factor",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Used only when --enable-spec-aug is True. "
|
||||
"It specifies the factor for time warping in SpecAugment. "
|
||||
"Larger values mean more warping. "
|
||||
"A value less than 1 means to disable time warp.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--enable-musan",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, select noise from MUSAN and mix it"
|
||||
"with training dataset. ",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--input-strategy",
|
||||
type=str,
|
||||
default="PrecomputedFeatures",
|
||||
help="AudioSamples or PrecomputedFeatures",
|
||||
)
|
||||
|
||||
def train_dataloaders(
|
||||
self,
|
||||
cuts_train: CutSet,
|
||||
sampler_state_dict: Optional[Dict[str, Any]] = None,
|
||||
) -> DataLoader:
|
||||
"""
|
||||
Args:
|
||||
cuts_train:
|
||||
CutSet for training.
|
||||
sampler_state_dict:
|
||||
The state dict for the training sampler.
|
||||
"""
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz")
|
||||
transforms.append(
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable MUSAN")
|
||||
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = []
|
||||
if self.args.enable_spec_aug:
|
||||
logging.info("Enable SpecAugment")
|
||||
logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}")
|
||||
# Set the value of num_frame_masks according to Lhotse's version.
|
||||
# In different Lhotse's versions, the default of num_frame_masks is
|
||||
# different.
|
||||
num_frame_masks = 10
|
||||
num_frame_masks_parameter = inspect.signature(
|
||||
SpecAugment.__init__
|
||||
).parameters["num_frame_masks"]
|
||||
if num_frame_masks_parameter.default == 1:
|
||||
num_frame_masks = 2
|
||||
logging.info(f"Num frame mask: {num_frame_masks}")
|
||||
input_transforms.append(
|
||||
SpecAugment(
|
||||
time_warp_factor=self.args.spec_aug_time_warp_factor,
|
||||
num_feature_masks=2,
|
||||
features_mask_size=5,
|
||||
num_frame_masks=10,
|
||||
frames_mask_size=5,
|
||||
p=0.5,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logging.info("Disable SpecAugment")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
train = K2SpeechRecognitionDataset(
|
||||
input_strategy=eval(self.args.input_strategy)(),
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using DynamicBucketingSampler.")
|
||||
train_sampler = DynamicBucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
drop_last=self.args.drop_last,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
if sampler_state_dict is not None:
|
||||
logging.info("Loading sampler state dict")
|
||||
train_sampler.load_state_dict(sampler_state_dict)
|
||||
|
||||
# 'seed' is derived from the current random state, which will have
|
||||
# previously been set in the main process.
|
||||
seed = torch.randint(0, 100000, ()).item()
|
||||
worker_init_fn = _SeedWorkers(seed)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
worker_init_fn=worker_init_fn,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader:
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = DynamicBucketingSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||
if self.args.on_the_fly_feats
|
||||
else eval(self.args.input_strategy)(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = DynamicBucketingSampler(
|
||||
cuts,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test,
|
||||
batch_size=None,
|
||||
sampler=sampler,
|
||||
num_workers=self.args.num_workers,
|
||||
)
|
||||
return test_dl
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_100_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-100 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_clean_360_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-clean-360 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_other_500_cuts(self) -> CutSet:
|
||||
logging.info("About to get train-other-500 cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def train_all_shuf_cuts(self) -> CutSet:
|
||||
logging.info(
|
||||
"About to get the shuffled train-clean-100, \
|
||||
train-clean-360 and train-other-500 cuts"
|
||||
)
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_train-all-shuf.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def dev_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_clean_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-clean cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz"
|
||||
)
|
||||
|
||||
@lru_cache()
|
||||
def test_other_cuts(self) -> CutSet:
|
||||
logging.info("About to get test-other cuts")
|
||||
return load_manifest_lazy(
|
||||
self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz"
|
||||
)
|
1
egs/librispeech/ASR/tiny_transducer_ctc/beam_search.py
Symbolic link
1
egs/librispeech/ASR/tiny_transducer_ctc/beam_search.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/beam_search.py
|
770
egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py
Normal file
770
egs/librispeech/ASR/tiny_transducer_ctc/ctc_decode.py
Normal file
@ -0,0 +1,770 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Liyong Guo,
|
||||
# Quandong Wang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import pprint
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="tiny_transducer_ctc/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_phone",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="1best",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- (1) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||
It needs neither a lexicon nor an n-gram LM.
|
||||
- (2) 1best. Extract the best path from the decoding lattice as the
|
||||
decoding result.
|
||||
- (3) nbest. Extract n paths from the decoding lattice; the path
|
||||
with the highest score is the decoding result.
|
||||
- (4) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||
the highest score is the decoding result.
|
||||
- (5) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||
is the decoding result.
|
||||
you have trained an RNN LM using ./rnn_lm/train.py
|
||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""Number of paths for n-best based decoding method.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, and nbest-oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="""The scale to be applied to `lattice.scores`.
|
||||
It's needed if you use any kinds of n-best based rescoring.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, and nbest-oracle
|
||||
A smaller value results in more unique paths.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hlg-scale",
|
||||
type=float,
|
||||
default=0.7,
|
||||
help="""The scale to be applied to `hlg.scores`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lm-dir",
|
||||
type=str,
|
||||
default="data/lm",
|
||||
help="""The n-gram LM dir.
|
||||
It should contain either G_4_gram.pt or G_4_gram.fst.txt
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_decoding_params() -> AttributeDict:
|
||||
"""Parameters for decoding."""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"search_beam": 20,
|
||||
"output_beam": 8,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
"context_size": 2,
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
batch: dict,
|
||||
word_table: k2.SymbolTable,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if no rescoring is used, the key is the string `no_rescore`.
|
||||
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
||||
where `xxx` is the value of `lm_scale`. An example key is
|
||||
`lm_scale_0.7`
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
|
||||
- params.decoding_method is "1best", it uses 1best decoding without LM rescoring.
|
||||
- params.decoding_method is "nbest", it uses nbest decoding without LM rescoring.
|
||||
- params.decoding_method is "nbest-rescoring", it uses nbest LM rescoring.
|
||||
- params.decoding_method is "whole-lattice-rescoring", it uses whole lattice LM
|
||||
rescoring.
|
||||
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.decoding_method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.decoding_method is ctc-decoding.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
G:
|
||||
An LM. It is not None when params.decoding_method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict. Note: If it decodes to nothing, then return None.
|
||||
"""
|
||||
if HLG is not None:
|
||||
device = HLG.device
|
||||
else:
|
||||
device = H.device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, _ = model.encoder(feature, feature_lens)
|
||||
nnet_output = model.ctc_output(encoder_out)
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
torch.div(
|
||||
supervisions["start_frame"],
|
||||
params.subsampling_factor,
|
||||
rounding_mode="trunc",
|
||||
),
|
||||
torch.div(
|
||||
supervisions["num_frames"],
|
||||
params.subsampling_factor,
|
||||
rounding_mode="trunc",
|
||||
),
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
if H is None:
|
||||
assert HLG is not None
|
||||
decoding_graph = HLG
|
||||
else:
|
||||
assert HLG is None
|
||||
assert bpe_model is not None
|
||||
decoding_graph = H
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
decoding_graph=decoding_graph,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.decoding_method == "ctc-decoding":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
||||
# since we are using H, not HLG here.
|
||||
#
|
||||
# token_ids is a lit-of-list of IDs
|
||||
token_ids = get_texts(best_path)
|
||||
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
key = "ctc-decoding"
|
||||
return {key: hyps}
|
||||
|
||||
if params.decoding_method == "nbest-oracle":
|
||||
# Note: You can also pass rescored lattices to it.
|
||||
# We choose the HLG decoded lattice for speed reasons
|
||||
# as HLG decoding is faster and the oracle WER
|
||||
# is only slightly worse than that of rescored lattices.
|
||||
best_path = nbest_oracle(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
nbest_scale=params.nbest_scale,
|
||||
oov="<UNK>",
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
|
||||
return {key: hyps}
|
||||
|
||||
if params.decoding_method in ["1best", "nbest"]:
|
||||
if params.decoding_method == "1best":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
key = "no_rescore"
|
||||
else:
|
||||
best_path = nbest_decoding(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
return {key: hyps}
|
||||
|
||||
assert params.decoding_method in [
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
]
|
||||
|
||||
# lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
# lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||
# lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||
lm_scale_list = [0.6, 0.7, 0.8, 0.9]
|
||||
|
||||
if params.decoding_method == "nbest-rescoring":
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
elif params.decoding_method == "whole-lattice-rescoring":
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.decoding_method}"
|
||||
|
||||
ans = dict()
|
||||
if best_path_dict is not None:
|
||||
for lm_scale_str, best_path in best_path_dict.items():
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
ans[lm_scale_str] = hyps
|
||||
else:
|
||||
ans = None
|
||||
return ans
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
word_table: k2.SymbolTable,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.decoding_method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.decoding_method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.decoding_method is ctc-decoding.
|
||||
word_table:
|
||||
It is the word symbol table.
|
||||
G:
|
||||
An LM. It is not None when params.decoding_method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
batch=batch,
|
||||
word_table=word_table,
|
||||
G=G,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"{wer}-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
args.lm_dir = Path(args.lm_dir)
|
||||
|
||||
params = get_params()
|
||||
# add decoding params
|
||||
params.update(get_decoding_params())
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"ctc-decoding",
|
||||
"1best",
|
||||
"nbest",
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"nbest-oracle",
|
||||
)
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
params.suffix += f"-hlg-scale-{params.hlg_scale}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-uam"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
logging.info(pprint.pformat(params, indent=2))
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
|
||||
params.vocab_size = num_classes
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = 0
|
||||
|
||||
if params.decoding_method == "ctc-decoding":
|
||||
assert "lang_bpe" in str(
|
||||
params.lang_dir
|
||||
), "ctc-decoding only supports BPE lexicons."
|
||||
HLG = None
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
||||
else:
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location=device)
|
||||
)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
HLG.scores *= params.hlg_scale
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.decoding_method in (
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
):
|
||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
logging.warning("It may take 8 minutes.")
|
||||
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
# G.aux_labels is not needed in later computations, so
|
||||
# remove it here.
|
||||
del G.aux_labels
|
||||
# CAUTION: The following line is crucial.
|
||||
# Arcs entering the back-off state have label equal to #0.
|
||||
# We have to change it to 0 here.
|
||||
G.labels[G.labels >= first_word_disambig_id] = 0
|
||||
# See https://github.com/k2-fsa/k2/issues/874
|
||||
# for why we need to set G.properties to None
|
||||
G.__dict__["_properties"] = None
|
||||
G = k2.Fsa.from_fsas([G]).to(device)
|
||||
G = k2.arc_sort(G)
|
||||
# Save a dummy value so that it can be loaded in C++.
|
||||
# See https://github.com/pytorch/pytorch/issues/67902
|
||||
# for why we need to do this.
|
||||
G.dummy = 1
|
||||
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location=device)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
|
||||
if params.decoding_method == "whole-lattice-rescoring":
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
G = G.to(device)
|
||||
|
||||
# G.lm_scores is used to replace HLG.lm_scores during
|
||||
# LM rescoring.
|
||||
G.lm_scores = G.scores.clone()
|
||||
else:
|
||||
G = None
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
enc_param = sum([p.numel() for p in model.encoder.parameters()])
|
||||
ctc_param = sum([p.numel() for p in model.ctc_output.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
logging.info(f"Parameters for CTC decoding: {enc_param + ctc_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
word_table=lexicon.word_table,
|
||||
G=G,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
717
egs/librispeech/ASR/tiny_transducer_ctc/decode.py
Normal file
717
egs/librispeech/ASR/tiny_transducer_ctc/decode.py
Normal file
@ -0,0 +1,717 @@
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
import pprint
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
LOG_EPS = math.log(1e-10)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="tiny_transducer_ctc/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=Path,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir containing word table and LG graph",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="fast_beam_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
- fast_beam_search_LG
|
||||
- fast_beam_search_nbest
|
||||
- fast_beam_search_nbest_oracle
|
||||
- fast_beam_search_nbest_LG
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=20.0,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --decoding-method is fast_beam_search, fast_beam_search_LG,
|
||||
fast_beam_search_nbest, fast_beam_search_nbest_LG,
|
||||
and fast_beam_search_nbest_oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="""
|
||||
Used only when --decoding_method is fast_beam_search_LG or
|
||||
fast_beam_search_nbest_LG.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=64,
|
||||
help="""Used only when --decoding-method is
|
||||
fast_beam_search, fast_beam_search_LG, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame.
|
||||
Used only when --decoding_method is greedy_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""Number of paths for nbest decoding.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""Scale applied to lattice scores when computing nbest paths.
|
||||
Used only when the decoding method is fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if greedy_search is used, it would be "greedy_search"
|
||||
If beam search with a beam size of 7 is used, it would be
|
||||
"beam_7"
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_LG,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
hyps = []
|
||||
|
||||
if (
|
||||
params.decoding_method == "fast_beam_search"
|
||||
or params.decoding_method == "fast_beam_search_LG"
|
||||
):
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
|
||||
if params.decoding_method == "fast_beam_search":
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest_LG":
|
||||
hyp_tokens = fast_beam_search_nbest_LG(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in hyp_tokens:
|
||||
hyps.append([word_table[i] for i in hyp])
|
||||
elif params.decoding_method == "fast_beam_search_nbest":
|
||||
hyp_tokens = fast_beam_search_nbest(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "fast_beam_search_nbest_oracle":
|
||||
hyp_tokens = fast_beam_search_nbest_oracle(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=sp.encode(supervisions["text"]),
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.decoding_method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.decoding_method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported decoding method: {params.decoding_method}"
|
||||
)
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
return {"greedy_search": hyps}
|
||||
elif "fast_beam_search" in params.decoding_method:
|
||||
key = f"beam_{params.beam}_"
|
||||
key += f"max_contexts_{params.max_contexts}_"
|
||||
key += f"max_states_{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
key += f"_num_paths_{params.num_paths}_"
|
||||
key += f"nbest_scale_{params.nbest_scale}"
|
||||
if "LG" in params.decoding_method:
|
||||
key += f"_ngram_lm_scale_{params.ngram_lm_scale}"
|
||||
|
||||
return {key: hyps}
|
||||
else:
|
||||
return {f"beam_size_{params.beam_size}": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
word_table: Optional[k2.SymbolTable] = None,
|
||||
decoding_graph: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
decoding_graph:
|
||||
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
|
||||
only when --decoding_method is fast_beam_search, fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG.
|
||||
Returns:
|
||||
Return a dict, whose key may be "greedy_search" if greedy search
|
||||
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
if params.decoding_method == "greedy_search":
|
||||
log_interval = 50
|
||||
else:
|
||||
log_interval = 20
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
decoding_graph=decoding_graph,
|
||||
word_table=word_table,
|
||||
batch=batch,
|
||||
)
|
||||
|
||||
for name, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((cut_id, ref_words, hyp_words))
|
||||
|
||||
results[name].extend(this_batch)
|
||||
|
||||
num_cuts += len(texts)
|
||||
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = (
|
||||
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = (
|
||||
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=True
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.res_dir / f"{wer}-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
assert params.decoding_method in (
|
||||
"greedy_search",
|
||||
"beam_search",
|
||||
"fast_beam_search",
|
||||
"fast_beam_search_LG",
|
||||
"fast_beam_search_nbest",
|
||||
"fast_beam_search_nbest_LG",
|
||||
"fast_beam_search_nbest_oracle",
|
||||
"modified_beam_search",
|
||||
)
|
||||
if "lang_phone" in str(params.lang_dir):
|
||||
assert params.decoding_method in (
|
||||
"fast_beam_search_LG",
|
||||
"fast_beam_search_nbest_LG",
|
||||
), "For phone lexicon, use a decoding method with LG."
|
||||
|
||||
params.res_dir = params.exp_dir / params.decoding_method
|
||||
|
||||
if params.iter > 0:
|
||||
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||
else:
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
params.suffix += f"-beam-{params.beam}"
|
||||
params.suffix += f"-max-contexts-{params.max_contexts}"
|
||||
params.suffix += f"-max-states-{params.max_states}"
|
||||
if "nbest" in params.decoding_method:
|
||||
params.suffix += f"-nbest-scale-{params.nbest_scale}"
|
||||
params.suffix += f"-num-paths-{params.num_paths}"
|
||||
if "LG" in params.decoding_method:
|
||||
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"
|
||||
elif "beam_search" in params.decoding_method:
|
||||
params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}"
|
||||
else:
|
||||
params.suffix += f"-context-{params.context_size}"
|
||||
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
|
||||
|
||||
if params.use_averaged_model:
|
||||
params.suffix += "-uam"
|
||||
|
||||
setup_logger(f"{params.res_dir}/log-{params.suffix}")
|
||||
logging.info("Decoding started")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
if "lang_bpe" in str(params.lang_dir):
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(str(params.lang_dir / "bpe.model"))
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
else:
|
||||
params.blank_id = lexicon.token_table.get("<eps>")
|
||||
params.unk_id = lexicon.token_table.get("SPN")
|
||||
params.vocab_size = max(lexicon.tokens) + 1
|
||||
sp = None
|
||||
|
||||
logging.info(pprint.pformat(params, indent=2))
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
if "fast_beam_search" in params.decoding_method:
|
||||
if "LG" in params.decoding_method:
|
||||
word_table = lexicon.word_table
|
||||
lg_filename = params.lang_dir / "LG.pt"
|
||||
logging.info(f"Loading {lg_filename}")
|
||||
decoding_graph = k2.Fsa.from_dict(
|
||||
torch.load(lg_filename, map_location=device)
|
||||
)
|
||||
decoding_graph.scores *= params.ngram_lm_scale
|
||||
else:
|
||||
word_table = None
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
else:
|
||||
decoding_graph = None
|
||||
word_table = None
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
enc_param = sum([p.numel() for p in model.encoder.parameters()])
|
||||
dec_param = sum([p.numel() for p in model.decoder.parameters()])
|
||||
join_param = sum([p.numel() for p in model.joiner.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
logging.info(
|
||||
f"Parameters for transducer decoding: {enc_param + dec_param + join_param}"
|
||||
)
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
|
||||
test_clean_dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
test_other_dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
test_dl = [test_clean_dl, test_other_dl]
|
||||
|
||||
for test_set, test_dl in zip(test_sets, test_dl):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
sp=sp,
|
||||
word_table=word_table,
|
||||
decoding_graph=decoding_graph,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params,
|
||||
test_set_name=test_set,
|
||||
results_dict=results_dict,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
1
egs/librispeech/ASR/tiny_transducer_ctc/decoder.py
Symbolic link
1
egs/librispeech/ASR/tiny_transducer_ctc/decoder.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/decoder.py
|
379
egs/librispeech/ASR/tiny_transducer_ctc/encoder.py
Normal file
379
egs/librispeech/ASR/tiny_transducer_ctc/encoder.py
Normal file
@ -0,0 +1,379 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2022 Spacetouch Inc. (author: Tiance Wang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ActivationBalancer, DoubleSwish
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Conv1dNet(EncoderInterface):
|
||||
"""
|
||||
1D Convolution network with causal squeeze and excitation
|
||||
module and optional skip connections.
|
||||
|
||||
Latency: 80ms + (conv_layers+1) // 2 * 40ms, assuming 10ms stride.
|
||||
|
||||
Args:
|
||||
output_dim (int): Number of output channels of the last layer.
|
||||
input_dim (int): Number of input features
|
||||
conv_layers (int): Number of convolution layers,
|
||||
excluding the subsampling layers.
|
||||
channels (int): Number of output channels for each layer,
|
||||
except the last layer.
|
||||
subsampling_factor (int): The subsampling factor for the model.
|
||||
skip_add (bool): Whether to use skip connection for each convolution layer.
|
||||
dscnn (bool): Whether to use depthwise-separated convolution.
|
||||
activation (str): Activation function type.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_dim: int,
|
||||
input_dim: int = 80,
|
||||
conv_layers: int = 10,
|
||||
channels: int = 256,
|
||||
subsampling_factor: int = 4,
|
||||
skip_add: bool = False,
|
||||
dscnn: bool = True,
|
||||
activation: str = "relu",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert subsampling_factor == 4, "Only support subsampling = 4"
|
||||
|
||||
self.conv_layers = conv_layers
|
||||
self.skip_add = skip_add
|
||||
# 80ms latency for subsample_layer
|
||||
self.subsample_layer = nn.Sequential(
|
||||
conv1d_bn_block(
|
||||
input_dim, channels, 9, stride=2, activation=activation, dscnn=dscnn
|
||||
),
|
||||
conv1d_bn_block(
|
||||
channels, channels, 5, stride=2, activation=activation, dscnn=dscnn
|
||||
),
|
||||
)
|
||||
|
||||
self.conv_blocks = nn.ModuleList()
|
||||
cin = [channels] * conv_layers
|
||||
cout = [channels] * (conv_layers - 1) + [output_dim]
|
||||
|
||||
# Use causal and standard convolution alternatively
|
||||
for ly in range(conv_layers):
|
||||
self.conv_blocks.append(
|
||||
nn.Sequential(
|
||||
conv1d_bn_block(
|
||||
cin[ly],
|
||||
cout[ly],
|
||||
3,
|
||||
activation=activation,
|
||||
dscnn=dscnn,
|
||||
causal=ly % 2,
|
||||
),
|
||||
CausalSqueezeExcite1d(cout[ly], 16, 30),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
|
||||
x_lens:
|
||||
A tensor of shape (batch_size,) containing the number of frames in
|
||||
`x` before padding.
|
||||
Returns:
|
||||
Return a tuple containing 2 tensors:
|
||||
- embeddings: its shape is (batch_size, output_seq_len, encoder_dims)
|
||||
- lengths, a tensor of shape (batch_size,) containing the number
|
||||
of frames in `embeddings` before padding.
|
||||
"""
|
||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||
x = self.subsample_layer(x)
|
||||
for idx, layer in enumerate(self.conv_blocks):
|
||||
if self.skip_add and 0 < idx < self.conv_layers - 1:
|
||||
x = layer(x) + x
|
||||
else:
|
||||
x = layer(x)
|
||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||
lengths = x_lens >> 2
|
||||
return x, lengths
|
||||
|
||||
|
||||
def get_activation(
|
||||
name: str,
|
||||
channels: int,
|
||||
channel_dim: int = -1,
|
||||
min_val: int = 0,
|
||||
max_val: int = 1,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Get activation function from name in string.
|
||||
|
||||
Args:
|
||||
name: activation function name
|
||||
channels: only used for PReLU, should be equal to x.shape[1].
|
||||
channel_dim: the axis/dimension corresponding to the channel,
|
||||
interprted as an offset from the input's ndim if negative.
|
||||
e.g. for NCHW tensor, channel_dim = 1
|
||||
min_val: minimum value of hardtanh
|
||||
max_val: maximum value of hardtanh
|
||||
|
||||
Returns:
|
||||
The activation function module
|
||||
|
||||
"""
|
||||
act_layer = nn.Identity()
|
||||
name = name.lower()
|
||||
if name == "prelu":
|
||||
act_layer = nn.PReLU(channels)
|
||||
elif name == "relu":
|
||||
act_layer = nn.ReLU()
|
||||
elif name == "relu6":
|
||||
act_layer = nn.ReLU6()
|
||||
elif name == "hardtanh":
|
||||
act_layer = nn.Hardtanh(min_val, max_val)
|
||||
elif name in ["swish", "silu"]:
|
||||
act_layer = nn.SiLU()
|
||||
elif name == "elu":
|
||||
act_layer = nn.ELU()
|
||||
elif name == "doubleswish":
|
||||
act_layer = nn.Sequential(
|
||||
ActivationBalancer(num_channels=channels, channel_dim=channel_dim),
|
||||
DoubleSwish(),
|
||||
)
|
||||
elif name == "":
|
||||
act_layer = nn.Identity()
|
||||
else:
|
||||
raise Exception(f"Unknown activation function: {name}")
|
||||
|
||||
return act_layer
|
||||
|
||||
|
||||
class CausalSqueezeExcite1d(nn.Module):
|
||||
"""
|
||||
Causal squeeze and excitation module with input and output shape
|
||||
(batch, channels, time). The global average pooling in the original
|
||||
SE module is replaced by a causal filter, so
|
||||
the layer does not introduce any algorithmic latency.
|
||||
|
||||
Args:
|
||||
channels (int): Number of channels
|
||||
reduction (int): channel reduction rate
|
||||
context_window (int): Context window size for the moving average operation.
|
||||
For EMA, the smoothing factor is 1 / context_window.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
reduction: int = 16,
|
||||
context_window: int = 10,
|
||||
) -> None:
|
||||
super(CausalSqueezeExcite1d, self).__init__()
|
||||
|
||||
assert channels >= reduction
|
||||
|
||||
self.context_window = context_window
|
||||
c_squeeze = channels // reduction
|
||||
self.linear1 = nn.Linear(channels, c_squeeze, bias=True)
|
||||
self.act1 = nn.ReLU(inplace=True)
|
||||
self.linear2 = nn.Linear(c_squeeze, channels, bias=True)
|
||||
self.act2 = nn.Sigmoid()
|
||||
|
||||
# EMA worked better than MA empirically
|
||||
# self.avg_filter = self.moving_avg
|
||||
self.avg_filter = self.exponential_moving_avg
|
||||
self.ema_matrix = torch.tensor([0])
|
||||
self.ema_matrix_size = 0
|
||||
|
||||
def _precompute_ema_matrix(self, N: int, device: torch.device):
|
||||
a = 1.0 / self.context_window # smoothing factor
|
||||
w = [[(1 - a) ** k * a for k in range(n, n - N, -1)] for n in range(N)]
|
||||
w = torch.tensor(w).to(device).tril()
|
||||
w[:, 0] *= self.context_window
|
||||
self.ema_matrix = w.T
|
||||
self.ema_matrix_size = N
|
||||
|
||||
def exponential_moving_avg(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Exponential moving average filter, which is calculated as:
|
||||
y[t] = (1-a) * y[t-1] + a * x[t]
|
||||
where a = 1 / self.context_window is the smoothing factor.
|
||||
|
||||
For training, the iterative version is too slow. A better way is
|
||||
to expand y[t] as a function of x[0..t] only and use matrix-vector multiplication.
|
||||
The weight matrix can be precomputed if the smoothing factor is fixed.
|
||||
"""
|
||||
if self.training:
|
||||
# use matrix version to speed up training
|
||||
N = x.shape[-1]
|
||||
if N > self.ema_matrix_size:
|
||||
self._precompute_ema_matrix(N, x.device)
|
||||
y = torch.matmul(x, self.ema_matrix[:N, :N])
|
||||
else:
|
||||
# use iterative version to save memory
|
||||
a = 1.0 / self.context_window
|
||||
y = torch.empty_like(x)
|
||||
y[:, :, 0] = x[:, :, 0]
|
||||
for t in range(1, y.shape[-1]):
|
||||
y[:, :, t] = (1 - a) * y[:, :, t - 1] + a * x[:, :, t]
|
||||
return y
|
||||
|
||||
def moving_avg(self, x: Tensor) -> Tensor:
|
||||
"""
|
||||
Simple moving average with context_window as window size.
|
||||
"""
|
||||
y = torch.empty_like(x)
|
||||
k = min(x.shape[2], self.context_window)
|
||||
w = [[1 / n] * n + [0] * (k - n - 1) for n in range(1, k)]
|
||||
w = torch.tensor(w, device=x.device)
|
||||
y[:, :, : k - 1] = torch.matmul(x[:, :, : k - 1], w.T)
|
||||
y[:, :, k - 1 :] = F.avg_pool1d(x, k, 1)
|
||||
return y
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
|
||||
assert len(x.shape) == 3, "Input is not a 3D tensor!"
|
||||
y = self.exponential_moving_avg(x)
|
||||
y = y.permute(0, 2, 1) # make channel last for squeeze op
|
||||
y = self.act1(self.linear1(y))
|
||||
y = self.act2(self.linear2(y))
|
||||
y = y.permute(0, 2, 1) # back to original shape
|
||||
y = x * y
|
||||
return y
|
||||
|
||||
|
||||
def conv1d_bn_block(
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
activation: str = "relu",
|
||||
dscnn: bool = False,
|
||||
causal: bool = False,
|
||||
) -> nn.Sequential:
|
||||
"""
|
||||
Conv1d - batchnorm - activation block.
|
||||
If kernel size is even, output length = input length + 1.
|
||||
Otherwise, output and input lengths are equal.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels
|
||||
out_channels (int): Number of output channels
|
||||
kernel_size (int): kernel size
|
||||
stride (int): convolution stride
|
||||
dilation (int): convolution dilation rate
|
||||
dscnn (bool): Use depthwise separated convolution.
|
||||
causal (bool): Use causal convolution
|
||||
activation (str): Activation function type.
|
||||
|
||||
"""
|
||||
if dscnn:
|
||||
return nn.Sequential(
|
||||
CausalConv1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
groups=in_channels,
|
||||
bias=False,
|
||||
)
|
||||
if causal
|
||||
else nn.Conv1d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size // 2) * dilation,
|
||||
dilation=dilation,
|
||||
groups=in_channels,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm1d(in_channels),
|
||||
get_activation(activation, in_channels),
|
||||
nn.Conv1d(in_channels, out_channels, 1, bias=False),
|
||||
nn.BatchNorm1d(out_channels),
|
||||
get_activation(activation, out_channels),
|
||||
)
|
||||
else:
|
||||
return nn.Sequential(
|
||||
CausalConv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
dilation=dilation,
|
||||
bias=False,
|
||||
)
|
||||
if causal
|
||||
else nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size // 2) * dilation,
|
||||
dilation=dilation,
|
||||
bias=False,
|
||||
),
|
||||
nn.BatchNorm1d(out_channels),
|
||||
get_activation(activation, out_channels),
|
||||
)
|
||||
|
||||
|
||||
class CausalConv1d(nn.Module):
|
||||
"""
|
||||
Causal convolution with padding automatically chosen to match input/output length.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super(CausalConv1d, self).__init__()
|
||||
assert kernel_size > 2
|
||||
|
||||
self.padding = dilation * (kernel_size - 1)
|
||||
self.stride = stride
|
||||
|
||||
self.conv = nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
self.padding,
|
||||
dilation,
|
||||
groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.conv(x)[:, :, : -self.padding // self.stride]
|
1
egs/librispeech/ASR/tiny_transducer_ctc/encoder_interface.py
Symbolic link
1
egs/librispeech/ASR/tiny_transducer_ctc/encoder_interface.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless2/encoder_interface.py
|
316
egs/librispeech/ASR/tiny_transducer_ctc/export.py
Executable file
316
egs/librispeech/ASR/tiny_transducer_ctc/export.py
Executable file
@ -0,0 +1,316 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# This script converts several saved checkpoints
|
||||
# to a single one using model averaging.
|
||||
"""
|
||||
|
||||
Usage:
|
||||
|
||||
(1) Export to torchscript model using torch.jit.script()
|
||||
|
||||
./tiny_transducer_ctc/export.py \
|
||||
--exp-dir ./tiny_transducer_ctc/exp \
|
||||
--lang-dir data/lang_bpe_500 \
|
||||
--epoch 30 \
|
||||
--avg 2 \
|
||||
--jit 1
|
||||
|
||||
It will generate a file `cpu_jit.pt` in the given `exp_dir`. You can later
|
||||
load it by `torch.jit.load("cpu_jit.pt")`.
|
||||
|
||||
Note `cpu` in the name `cpu_jit.pt` means the parameters when loaded into Python
|
||||
are on CPU. You can use `to("cuda")` to move them to a CUDA device.
|
||||
|
||||
Check
|
||||
https://github.com/k2-fsa/sherpa
|
||||
for how to use the exported models outside of icefall.
|
||||
|
||||
(2) Export `model.state_dict()`
|
||||
|
||||
./tiny_transducer_ctc/export.py \
|
||||
--exp-dir ./tiny_transducer_ctc/exp \
|
||||
--lang-dir data/lang_bpe_500 \
|
||||
--epoch 30 \
|
||||
--avg 2
|
||||
|
||||
It will generate a file `pretrained.pt` in the given `exp_dir`. You can later
|
||||
load it by `icefall.checkpoint.load_checkpoint()`.
|
||||
|
||||
To use the generated file with `tiny_transducer_ctc/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
|
||||
cd /path/to/egs/librispeech/ASR
|
||||
./tiny_transducer_ctc/decode.py \
|
||||
--exp-dir ./tiny_transducer_ctc/exp \
|
||||
--epoch 9999 \
|
||||
--use-averaged-model 0
|
||||
--avg 1 \
|
||||
--max-duration 600 \
|
||||
--decoding-method greedy_search \
|
||||
--lang-dir data/lang_bpe_500 \
|
||||
|
||||
Check ./pretrained.py for its usage.
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from icefall.checkpoint import (
|
||||
average_checkpoints,
|
||||
average_checkpoints_with_averaged_model,
|
||||
find_checkpoints,
|
||||
load_checkpoint,
|
||||
)
|
||||
from icefall.lexicon import UniqLexicon
|
||||
from icefall.utils import str2bool
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=30,
|
||||
help="""It specifies the checkpoint to use for decoding.
|
||||
Note: Epoch counts from 1.
|
||||
You can specify --avg to use more checkpoints for model averaging.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--iter",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""If positive, --epoch is ignored and it
|
||||
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||
You can specify --avg to use more checkpoints for model averaging.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-averaged-model",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Whether to load averaged model. Currently it only supports "
|
||||
"using --epoch. If True, it would decode with the averaged model "
|
||||
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||
"Actually only the models with epoch number of `epoch-avg` and "
|
||||
"`epoch` are loaded for averaging. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="tiny_transducer_ctc/exp_4m_bpe500_halfdelay_specaug",
|
||||
help="""It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_500",
|
||||
help="""The lang dir
|
||||
It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True to save a model after applying torch.jit.script.
|
||||
It will generate a file named cpu_jit.pt
|
||||
|
||||
Check ./jit_pretrained.py for how to use it.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
if "lang_bpe" in str(params.lang_dir):
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.lang_dir + "/bpe.model")
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
else:
|
||||
assert "lang_phone" in str(params.lang_dir)
|
||||
phone_lexicon = UniqLexicon(params.lang_dir)
|
||||
params.blank_id = 0
|
||||
params.vocab_size = max(phone_lexicon.tokens) + 1
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
model.to(device)
|
||||
|
||||
if not params.use_averaged_model:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
elif params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if i >= 1:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||
else:
|
||||
if params.iter > 0:
|
||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||
: params.avg + 1
|
||||
]
|
||||
if len(filenames) == 0:
|
||||
raise ValueError(
|
||||
f"No checkpoints found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
elif len(filenames) < params.avg + 1:
|
||||
raise ValueError(
|
||||
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||
f" --iter {params.iter}, --avg {params.avg}"
|
||||
)
|
||||
filename_start = filenames[-1]
|
||||
filename_end = filenames[0]
|
||||
logging.info(
|
||||
"Calculating the averaged model over iteration checkpoints"
|
||||
f" from {filename_start} (excluded) to {filename_end}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
assert params.avg > 0, params.avg
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
|
||||
model.to("cpu")
|
||||
model.eval()
|
||||
|
||||
if params.jit is True:
|
||||
logging.info("Using torch.jit.script()")
|
||||
# We won't use the forward() method of the model in C++, so just ignore
|
||||
# it here.
|
||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
||||
# torch scriptabe.
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
logging.info("Using torch.jit.script")
|
||||
model = torch.jit.script(model)
|
||||
filename = params.exp_dir / "cpu_jit.pt"
|
||||
model.save(str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
else:
|
||||
logging.info("Not using torchscript. Export model.state_dict()")
|
||||
# Save it using a format so that it can be loaded
|
||||
# by :func:`load_checkpoint`
|
||||
filename = params.exp_dir / "pretrained.pt"
|
||||
torch.save({"model": model.state_dict()}, str(filename))
|
||||
logging.info(f"Saved to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
271
egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py
Executable file
271
egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained.py
Executable file
@ -0,0 +1,271 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models, exported by `torch.jit.script()`
|
||||
and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./tiny_transducer_ctc/export.py \
|
||||
--exp-dir ./tiny_transducer_ctc/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--jit 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./tiny_transducer_ctc/jit_pretrained.py \
|
||||
--nn-model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nn-model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the torchscript model cpu_jit.pt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float = 16000
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
def greedy_search(
|
||||
model: torch.jit.ScriptModule,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
) -> List[List[int]]:
|
||||
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C)
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,).
|
||||
Returns:
|
||||
Return the decoded results for each utterance.
|
||||
"""
|
||||
assert encoder_out.ndim == 3
|
||||
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
|
||||
device = encoder_out.device
|
||||
blank_id = 0 # hard-code to 0
|
||||
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
context_size = model.decoder.context_size
|
||||
hyps = [[blank_id] * context_size for _ in range(N)]
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
hyps,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (N, context_size)
|
||||
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
).squeeze(1)
|
||||
|
||||
offset = 0
|
||||
for batch_size in batch_size_list:
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = packed_encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out
|
||||
# current_encoder_out's shape: (batch_size, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
decoder_out = decoder_out[:batch_size]
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
)
|
||||
# logits'shape (batch_size, vocab_size)
|
||||
|
||||
assert logits.ndim == 2, logits.shape
|
||||
y = logits.argmax(dim=1).tolist()
|
||||
emitted = False
|
||||
for i, v in enumerate(y):
|
||||
if v != blank_id:
|
||||
hyps[i].append(v)
|
||||
emitted = True
|
||||
if emitted:
|
||||
# update decoder output
|
||||
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||
decoder_input = torch.tensor(
|
||||
decoder_input,
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
)
|
||||
decoder_out = model.decoder(
|
||||
decoder_input,
|
||||
need_pad=torch.tensor([False]),
|
||||
)
|
||||
decoder_out = decoder_out.squeeze(1)
|
||||
|
||||
sorted_ans = [h[context_size:] for h in hyps]
|
||||
ans = []
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
for i in range(N):
|
||||
ans.append(sorted_ans[unsorted_indices[i]])
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
logging.info(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = torch.jit.load(args.nn_model_filename)
|
||||
|
||||
model.eval()
|
||||
|
||||
model.to(device)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = 16000
|
||||
opts.mel_opts.num_bins = 80
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {args.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=args.sound_files,
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(
|
||||
features,
|
||||
batch_first=True,
|
||||
padding_value=math.log(1e-10),
|
||||
)
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
)
|
||||
|
||||
hyps = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
s = "\n"
|
||||
for filename, hyp in zip(args.sound_files, hyps):
|
||||
words = sp.decode(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
426
egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py
Executable file
426
egs/librispeech/ASR/tiny_transducer_ctc/jit_pretrained_ctc.py
Executable file
@ -0,0 +1,426 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models, exported by `torch.jit.script()`
|
||||
and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./tiny_transducer_ctc/export.py \
|
||||
--exp-dir ./tiny_transducer_ctc/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--jit 1
|
||||
|
||||
Usage of this script:
|
||||
|
||||
(1) ctc-decoding
|
||||
./tiny_transducer_ctc/jit_pretrained_ctc.py \
|
||||
--model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--method ctc-decoding \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) 1best
|
||||
./tiny_transducer_ctc/jit_pretrained_ctc.py \
|
||||
--model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--method 1best \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
|
||||
(3) nbest-rescoring
|
||||
./tiny_transducer_ctc/jit_pretrained_ctc.py \
|
||||
--model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--G data/lm/G_4_gram.pt \
|
||||
--method nbest-rescoring \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
|
||||
(4) whole-lattice-rescoring
|
||||
./tiny_transducer_ctc/jit_pretrained_ctc.py \
|
||||
--model-filename ./tiny_transducer_ctc/exp/cpu_jit.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--G data/lm/G_4_gram.pt \
|
||||
--method whole-lattice-rescoring \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from ctc_decode import get_decoding_params
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import get_params
|
||||
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
one_best_decoding,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model-filename",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the torchscript model.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--words-file",
|
||||
type=str,
|
||||
help="""Path to words.txt.
|
||||
Used only when method is not ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--HLG",
|
||||
type=str,
|
||||
help="""Path to HLG.pt.
|
||||
Used only when method is not ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.
|
||||
Used only when method is ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="1best",
|
||||
help="""Decoding method.
|
||||
Possible values are:
|
||||
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
||||
piece model, i.e., lang_dir/bpe.model, to convert
|
||||
word pieces to words. It needs neither a lexicon
|
||||
nor an n-gram LM.
|
||||
(1) 1best - Use the best path as decoding output. Only
|
||||
the transformer encoder output is used for decoding.
|
||||
We call it HLG decoding.
|
||||
(2) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an LM, the path with
|
||||
the highest score is the decoding result.
|
||||
We call it HLG decoding + n-gram LM rescoring.
|
||||
(3) whole-lattice-rescoring - Use an LM to rescore the
|
||||
decoding lattice and then use 1best to decode the
|
||||
rescored lattice.
|
||||
We call it HLG decoding + n-gram LM rescoring.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--G",
|
||||
type=str,
|
||||
help="""An LM for rescoring.
|
||||
Used only when method is
|
||||
whole-lattice-rescoring or nbest-rescoring.
|
||||
It's usually a 4-gram LM.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""
|
||||
Used only when method is attention-decoder.
|
||||
It specifies the size of n-best list.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=1.3,
|
||||
help="""
|
||||
Used only when method is whole-lattice-rescoring and nbest-rescoring.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
(Note: You need to tune it on a dataset.)
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""
|
||||
Used only when method is nbest-rescoring.
|
||||
It specifies the scale for lattice.scores when
|
||||
extracting n-best lists. A smaller value results in
|
||||
more unique number of paths with the risk of missing
|
||||
the best path.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-classes",
|
||||
type=int,
|
||||
default=500,
|
||||
help="""
|
||||
Vocab size in the BPE model.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float = 16000
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
# add decoding params
|
||||
params.update(get_decoding_params())
|
||||
params.update(vars(args))
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = torch.jit.load(args.model_filename)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
)
|
||||
nnet_output = model.ctc_output(encoder_out)
|
||||
|
||||
batch_size = nnet_output.shape[0]
|
||||
supervision_segments = torch.tensor(
|
||||
[
|
||||
[i, 0, feature_lengths[i] // params.subsampling_factor]
|
||||
for i in range(batch_size)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
logging.info("Use CTC decoding")
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(params.bpe_model)
|
||||
max_token_id = params.num_classes - 1
|
||||
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
decoding_graph=H,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
token_ids = get_texts(best_path)
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
hyps = [s.split() for s in hyps]
|
||||
elif params.method in [
|
||||
"1best",
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.method in [
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading G from {params.G}")
|
||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||
G = G.to(device)
|
||||
if params.method == "whole-lattice-rescoring":
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
# G.lm_scores is used to replace HLG.lm_scores during
|
||||
# LM rescoring.
|
||||
G.lm_scores = G.scores.clone()
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
decoding_graph=HLG,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.method == "1best":
|
||||
logging.info("Use HLG decoding")
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
if params.method == "nbest-rescoring":
|
||||
logging.info("Use HLG decoding + LM rescoring")
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=[params.ngram_lm_scale],
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
logging.info("Use HLG decoding + LM rescoring")
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=[params.ngram_lm_scale],
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/librispeech/ASR/tiny_transducer_ctc/joiner.py
Symbolic link
1
egs/librispeech/ASR/tiny_transducer_ctc/joiner.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/joiner.py
|
1
egs/librispeech/ASR/tiny_transducer_ctc/model.py
Symbolic link
1
egs/librispeech/ASR/tiny_transducer_ctc/model.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7_ctc/model.py
|
357
egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py
Executable file
357
egs/librispeech/ASR/tiny_transducer_ctc/pretrained.py
Executable file
@ -0,0 +1,357 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads a checkpoint and uses it to decode waves.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
./tiny_transducer_ctc/export.py \
|
||||
--exp-dir ./tiny_transducer_ctc/exp \
|
||||
--lang-dir data/lang_bpe_500 \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
Usage of this script:
|
||||
|
||||
(1) greedy search
|
||||
./tiny_transducer_ctc/pretrained.py \
|
||||
--checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \
|
||||
--lang-dir data/lang_bpe_500 \
|
||||
--method greedy_search \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) beam search
|
||||
./tiny_transducer_ctc/pretrained.py \
|
||||
--checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \
|
||||
--lang-dir data/lang_bpe_500 \
|
||||
--method beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) modified beam search
|
||||
./tiny_transducer_ctc/pretrained.py \
|
||||
--checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \
|
||||
--lang-dir data/lang_bpe_500 \
|
||||
--method modified_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(4) fast beam search
|
||||
./tiny_transducer_ctc/pretrained.py \
|
||||
--checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \
|
||||
--lang-dir data/lang_bpe_500 \
|
||||
--method fast_beam_search \
|
||||
--beam-size 4 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
You can also use `./tiny_transducer_ctc/exp/epoch-xx.pt`.
|
||||
|
||||
Note: ./tiny_transducer_ctc/exp/pretrained.pt is generated by
|
||||
./tiny_transducer_ctc/export.py
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
fast_beam_search_one_best,
|
||||
greedy_search,
|
||||
greedy_search_batch,
|
||||
modified_beam_search,
|
||||
)
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_500",
|
||||
help="""The lang dir
|
||||
It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="""Possible values are:
|
||||
- greedy_search
|
||||
- beam_search
|
||||
- modified_beam_search
|
||||
- fast_beam_search
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam",
|
||||
type=float,
|
||||
default=4,
|
||||
help="""A floating point value to calculate the cutoff score during beam
|
||||
search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||
`beam` in Kaldi.
|
||||
Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-contexts",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-states",
|
||||
type=int,
|
||||
default=8,
|
||||
help="""Used only when --method is fast_beam_search""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-sym-per-frame",
|
||||
type=int,
|
||||
default=1,
|
||||
help="""Maximum number of symbols per frame. Used only when
|
||||
--method is greedy_search.
|
||||
""",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert (
|
||||
sample_rate == expected_sample_rate
|
||||
), f"Expected sample rate: {expected_sample_rate}. Given: {sample_rate}"
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
|
||||
params.update(vars(args))
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.lang_dir + "/bpe.model")
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.unk_id = sp.piece_to_id("<unk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths)
|
||||
|
||||
num_waves = encoder_out.size(0)
|
||||
hyps = []
|
||||
msg = f"Using {params.method}"
|
||||
if params.method == "beam_search":
|
||||
msg += f" with beam size {params.beam_size}"
|
||||
logging.info(msg)
|
||||
|
||||
if params.method == "fast_beam_search":
|
||||
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||
hyp_tokens = fast_beam_search_one_best(
|
||||
model=model,
|
||||
decoding_graph=decoding_graph,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam,
|
||||
max_contexts=params.max_contexts,
|
||||
max_states=params.max_states,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "modified_beam_search":
|
||||
hyp_tokens = modified_beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
else:
|
||||
for i in range(num_waves):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
if params.method == "greedy_search":
|
||||
hyp = greedy_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
max_sym_per_frame=params.max_sym_per_frame,
|
||||
)
|
||||
elif params.method == "beam_search":
|
||||
hyp = beam_search(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
beam=params.beam_size,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported method: {params.method}")
|
||||
|
||||
hyps.append(sp.decode(hyp).split())
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
444
egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py
Executable file
444
egs/librispeech/ASR/tiny_transducer_ctc/pretrained_ctc.py
Executable file
@ -0,0 +1,444 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This script loads torchscript models, exported by `torch.jit.script()`
|
||||
and uses them to decode waves.
|
||||
You can use the following command to get the exported models:
|
||||
|
||||
./tiny_transducer_ctc/export.py \
|
||||
--exp-dir ./tiny_transducer_ctc/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10
|
||||
|
||||
Usage of this script:
|
||||
|
||||
(1) ctc-decoding
|
||||
./tiny_transducer_ctc/jit_pretrained_ctc.py \
|
||||
--checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \
|
||||
--lang-dir data/lang_bpe_500 \
|
||||
--method ctc-decoding \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(2) 1best
|
||||
./tiny_transducer_ctc/jit_pretrained_ctc.py \
|
||||
--checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--method 1best \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
(3) nbest-rescoring
|
||||
./tiny_transducer_ctc/jit_pretrained_ctc.py \
|
||||
--checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--G data/lm/G_4_gram.pt \
|
||||
--method nbest-rescoring \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
|
||||
|
||||
(4) whole-lattice-rescoring
|
||||
./tiny_transducer_ctc/jit_pretrained_ctc.py \
|
||||
--checkpoint ./tiny_transducer_ctc/exp/pretrained.pt \
|
||||
--HLG data/lang_bpe_500/HLG.pt \
|
||||
--words-file data/lang_bpe_500/words.txt \
|
||||
--G data/lm/G_4_gram.pt \
|
||||
--method whole-lattice-rescoring \
|
||||
--sample-rate 16000 \
|
||||
/path/to/foo.wav \
|
||||
/path/to/bar.wav
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torchaudio
|
||||
from ctc_decode import get_decoding_params
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
one_best_decoding,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.utils import get_texts
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--words-file",
|
||||
type=str,
|
||||
help="""Path to words.txt.
|
||||
Used only when method is not ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--HLG",
|
||||
type=str,
|
||||
help="""Path to HLG.pt.
|
||||
Used only when method is not ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
help="""Path to bpe.model.
|
||||
Used only when method is ctc-decoding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="1best",
|
||||
help="""Decoding method.
|
||||
Possible values are:
|
||||
(0) ctc-decoding - Use CTC decoding. It uses a sentence
|
||||
piece model, i.e., lang_dir/bpe.model, to convert
|
||||
word pieces to words. It needs neither a lexicon
|
||||
nor an n-gram LM.
|
||||
(1) 1best - Use the best path as decoding output. Only
|
||||
the transformer encoder output is used for decoding.
|
||||
We call it HLG decoding.
|
||||
(2) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an LM, the path with
|
||||
the highest score is the decoding result.
|
||||
We call it HLG decoding + n-gram LM rescoring.
|
||||
(3) whole-lattice-rescoring - Use an LM to rescore the
|
||||
decoding lattice and then use 1best to decode the
|
||||
rescored lattice.
|
||||
We call it HLG decoding + n-gram LM rescoring.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--G",
|
||||
type=str,
|
||||
help="""An LM for rescoring.
|
||||
Used only when method is
|
||||
whole-lattice-rescoring or nbest-rescoring.
|
||||
It's usually a 4-gram LM.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""
|
||||
Used only when method is attention-decoder.
|
||||
It specifies the size of n-best list.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ngram-lm-scale",
|
||||
type=float,
|
||||
default=1.3,
|
||||
help="""
|
||||
Used only when method is whole-lattice-rescoring and nbest-rescoring.
|
||||
It specifies the scale for n-gram LM scores.
|
||||
(Note: You need to tune it on a dataset.)
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""
|
||||
Used only when method is nbest-rescoring.
|
||||
It specifies the scale for lattice.scores when
|
||||
extracting n-best lists. A smaller value results in
|
||||
more unique number of paths with the risk of missing
|
||||
the best path.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-classes",
|
||||
type=int,
|
||||
default=500,
|
||||
help="""
|
||||
Vocab size in the BPE model.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="The sample rate of the input sound file",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to transcribe. "
|
||||
"Supported formats are those supported by torchaudio.load(). "
|
||||
"For example, wav and flac are supported. "
|
||||
"The sample rate has to be 16kHz.",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def read_sound_files(
|
||||
filenames: List[str], expected_sample_rate: float = 16000
|
||||
) -> List[torch.Tensor]:
|
||||
"""Read a list of sound files into a list 1-D float32 torch tensors.
|
||||
Args:
|
||||
filenames:
|
||||
A list of sound filenames.
|
||||
expected_sample_rate:
|
||||
The expected sample rate of the sound files.
|
||||
Returns:
|
||||
Return a list of 1-D float32 torch tensors.
|
||||
"""
|
||||
ans = []
|
||||
for f in filenames:
|
||||
wave, sample_rate = torchaudio.load(f)
|
||||
assert sample_rate == expected_sample_rate, (
|
||||
f"expected sample rate: {expected_sample_rate}. " f"Given: {sample_rate}"
|
||||
)
|
||||
# We use only the first channel
|
||||
ans.append(wave[0])
|
||||
return ans
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
# add decoding params
|
||||
params.update(get_decoding_params())
|
||||
params.update(vars(args))
|
||||
params.vocab_size = params.num_classes
|
||||
params.blank_id = 0
|
||||
|
||||
logging.info(f"{params}")
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
logging.info("Creating model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
logging.info("Constructing Fbank computer")
|
||||
opts = kaldifeat.FbankOptions()
|
||||
opts.device = device
|
||||
opts.frame_opts.dither = 0
|
||||
opts.frame_opts.snip_edges = False
|
||||
opts.frame_opts.samp_freq = params.sample_rate
|
||||
opts.mel_opts.num_bins = params.feature_dim
|
||||
|
||||
fbank = kaldifeat.Fbank(opts)
|
||||
|
||||
logging.info(f"Reading sound files: {params.sound_files}")
|
||||
waves = read_sound_files(
|
||||
filenames=params.sound_files, expected_sample_rate=params.sample_rate
|
||||
)
|
||||
waves = [w.to(device) for w in waves]
|
||||
|
||||
logging.info("Decoding started")
|
||||
features = fbank(waves)
|
||||
feature_lengths = [f.size(0) for f in features]
|
||||
|
||||
features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10))
|
||||
feature_lengths = torch.tensor(feature_lengths, device=device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=features,
|
||||
x_lens=feature_lengths,
|
||||
)
|
||||
nnet_output = model.ctc_output(encoder_out)
|
||||
|
||||
batch_size = nnet_output.shape[0]
|
||||
supervision_segments = torch.tensor(
|
||||
[
|
||||
[i, 0, feature_lengths[i] // params.subsampling_factor]
|
||||
for i in range(batch_size)
|
||||
],
|
||||
dtype=torch.int32,
|
||||
)
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
logging.info("Use CTC decoding")
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(params.bpe_model)
|
||||
max_token_id = params.num_classes - 1
|
||||
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
decoding_graph=H,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
token_ids = get_texts(best_path)
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
hyps = [s.split() for s in hyps]
|
||||
elif params.method in [
|
||||
"1best",
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading HLG from {params.HLG}")
|
||||
HLG = k2.Fsa.from_dict(torch.load(params.HLG, map_location="cpu"))
|
||||
HLG = HLG.to(device)
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
# For whole-lattice-rescoring and attention-decoder
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.method in [
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
]:
|
||||
logging.info(f"Loading G from {params.G}")
|
||||
G = k2.Fsa.from_dict(torch.load(params.G, map_location="cpu"))
|
||||
G = G.to(device)
|
||||
if params.method == "whole-lattice-rescoring":
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
|
||||
# G.lm_scores is used to replace HLG.lm_scores during
|
||||
# LM rescoring.
|
||||
G.lm_scores = G.scores.clone()
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
decoding_graph=HLG,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.method == "1best":
|
||||
logging.info("Use HLG decoding")
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
if params.method == "nbest-rescoring":
|
||||
logging.info("Use HLG decoding + LM rescoring")
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=[params.ngram_lm_scale],
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
logging.info("Use HLG decoding + LM rescoring")
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=[params.ngram_lm_scale],
|
||||
)
|
||||
best_path = next(iter(best_path_dict.values()))
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
word_sym_table = k2.SymbolTable.from_file(params.words_file)
|
||||
hyps = [[word_sym_table[i] for i in ids] for ids in hyps]
|
||||
else:
|
||||
raise ValueError(f"Unsupported decoding method: {params.method}")
|
||||
|
||||
s = "\n"
|
||||
for filename, hyp in zip(params.sound_files, hyps):
|
||||
words = " ".join(hyp)
|
||||
s += f"{filename}:\n{words}\n\n"
|
||||
logging.info(s)
|
||||
|
||||
logging.info("Decoding Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
1
egs/librispeech/ASR/tiny_transducer_ctc/scaling.py
Symbolic link
1
egs/librispeech/ASR/tiny_transducer_ctc/scaling.py
Symbolic link
@ -0,0 +1 @@
|
||||
../pruned_transducer_stateless7/scaling.py
|
1251
egs/librispeech/ASR/tiny_transducer_ctc/train.py
Normal file
1251
egs/librispeech/ASR/tiny_transducer_ctc/train.py
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user