mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add MMI training with word pieces as modelling unit. (#6)
* Fix an error in TDNN-LSTM training. * WIP: Refactoring * Refactor transformer.py * Remove unused code. * Minor fixes. * Fix decoder padding mask. * Add MMI training with word pieces. * Remove unused files. * Minor fixes. * Refactoring. * Minor fixes. * Use pre-computed alignments in LF-MMI training. * Minor fixes. * Update decoding script. * Add doc about how to check and use extracted alignments. * Fix style issues. * Fix typos. * Fix style issues. * Disable macOS tests for now.
This commit is contained in:
parent
4890e27b45
commit
53b79fafa7
5
.flake8
5
.flake8
@ -4,8 +4,9 @@ statistics=true
|
||||
max-line-length = 80
|
||||
per-file-ignores =
|
||||
# line too long
|
||||
egs/librispeech/ASR/conformer_ctc/conformer.py: E501,
|
||||
egs/librispeech/ASR/*/conformer.py: E501,
|
||||
|
||||
exclude =
|
||||
.git,
|
||||
**/data/**
|
||||
**/data/**,
|
||||
icefall/shared/make_kn_lm.py
|
||||
|
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@ -29,7 +29,9 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-18.04, macos-10.15]
|
||||
# os: [ubuntu-18.04, macos-10.15]
|
||||
# disable macOS test for now.
|
||||
os: [ubuntu-18.04]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
torch: ["1.8.1"]
|
||||
k2-version: ["1.9.dev20210919"]
|
||||
|
@ -27,10 +27,10 @@ avg=15
|
||||
--bucketing-sampler 0 \
|
||||
--full-libri 1 \
|
||||
--exp-dir conformer_ctc/exp \
|
||||
--lang-dir data/lang_bpe_5000 \
|
||||
--ali-dir data/ali_5000
|
||||
--lang-dir data/lang_bpe_500 \
|
||||
--ali-dir data/ali_500
|
||||
```
|
||||
and you will get four files inside the folder `data/ali_5000`:
|
||||
and you will get four files inside the folder `data/ali_500`:
|
||||
|
||||
```
|
||||
$ ls -lh data/ali_500
|
||||
@ -51,3 +51,27 @@ in `conformer_ctc/train.py`.
|
||||
Search `./conformer_ctc/asr_datamodule.py` for `preserve_id`.
|
||||
|
||||
**TODO:** Add doc about how to use the extracted alignment in the other pull-request.
|
||||
|
||||
### Step 3: Check your extracted alignments
|
||||
|
||||
There is a file `test_ali.py` in `icefall/test` that can be used to test your
|
||||
alignments. It uses pre-computed alignments to modify a randomly generated
|
||||
`nnet_output` and it checks that we can decode the correct transcripts
|
||||
from the resulting `nnet_output`.
|
||||
|
||||
You should get something like the following if you run that script:
|
||||
|
||||
```
|
||||
$ ./test/test_ali.py
|
||||
['THE GOOD NATURED AUDIENCE IN PITY TO FALLEN MAJESTY SHOWED FOR ONCE GREATER DEFERENCE TO THE KING THAN TO THE MINISTER AND SUNG THE PSALM WHICH THE FORMER HAD CALLED FOR', 'THE OLD SERVANT TOLD HIM QUIETLY AS THEY CREPT BACK TO DWELL THAT THIS PASSAGE THAT LED FROM THE HUT IN THE PLEASANCE TO SHERWOOD AND THAT GEOFFREY FOR THE TIME WAS HIDING WITH THE OUTLAWS IN THE FOREST', 'FOR A WHILE SHE LAY IN HER CHAIR IN HAPPY DREAMY PLEASURE AT SUN AND BIRD AND TREE', "BUT THE ESSENCE OF LUTHER'S LECTURES IS THERE"]
|
||||
['THE GOOD NATURED AUDIENCE IN PITY TO FALLEN MAJESTY SHOWED FOR ONCE GREATER DEFERENCE TO THE KING THAN TO THE MINISTER AND SUNG THE PSALM WHICH THE FORMER HAD CALLED FOR', 'THE OLD SERVANT TOLD HIM QUIETLY AS THEY CREPT BACK TO GAMEWELL THAT THIS PASSAGE WAY LED FROM THE HUT IN THE PLEASANCE TO SHERWOOD AND THAT GEOFFREY FOR THE TIME WAS HIDING WITH THE OUTLAWS IN THE FOREST', 'FOR A WHILE SHE LAY IN HER CHAIR IN HAPPY DREAMY PLEASURE AT SUN AND BIRD AND TREE', "BUT THE ESSENCE OF LUTHER'S LECTURES IS THERE"]
|
||||
```
|
||||
|
||||
### Step 4: Use your alignments in training
|
||||
|
||||
Please refer to `conformer_mmi/train.py` for how usage. Some useful
|
||||
functions are:
|
||||
|
||||
- `load_alignments()`, it loads alignment saved by `conformer_ctc/ali.py`
|
||||
- `convert_alignments_to_tensor()`, it converts alignments to PyTorch tensors
|
||||
- `lookup_alignments()`, it returns the alignments of utterances by giving the cut ID of the utterances.
|
||||
|
@ -129,7 +129,7 @@ def get_params() -> AttributeDict:
|
||||
"""Return a dict containing training parameters.
|
||||
|
||||
All training related parameters that are not passed from the commandline
|
||||
is saved in the variable `params`.
|
||||
are saved in the variable `params`.
|
||||
|
||||
Commandline options are merged into `params` after they are parsed, so
|
||||
you can also access them via `params`.
|
||||
|
0
egs/librispeech/ASR/conformer_mmi/__init__.py
Normal file
0
egs/librispeech/ASR/conformer_mmi/__init__.py
Normal file
356
egs/librispeech/ASR/conformer_mmi/asr_datamodule.py
Normal file
356
egs/librispeech/ASR/conformer_mmi/asr_datamodule.py
Normal file
@ -0,0 +1,356 @@
|
||||
# Copyright 2021 Piotr Żelasko
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest
|
||||
from lhotse.dataset import (
|
||||
BucketingSampler,
|
||||
CutConcatenate,
|
||||
CutMix,
|
||||
K2SpeechRecognitionDataset,
|
||||
PrecomputedFeatures,
|
||||
SingleCutSampler,
|
||||
SpecAugment,
|
||||
)
|
||||
from lhotse.dataset.input_strategies import OnTheFlyFeatures
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.dataset.datamodule import DataModule
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
class LibriSpeechAsrDataModule(DataModule):
|
||||
"""
|
||||
DataModule for k2 ASR experiments.
|
||||
It assumes there is always one train and valid dataloader,
|
||||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||
and test-other).
|
||||
|
||||
It contains all the common data pipeline modules used in ASR
|
||||
experiments, e.g.:
|
||||
- dynamic batch size,
|
||||
- bucketing samplers,
|
||||
- cut concatenation,
|
||||
- augmentation,
|
||||
- on-the-fly feature extraction
|
||||
|
||||
This class should be derived for specific corpora used in ASR tasks.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||
super().add_arguments(parser)
|
||||
group = parser.add_argument_group(
|
||||
title="ASR data related options",
|
||||
description="These options are used for the preparation of "
|
||||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||
"effective batch sizes, sampling strategies, applied data "
|
||||
"augmentations, etc.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--full-libri",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, use 960h LibriSpeech. "
|
||||
"Otherwise, use 100h subset.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--feature-dir",
|
||||
type=Path,
|
||||
default=Path("data/fbank"),
|
||||
help="Path to directory with train/valid/test cuts.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--max-duration",
|
||||
type=int,
|
||||
default=200.0,
|
||||
help="Maximum pooled recordings duration (seconds) in a "
|
||||
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--bucketing-sampler",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, the batches will come from buckets of "
|
||||
"similar duration (saves padding frames).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-buckets",
|
||||
type=int,
|
||||
default=30,
|
||||
help="The number of buckets for the BucketingSampler"
|
||||
"(you might want to increase it for larger datasets).",
|
||||
)
|
||||
group.add_argument(
|
||||
"--concatenate-cuts",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, utterances (cuts) will be concatenated "
|
||||
"to minimize the amount of padding.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--duration-factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Determines the maximum duration of a concatenated cut "
|
||||
"relative to the duration of the longest cut in a batch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gap",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="The amount of padding (in seconds) inserted between "
|
||||
"concatenated cuts. This padding is filled with noise when "
|
||||
"noise augmentation is used.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--on-the-fly-feats",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="When enabled, use on-the-fly cut mixing and feature "
|
||||
"extraction. Will drop existing precomputed feature manifests "
|
||||
"if available.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--shuffle",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled (=default), the examples will be "
|
||||
"shuffled for each epoch.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--return-cuts",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="When enabled, each batch will have the "
|
||||
"field: batch['supervisions']['cut'] with the cuts that "
|
||||
"were used to construct it.",
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The number of training dataloader workers that "
|
||||
"collect the batches.",
|
||||
)
|
||||
|
||||
def train_dataloaders(self) -> DataLoader:
|
||||
logging.info("About to get train cuts")
|
||||
cuts_train = self.train_cuts()
|
||||
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(self.args.feature_dir / "cuts_musan.json.gz")
|
||||
|
||||
logging.info("About to create train dataset")
|
||||
transforms = [
|
||||
CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True)
|
||||
]
|
||||
if self.args.concatenate_cuts:
|
||||
logging.info(
|
||||
f"Using cut concatenation with duration factor "
|
||||
f"{self.args.duration_factor} and gap {self.args.gap}."
|
||||
)
|
||||
# Cut concatenation should be the first transform in the list,
|
||||
# so that if we e.g. mix noise in, it will fill the gaps between
|
||||
# different utterances.
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
input_transforms = [
|
||||
SpecAugment(
|
||||
num_frame_masks=2,
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
)
|
||||
]
|
||||
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.on_the_fly_feats:
|
||||
# NOTE: the PerturbSpeed transform should be added only if we
|
||||
# remove it from data prep stage.
|
||||
# Add on-the-fly speed perturbation; since originally it would
|
||||
# have increased epoch size by 3, we will apply prob 2/3 and use
|
||||
# 3x more epochs.
|
||||
# Speed perturbation probably should come first before
|
||||
# concatenation, but in principle the transforms order doesn't have
|
||||
# to be strict (e.g. could be randomized)
|
||||
# transforms = [PerturbSpeed(factors=[0.9, 1.1], p=2/3)] + transforms # noqa
|
||||
# Drop feats to be on the safe side.
|
||||
train = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
input_transforms=input_transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
|
||||
if self.args.bucketing_sampler:
|
||||
logging.info("Using BucketingSampler.")
|
||||
train_sampler = BucketingSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
num_buckets=self.args.num_buckets,
|
||||
bucket_method="equal_duration",
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
logging.info("Using SingleCutSampler.")
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=self.args.shuffle,
|
||||
)
|
||||
logging.info("About to create train dataloader")
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=self.args.num_workers,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return train_dl
|
||||
|
||||
def valid_dataloaders(self) -> DataLoader:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = self.valid_cuts()
|
||||
|
||||
transforms = []
|
||||
if self.args.concatenate_cuts:
|
||||
transforms = [
|
||||
CutConcatenate(
|
||||
duration_factor=self.args.duration_factor, gap=self.args.gap
|
||||
)
|
||||
] + transforms
|
||||
|
||||
logging.info("About to create dev dataset")
|
||||
if self.args.on_the_fly_feats:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
else:
|
||||
validate = K2SpeechRecognitionDataset(
|
||||
cut_transforms=transforms,
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
valid_sampler = SingleCutSampler(
|
||||
cuts_valid,
|
||||
max_duration=self.args.max_duration,
|
||||
shuffle=False,
|
||||
)
|
||||
logging.info("About to create dev dataloader")
|
||||
valid_dl = DataLoader(
|
||||
validate,
|
||||
sampler=valid_sampler,
|
||||
batch_size=None,
|
||||
num_workers=2,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
return valid_dl
|
||||
|
||||
def test_dataloaders(self) -> Union[DataLoader, List[DataLoader]]:
|
||||
cuts = self.test_cuts()
|
||||
is_list = isinstance(cuts, list)
|
||||
test_loaders = []
|
||||
if not is_list:
|
||||
cuts = [cuts]
|
||||
|
||||
for cuts_test in cuts:
|
||||
logging.debug("About to create test dataset")
|
||||
test = K2SpeechRecognitionDataset(
|
||||
input_strategy=OnTheFlyFeatures(
|
||||
Fbank(FbankConfig(num_mel_bins=80))
|
||||
)
|
||||
if self.args.on_the_fly_feats
|
||||
else PrecomputedFeatures(),
|
||||
return_cuts=self.args.return_cuts,
|
||||
)
|
||||
sampler = SingleCutSampler(
|
||||
cuts_test, max_duration=self.args.max_duration
|
||||
)
|
||||
logging.debug("About to create test dataloader")
|
||||
test_dl = DataLoader(
|
||||
test, batch_size=None, sampler=sampler, num_workers=1
|
||||
)
|
||||
test_loaders.append(test_dl)
|
||||
|
||||
if is_list:
|
||||
return test_loaders
|
||||
else:
|
||||
return test_loaders[0]
|
||||
|
||||
@lru_cache()
|
||||
def train_cuts(self) -> CutSet:
|
||||
logging.info("About to get train cuts")
|
||||
cuts_train = load_manifest(
|
||||
self.args.feature_dir / "cuts_train-clean-100.json.gz"
|
||||
)
|
||||
if self.args.full_libri:
|
||||
cuts_train = (
|
||||
cuts_train
|
||||
+ load_manifest(
|
||||
self.args.feature_dir / "cuts_train-clean-360.json.gz"
|
||||
)
|
||||
+ load_manifest(
|
||||
self.args.feature_dir / "cuts_train-other-500.json.gz"
|
||||
)
|
||||
)
|
||||
return cuts_train
|
||||
|
||||
@lru_cache()
|
||||
def valid_cuts(self) -> CutSet:
|
||||
logging.info("About to get dev cuts")
|
||||
cuts_valid = load_manifest(
|
||||
self.args.feature_dir / "cuts_dev-clean.json.gz"
|
||||
) + load_manifest(self.args.feature_dir / "cuts_dev-other.json.gz")
|
||||
return cuts_valid
|
||||
|
||||
@lru_cache()
|
||||
def test_cuts(self) -> List[CutSet]:
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
cuts = []
|
||||
for test_set in test_sets:
|
||||
logging.debug("About to get test cuts")
|
||||
cuts.append(
|
||||
load_manifest(
|
||||
self.args.feature_dir / f"cuts_{test_set}.json.gz"
|
||||
)
|
||||
)
|
||||
return cuts
|
916
egs/librispeech/ASR/conformer_mmi/conformer.py
Normal file
916
egs/librispeech/ASR/conformer_mmi/conformer.py
Normal file
@ -0,0 +1,916 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
class Conformer(Transformer):
|
||||
"""
|
||||
Args:
|
||||
num_features (int): Number of input features
|
||||
num_classes (int): Number of output classes
|
||||
subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers)
|
||||
d_model (int): attention dimension
|
||||
nhead (int): number of head
|
||||
dim_feedforward (int): feedforward dimention
|
||||
num_encoder_layers (int): number of encoder layers
|
||||
num_decoder_layers (int): number of decoder layers
|
||||
dropout (float): dropout rate
|
||||
cnn_module_kernel (int): Kernel size of convolution module
|
||||
normalize_before (bool): whether to use layer_norm before the first block.
|
||||
vgg_frontend (bool): whether to use vgg frontend.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
num_classes: int,
|
||||
subsampling_factor: int = 4,
|
||||
d_model: int = 256,
|
||||
nhead: int = 4,
|
||||
dim_feedforward: int = 2048,
|
||||
num_encoder_layers: int = 12,
|
||||
num_decoder_layers: int = 6,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
use_feat_batchnorm: bool = False,
|
||||
) -> None:
|
||||
super(Conformer, self).__init__(
|
||||
num_features=num_features,
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=subsampling_factor,
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
num_encoder_layers=num_encoder_layers,
|
||||
num_decoder_layers=num_decoder_layers,
|
||||
dropout=dropout,
|
||||
normalize_before=normalize_before,
|
||||
vgg_frontend=vgg_frontend,
|
||||
use_feat_batchnorm=use_feat_batchnorm,
|
||||
)
|
||||
|
||||
self.encoder_pos = RelPositionalEncoding(d_model, dropout)
|
||||
|
||||
encoder_layer = ConformerEncoderLayer(
|
||||
d_model,
|
||||
nhead,
|
||||
dim_feedforward,
|
||||
dropout,
|
||||
cnn_module_kernel,
|
||||
normalize_before,
|
||||
)
|
||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
||||
self.normalize_before = normalize_before
|
||||
if self.normalize_before:
|
||||
self.after_norm = nn.LayerNorm(d_model)
|
||||
else:
|
||||
# Note: TorchScript detects that self.after_norm could be used inside forward()
|
||||
# and throws an error without this change.
|
||||
self.after_norm = identity
|
||||
|
||||
def run_encoder(
|
||||
self, x: Tensor, supervisions: Optional[Supervisions] = None
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The model input. Its shape is (N, T, C).
|
||||
supervisions:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
CAUTION: It contains length information, i.e., start and number of
|
||||
frames, before subsampling
|
||||
It is read directly from the batch, without any sorting. It is used
|
||||
to compute encoder padding mask, which is used as memory key padding
|
||||
mask for the decoder.
|
||||
|
||||
Returns:
|
||||
Tensor: Predictor tensor of dimension (input_length, batch_size, d_model).
|
||||
Tensor: Mask tensor of dimension (batch_size, input_length)
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
x, pos_emb = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||
mask = encoder_padding_mask(x.size(0), supervisions)
|
||||
if mask is not None:
|
||||
mask = mask.to(x.device)
|
||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, B, F)
|
||||
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
|
||||
return x, mask
|
||||
|
||||
|
||||
class ConformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
ConformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
|
||||
See: "Conformer: Convolution-augmented Transformer for Speech Recognition"
|
||||
|
||||
Args:
|
||||
d_model: the number of expected features in the input (required).
|
||||
nhead: the number of heads in the multiheadattention models (required).
|
||||
dim_feedforward: the dimension of the feedforward network model (default=2048).
|
||||
dropout: the dropout value (default=0.1).
|
||||
cnn_module_kernel (int): Kernel size of convolution module.
|
||||
normalize_before: whether to use layer_norm before the first block.
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> pos_emb = torch.rand(32, 19, 512)
|
||||
>>> out = encoder_layer(src, pos_emb)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
nhead: int,
|
||||
dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
normalize_before: bool = True,
|
||||
) -> None:
|
||||
super(ConformerEncoderLayer, self).__init__()
|
||||
self.self_attn = RelPositionMultiheadAttention(
|
||||
d_model, nhead, dropout=0.0
|
||||
)
|
||||
|
||||
self.feed_forward = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
Swish(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim_feedforward, d_model),
|
||||
)
|
||||
|
||||
self.feed_forward_macaron = nn.Sequential(
|
||||
nn.Linear(d_model, dim_feedforward),
|
||||
Swish(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(dim_feedforward, d_model),
|
||||
)
|
||||
|
||||
self.conv_module = ConvolutionModule(d_model, cnn_module_kernel)
|
||||
|
||||
self.norm_ff_macaron = nn.LayerNorm(
|
||||
d_model
|
||||
) # for the macaron style FNN module
|
||||
self.norm_ff = nn.LayerNorm(d_model) # for the FNN module
|
||||
self.norm_mha = nn.LayerNorm(d_model) # for the MHA module
|
||||
|
||||
self.ff_scale = 0.5
|
||||
|
||||
self.norm_conv = nn.LayerNorm(d_model) # for the CNN module
|
||||
self.norm_final = nn.LayerNorm(
|
||||
d_model
|
||||
) # for the final output of the block
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
src_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder layer (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
src_mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, N is the batch size, E is the feature number
|
||||
"""
|
||||
|
||||
# macaron style feed forward module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
src = residual + self.ff_scale * self.dropout(
|
||||
self.feed_forward_macaron(src)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
src = self.norm_ff_macaron(src)
|
||||
|
||||
# multi-headed self-attention module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_mha(src)
|
||||
src_att = self.self_attn(
|
||||
src,
|
||||
src,
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)[0]
|
||||
src = residual + self.dropout(src_att)
|
||||
if not self.normalize_before:
|
||||
src = self.norm_mha(src)
|
||||
|
||||
# convolution module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_conv(src)
|
||||
src = residual + self.dropout(self.conv_module(src))
|
||||
if not self.normalize_before:
|
||||
src = self.norm_conv(src)
|
||||
|
||||
# feed forward module
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm_ff(src)
|
||||
src = residual + self.ff_scale * self.dropout(self.feed_forward(src))
|
||||
if not self.normalize_before:
|
||||
src = self.norm_ff(src)
|
||||
|
||||
if self.normalize_before:
|
||||
src = self.norm_final(src)
|
||||
|
||||
return src
|
||||
|
||||
|
||||
class ConformerEncoder(nn.TransformerEncoder):
|
||||
r"""ConformerEncoder is a stack of N encoder layers
|
||||
|
||||
Args:
|
||||
encoder_layer: an instance of the ConformerEncoderLayer() class (required).
|
||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||
norm: the layer normalization component (optional).
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = ConformerEncoderLayer(d_model=512, nhead=8)
|
||||
>>> conformer_encoder = ConformerEncoder(encoder_layer, num_layers=6)
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> pos_emb = torch.rand(32, 19, 512)
|
||||
>>> out = conformer_encoder(src, pos_emb)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, encoder_layer: nn.Module, num_layers: int, norm: nn.Module = None
|
||||
) -> None:
|
||||
super(ConformerEncoder, self).__init__(
|
||||
encoder_layer=encoder_layer, num_layers=num_layers, norm=norm
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required).
|
||||
pos_emb: Positional embedding tensor (required).
|
||||
mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
pos_emb: (N, 2*S-1, E)
|
||||
mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
|
||||
|
||||
"""
|
||||
output = src
|
||||
|
||||
for mod in self.layers:
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
src_mask=mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
if self.norm is not None:
|
||||
output = self.norm(output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class RelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding module.
|
||||
|
||||
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
|
||||
|
||||
Args:
|
||||
d_model: Embedding dimension.
|
||||
dropout_rate: Dropout rate.
|
||||
max_len: Maximum input length.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, d_model: int, dropout_rate: float, max_len: int = 5000
|
||||
) -> None:
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(RelPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x: Tensor) -> None:
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||
# Note: TorchScript doesn't implement operator== for torch.Device
|
||||
if self.pe.dtype != x.dtype or str(self.pe.device) != str(
|
||||
x.device
|
||||
):
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||
# position of key vector. We use position relative positions when keys
|
||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||
pe_positive = torch.zeros(x.size(1), self.d_model)
|
||||
pe_negative = torch.zeros(x.size(1), self.d_model)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
||||
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
||||
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
||||
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
||||
|
||||
# Reserve the order of positive indices and concat both positive and
|
||||
# negative indices. This is used to support the shifting trick
|
||||
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[Tensor, Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2
|
||||
- x.size(1)
|
||||
+ 1 : self.pe.size(1) // 2 # noqa E203
|
||||
+ x.size(1),
|
||||
]
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
|
||||
class RelPositionMultiheadAttention(nn.Module):
|
||||
r"""Multi-Head Attention layer with relative position encoding
|
||||
|
||||
See reference: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
|
||||
|
||||
Args:
|
||||
embed_dim: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
|
||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
dropout: float = 0.0,
|
||||
) -> None:
|
||||
super(RelPositionMultiheadAttention, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.dropout = dropout
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
|
||||
self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True)
|
||||
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
|
||||
|
||||
# linear transformation for positional encoding.
|
||||
self.linear_pos = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(num_heads, self.head_dim))
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self) -> None:
|
||||
nn.init.xavier_uniform_(self.in_proj.weight)
|
||||
nn.init.constant_(self.in_proj.bias, 0.0)
|
||||
nn.init.constant_(self.out_proj.bias, 0.0)
|
||||
|
||||
nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
pos_emb: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
pos_emb: Positional embedding tensor
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. When given a binary mask and a value is True,
|
||||
the corresponding value on the attention layer will be ignored. When given
|
||||
a byte mask and a value is non-zero, the corresponding value on the attention
|
||||
layer will be ignored
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
|
||||
Shape:
|
||||
- Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- pos_emb: :math:`(N, 2*L-1, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
||||
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
|
||||
- Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
"""
|
||||
return self.multi_head_attention_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
pos_emb,
|
||||
self.embed_dim,
|
||||
self.num_heads,
|
||||
self.in_proj.weight,
|
||||
self.in_proj.bias,
|
||||
self.dropout,
|
||||
self.out_proj.weight,
|
||||
self.out_proj.bias,
|
||||
training=self.training,
|
||||
key_padding_mask=key_padding_mask,
|
||||
need_weights=need_weights,
|
||||
attn_mask=attn_mask,
|
||||
)
|
||||
|
||||
def rel_shift(self, x: Tensor) -> Tensor:
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x: Input tensor (batch, head, time1, 2*time1-1).
|
||||
time1 means the length of query vector.
|
||||
|
||||
Returns:
|
||||
Tensor: tensor of shape (batch, head, time1, time2)
|
||||
(note: time2 has the same value as time1, but it is for
|
||||
the key, while time1 is for the query).
|
||||
"""
|
||||
(batch_size, num_heads, time1, n) = x.shape
|
||||
assert n == 2 * time1 - 1
|
||||
# Note: TorchScript requires explicit arg for stride()
|
||||
batch_stride = x.stride(0)
|
||||
head_stride = x.stride(1)
|
||||
time1_stride = x.stride(2)
|
||||
n_stride = x.stride(3)
|
||||
return x.as_strided(
|
||||
(batch_size, num_heads, time1, time1),
|
||||
(batch_stride, head_stride, time1_stride - n_stride, n_stride),
|
||||
storage_offset=n_stride * (time1 - 1),
|
||||
)
|
||||
|
||||
def multi_head_attention_forward(
|
||||
self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
pos_emb: Tensor,
|
||||
embed_dim_to_check: int,
|
||||
num_heads: int,
|
||||
in_proj_weight: Tensor,
|
||||
in_proj_bias: Tensor,
|
||||
dropout_p: float,
|
||||
out_proj_weight: Tensor,
|
||||
out_proj_bias: Tensor,
|
||||
training: bool = True,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
pos_emb: Positional embedding tensor
|
||||
embed_dim_to_check: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
in_proj_weight, in_proj_bias: input projection weight and bias.
|
||||
dropout_p: probability of an element to be zeroed.
|
||||
out_proj_weight, out_proj_bias: the output projection weight and bias.
|
||||
training: apply dropout if is ``True``.
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. This is an binary mask. When the value is True,
|
||||
the corresponding value on the attention layer will be filled with -inf.
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
|
||||
Shape:
|
||||
Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension.
|
||||
- pos_emb: :math:`(N, 2*L-1, E)` or :math:`(1, 2*L-1, E)` where L is the target sequence
|
||||
length, N is the batch size, E is the embedding dimension.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the zero positions
|
||||
will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. attn_mask ensures that position i is allowed to attend the unmasked
|
||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||
are not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
|
||||
Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension.
|
||||
- attn_output_weights: :math:`(N, L, S)` where N is the batch size,
|
||||
L is the target sequence length, S is the source sequence length.
|
||||
"""
|
||||
|
||||
tgt_len, bsz, embed_dim = query.size()
|
||||
assert embed_dim == embed_dim_to_check
|
||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||
|
||||
head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
scaling = float(head_dim) ** -0.5
|
||||
|
||||
if torch.equal(query, key) and torch.equal(key, value):
|
||||
# self-attention
|
||||
q, k, v = nn.functional.linear(
|
||||
query, in_proj_weight, in_proj_bias
|
||||
).chunk(3, dim=-1)
|
||||
|
||||
elif torch.equal(key, value):
|
||||
# encoder-decoder attention
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = nn.functional.linear(query, _w, _b)
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = None
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
k, v = nn.functional.linear(key, _w, _b).chunk(2, dim=-1)
|
||||
|
||||
else:
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = 0
|
||||
_end = embed_dim
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
q = nn.functional.linear(query, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim
|
||||
_end = embed_dim * 2
|
||||
_w = in_proj_weight[_start:_end, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:_end]
|
||||
k = nn.functional.linear(key, _w, _b)
|
||||
|
||||
# This is inline in_proj function with in_proj_weight and in_proj_bias
|
||||
_b = in_proj_bias
|
||||
_start = embed_dim * 2
|
||||
_end = None
|
||||
_w = in_proj_weight[_start:, :]
|
||||
if _b is not None:
|
||||
_b = _b[_start:]
|
||||
v = nn.functional.linear(value, _w, _b)
|
||||
|
||||
if attn_mask is not None:
|
||||
assert (
|
||||
attn_mask.dtype == torch.float32
|
||||
or attn_mask.dtype == torch.float64
|
||||
or attn_mask.dtype == torch.float16
|
||||
or attn_mask.dtype == torch.uint8
|
||||
or attn_mask.dtype == torch.bool
|
||||
), "Only float, byte, and bool types are supported for attn_mask, not {}".format(
|
||||
attn_mask.dtype
|
||||
)
|
||||
if attn_mask.dtype == torch.uint8:
|
||||
warnings.warn(
|
||||
"Byte tensor for attn_mask is deprecated. Use bool tensor instead."
|
||||
)
|
||||
attn_mask = attn_mask.to(torch.bool)
|
||||
|
||||
if attn_mask.dim() == 2:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||
raise RuntimeError(
|
||||
"The size of the 2D attn_mask is not correct."
|
||||
)
|
||||
elif attn_mask.dim() == 3:
|
||||
if list(attn_mask.size()) != [
|
||||
bsz * num_heads,
|
||||
query.size(0),
|
||||
key.size(0),
|
||||
]:
|
||||
raise RuntimeError(
|
||||
"The size of the 3D attn_mask is not correct."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"attn_mask's dimension {} is not supported".format(
|
||||
attn_mask.dim()
|
||||
)
|
||||
)
|
||||
# attn_mask's dim is 3 now.
|
||||
|
||||
# convert ByteTensor key_padding_mask to bool
|
||||
if (
|
||||
key_padding_mask is not None
|
||||
and key_padding_mask.dtype == torch.uint8
|
||||
):
|
||||
warnings.warn(
|
||||
"Byte tensor for key_padding_mask is deprecated. Use bool tensor instead."
|
||||
)
|
||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||
|
||||
q = q.contiguous().view(tgt_len, bsz, num_heads, head_dim)
|
||||
k = k.contiguous().view(-1, bsz, num_heads, head_dim)
|
||||
v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
src_len = k.size(0)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz, "{} == {}".format(
|
||||
key_padding_mask.size(0), bsz
|
||||
)
|
||||
assert key_padding_mask.size(1) == src_len, "{} == {}".format(
|
||||
key_padding_mask.size(1), src_len
|
||||
)
|
||||
|
||||
q = q.transpose(0, 1) # (batch, time1, head, d_k)
|
||||
|
||||
pos_emb_bsz = pos_emb.size(0)
|
||||
assert pos_emb_bsz in (1, bsz) # actually it is 1
|
||||
p = self.linear_pos(pos_emb).view(pos_emb_bsz, -1, num_heads, head_dim)
|
||||
p = p.transpose(1, 2) # (batch, head, 2*time1-1, d_k)
|
||||
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time1, d_k)
|
||||
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time1, d_k)
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3
|
||||
k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2)
|
||||
matrix_ac = torch.matmul(
|
||||
q_with_bias_u, k
|
||||
) # (batch, head, time1, time2)
|
||||
|
||||
# compute matrix b and matrix d
|
||||
matrix_bd = torch.matmul(
|
||||
q_with_bias_v, p.transpose(-2, -1)
|
||||
) # (batch, head, time1, 2*time1-1)
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
attn_output_weights = (
|
||||
matrix_ac + matrix_bd
|
||||
) * scaling # (batch, head, time1, time2)
|
||||
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, -1
|
||||
)
|
||||
|
||||
assert list(attn_output_weights.size()) == [
|
||||
bsz * num_heads,
|
||||
tgt_len,
|
||||
src_len,
|
||||
]
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_output_weights.masked_fill_(attn_mask, float("-inf"))
|
||||
else:
|
||||
attn_output_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz, num_heads, tgt_len, src_len
|
||||
)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float("-inf"),
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz * num_heads, tgt_len, src_len
|
||||
)
|
||||
|
||||
attn_output_weights = nn.functional.softmax(attn_output_weights, dim=-1)
|
||||
attn_output_weights = nn.functional.dropout(
|
||||
attn_output_weights, p=dropout_p, training=training
|
||||
)
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim]
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1)
|
||||
.contiguous()
|
||||
.view(tgt_len, bsz, embed_dim)
|
||||
)
|
||||
attn_output = nn.functional.linear(
|
||||
attn_output, out_proj_weight, out_proj_bias
|
||||
)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
attn_output_weights = attn_output_weights.view(
|
||||
bsz, num_heads, tgt_len, src_len
|
||||
)
|
||||
return attn_output, attn_output_weights.sum(dim=1) / num_heads
|
||||
else:
|
||||
return attn_output, None
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model.
|
||||
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/conformer/convolution.py
|
||||
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernerl size of conv layers.
|
||||
bias (bool): Whether to use bias in conv layers (default=True).
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, channels: int, kernel_size: int, bias: bool = True
|
||||
) -> None:
|
||||
"""Construct an ConvolutionModule object."""
|
||||
super(ConvolutionModule, self).__init__()
|
||||
# kernerl_size should be a odd number for 'SAME' padding
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
|
||||
self.pointwise_conv1 = nn.Conv1d(
|
||||
channels,
|
||||
2 * channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.activation = Swish()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Compute convolution module.
|
||||
|
||||
Args:
|
||||
x: Input tensor (#time, batch, channels).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (#time, batch, channels).
|
||||
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.permute(1, 2, 0) # (#batch, channels, time).
|
||||
|
||||
# GLU mechanism
|
||||
x = self.pointwise_conv1(x) # (batch, 2*channels, time)
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channels, time)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
x = self.depthwise_conv(x)
|
||||
x = self.activation(self.norm(x))
|
||||
|
||||
x = self.pointwise_conv2(x) # (batch, channel, time)
|
||||
|
||||
return x.permute(2, 0, 1)
|
||||
|
||||
|
||||
class Swish(torch.nn.Module):
|
||||
"""Construct an Swish object."""
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
"""Return Swich activation function."""
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def identity(x):
|
||||
return x
|
694
egs/librispeech/ASR/conformer_mmi/decode.py
Executable file
694
egs/librispeech/ASR/conformer_mmi/decode.py
Executable file
@ -0,0 +1,694 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=34,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--method",
|
||||
type=str,
|
||||
default="attention-decoder",
|
||||
help="""Decoding method.
|
||||
Supported values are:
|
||||
- (0) ctc-decoding. Use CTC decoding. It uses a sentence piece
|
||||
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||
It needs neither a lexicon nor an n-gram LM.
|
||||
- (1) 1best. Extract the best path from the decoding lattice as the
|
||||
decoding result.
|
||||
- (2) nbest. Extract n paths from the decoding lattice; the path
|
||||
with the highest score is the decoding result.
|
||||
- (3) nbest-rescoring. Extract n paths from the decoding lattice,
|
||||
rescore them with an n-gram LM (e.g., a 4-gram LM), the path with
|
||||
the highest score is the decoding result.
|
||||
- (4) whole-lattice-rescoring. Rescore the decoding lattice with an
|
||||
n-gram LM (e.g., a 4-gram LM), the best path of rescored lattice
|
||||
is the decoding result.
|
||||
- (5) attention-decoder. Extract n paths from the LM rescored
|
||||
lattice, the path with the highest score is the decoding result.
|
||||
- (6) nbest-oracle. Its WER is the lower bound of any n-best
|
||||
rescoring method can achieve. Useful for debugging n-best
|
||||
rescoring method.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-paths",
|
||||
type=int,
|
||||
default=100,
|
||||
help="""Number of paths for n-best based decoding method.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nbest-scale",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="""The scale to be applied to `lattice.scores`.
|
||||
It's needed if you use any kinds of n-best based rescoring.
|
||||
Used only when "method" is one of the following values:
|
||||
nbest, nbest-rescoring, attention-decoder, and nbest-oracle
|
||||
A smaller value results in more unique paths.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--export",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""When enabled, the averaged model is saved to
|
||||
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
|
||||
pretrained.pt contains a dict {"model": model.state_dict()},
|
||||
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="conformer_mmi/exp_500",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
default="data/lang_bpe_500",
|
||||
help="The lang dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-decoder-layers",
|
||||
type=int,
|
||||
default=6,
|
||||
help="Number of attention decoder layers",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"lm_dir": Path("data/lm"),
|
||||
# parameters for conformer
|
||||
"subsampling_factor": 4,
|
||||
"vgg_frontend": False,
|
||||
"use_feat_batchnorm": True,
|
||||
"feature_dim": 80,
|
||||
"nhead": 8,
|
||||
"attention_dim": 512,
|
||||
# parameters for decoding
|
||||
"search_beam": 20,
|
||||
"output_beam": 8,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
batch: dict,
|
||||
word_table: k2.SymbolTable,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[str]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if no rescoring is used, the key is the string `no_rescore`.
|
||||
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
||||
where `xxx` is the value of `lm_scale`. An example key is
|
||||
`lm_scale_0.7`
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
|
||||
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
||||
- params.method is "nbest", it uses nbest decoding without LM rescoring.
|
||||
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
|
||||
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
|
||||
rescoring.
|
||||
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.method is ctc-decoding.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
sos_id:
|
||||
The token ID of the SOS.
|
||||
eos_id:
|
||||
The token ID of the EOS.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
if HLG is not None:
|
||||
device = HLG.device
|
||||
else:
|
||||
device = H.device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
nnet_output, memory, memory_key_padding_mask = model(feature, supervisions)
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
supervisions["start_frame"] // params.subsampling_factor,
|
||||
supervisions["num_frames"] // params.subsampling_factor,
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
if H is None:
|
||||
assert HLG is not None
|
||||
decoding_graph = HLG
|
||||
else:
|
||||
assert HLG is None
|
||||
assert bpe_model is not None
|
||||
decoding_graph = H
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
decoding_graph=decoding_graph,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
# Note: `best_path.aux_labels` contains token IDs, not word IDs
|
||||
# since we are using H, not HLG here.
|
||||
#
|
||||
# token_ids is a lit-of-list of IDs
|
||||
token_ids = get_texts(best_path)
|
||||
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
key = "ctc-decoding"
|
||||
return {key: hyps}
|
||||
|
||||
if params.method == "nbest-oracle":
|
||||
# Note: You can also pass rescored lattices to it.
|
||||
# We choose the HLG decoded lattice for speed reasons
|
||||
# as HLG decoding is faster and the oracle WER
|
||||
# is only slightly worse than that of rescored lattices.
|
||||
best_path = nbest_oracle(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
ref_texts=supervisions["text"],
|
||||
word_table=word_table,
|
||||
nbest_scale=params.nbest_scale,
|
||||
oov="<UNK>",
|
||||
)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
key = f"oracle_{params.num_paths}_nbest_scale_{params.nbest_scale}" # noqa
|
||||
return {key: hyps}
|
||||
|
||||
if params.method in ["1best", "nbest"]:
|
||||
if params.method == "1best":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
key = "no_rescore"
|
||||
else:
|
||||
best_path = nbest_decoding(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
key = f"no_rescore-nbest-scale-{params.nbest_scale}-{params.num_paths}" # noqa
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
return {key: hyps}
|
||||
|
||||
assert params.method in [
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
]
|
||||
|
||||
lm_scale_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
||||
lm_scale_list += [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||
|
||||
if params.method == "nbest-rescoring":
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
elif params.method == "whole-lattice-rescoring":
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
elif params.method == "attention-decoder":
|
||||
# lattice uses a 3-gram Lm. We rescore it with a 4-gram LM.
|
||||
rescored_lattice = rescore_with_whole_lattice(
|
||||
lattice=lattice,
|
||||
G_with_epsilon_loops=G,
|
||||
lm_scale_list=None,
|
||||
)
|
||||
# TODO: pass `lattice` instead of `rescored_lattice` to
|
||||
# `rescore_with_attention_decoder`
|
||||
|
||||
best_path_dict = rescore_with_attention_decoder(
|
||||
lattice=rescored_lattice,
|
||||
num_paths=params.num_paths,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
nbest_scale=params.nbest_scale,
|
||||
)
|
||||
else:
|
||||
assert False, f"Unsupported decoding method: {params.method}"
|
||||
|
||||
ans = dict()
|
||||
if best_path_dict is not None:
|
||||
for lm_scale_str, best_path in best_path_dict.items():
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
ans[lm_scale_str] = hyps
|
||||
else:
|
||||
for lm_scale in lm_scale_list:
|
||||
ans["empty"] = [[] * lattice.shape[0]]
|
||||
return ans
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: Optional[k2.Fsa],
|
||||
H: Optional[k2.Fsa],
|
||||
bpe_model: Optional[spm.SentencePieceProcessor],
|
||||
word_table: k2.SymbolTable,
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used only when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.method is ctc-decoding.
|
||||
bpe_model:
|
||||
The BPE model. Used only when params.method is ctc-decoding.
|
||||
word_table:
|
||||
It is the word symbol table.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
batch=batch,
|
||||
word_table=word_table,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
num_cuts += len(batch["supervisions"]["text"])
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
if params.method == "attention-decoder":
|
||||
# Set it to False since there are too many logs.
|
||||
enable_log = False
|
||||
else:
|
||||
enable_log = True
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=enable_log
|
||||
)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
if enable_log:
|
||||
logging.info(
|
||||
"Wrote detailed error stats to {}".format(errs_filename)
|
||||
)
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.lang_dir = Path(args.lang_dir)
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
graph_compiler = BpeCtcTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
device=device,
|
||||
sos_token="<sos/eos>",
|
||||
eos_token="<sos/eos>",
|
||||
)
|
||||
sos_id = graph_compiler.sos_id
|
||||
eos_id = graph_compiler.eos_id
|
||||
|
||||
if params.method == "ctc-decoding":
|
||||
HLG = None
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
modified=False,
|
||||
device=device,
|
||||
)
|
||||
bpe_model = spm.SentencePieceProcessor()
|
||||
bpe_model.load(str(params.lang_dir / "bpe.model"))
|
||||
else:
|
||||
H = None
|
||||
bpe_model = None
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
||||
)
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.method in (
|
||||
"nbest-rescoring",
|
||||
"whole-lattice-rescoring",
|
||||
"attention-decoder",
|
||||
):
|
||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
logging.warning("It may take 8 minutes.")
|
||||
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
||||
first_word_disambig_id = lexicon.word_table["#0"]
|
||||
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
# G.aux_labels is not needed in later computations, so
|
||||
# remove it here.
|
||||
del G.aux_labels
|
||||
# CAUTION: The following line is crucial.
|
||||
# Arcs entering the back-off state have label equal to #0.
|
||||
# We have to change it to 0 here.
|
||||
G.labels[G.labels >= first_word_disambig_id] = 0
|
||||
G = k2.Fsa.from_fsas([G]).to(device)
|
||||
G = k2.arc_sort(G)
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt", map_location="cpu")
|
||||
G = k2.Fsa.from_dict(d).to(device)
|
||||
|
||||
if params.method in ["whole-lattice-rescoring", "attention-decoder"]:
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
G = G.to(device)
|
||||
|
||||
# G.lm_scores is used to replace HLG.lm_scores during
|
||||
# LM rescoring.
|
||||
G.lm_scores = G.scores.clone()
|
||||
else:
|
||||
G = None
|
||||
|
||||
model = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
nhead=params.nhead,
|
||||
d_model=params.attention_dim,
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
|
||||
if params.export:
|
||||
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
|
||||
torch.save(
|
||||
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
|
||||
)
|
||||
return
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
# CAUTION: `test_sets` is for displaying only.
|
||||
# If you want to skip test-clean, you have to skip
|
||||
# it inside the for loop. That is, use
|
||||
#
|
||||
# if test_set == 'test-clean': continue
|
||||
#
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
H=H,
|
||||
bpe_model=bpe_model,
|
||||
word_table=lexicon.word_table,
|
||||
G=G,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params, test_set_name=test_set, results_dict=results_dict
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
144
egs/librispeech/ASR/conformer_mmi/subsampling.py
Normal file
144
egs/librispeech/ASR/conformer_mmi/subsampling.py
Normal file
@ -0,0 +1,144 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class Conv2dSubsampling(nn.Module):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
Convert an input of shape [N, T, idim] to an output
|
||||
with shape [N, T', odim], where
|
||||
T' = ((T-1)//2 - 1)//2, which approximates T' == T//4
|
||||
|
||||
It is based on
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa
|
||||
"""
|
||||
|
||||
def __init__(self, idim: int, odim: int) -> None:
|
||||
"""
|
||||
Args:
|
||||
idim:
|
||||
Input dim. The input shape is [N, T, idim].
|
||||
Caution: It requires: T >=7, idim >=7
|
||||
odim:
|
||||
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
|
||||
"""
|
||||
assert idim >= 7
|
||||
super().__init__()
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=1, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(
|
||||
in_channels=odim, out_channels=odim, kernel_size=3, stride=2
|
||||
),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is [N, T, idim].
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
|
||||
"""
|
||||
# On entry, x is [N, T, idim]
|
||||
x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W]
|
||||
x = self.conv(x)
|
||||
# Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2]
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
# Now x is of shape [N, ((T-1)//2 - 1))//2, odim]
|
||||
return x
|
||||
|
||||
|
||||
class VggSubsampling(nn.Module):
|
||||
"""Trying to follow the setup described in the following paper:
|
||||
https://arxiv.org/pdf/1910.09799.pdf
|
||||
|
||||
This paper is not 100% explicit so I am guessing to some extent,
|
||||
and trying to compare with other VGG implementations.
|
||||
|
||||
Convert an input of shape [N, T, idim] to an output
|
||||
with shape [N, T', odim], where
|
||||
T' = ((T-1)//2 - 1)//2, which approximates T' = T//4
|
||||
"""
|
||||
|
||||
def __init__(self, idim: int, odim: int) -> None:
|
||||
"""Construct a VggSubsampling object.
|
||||
|
||||
This uses 2 VGG blocks with 2 Conv2d layers each,
|
||||
subsampling its input by a factor of 4 in the time dimensions.
|
||||
|
||||
Args:
|
||||
idim:
|
||||
Input dim. The input shape is [N, T, idim].
|
||||
Caution: It requires: T >=7, idim >=7
|
||||
odim:
|
||||
Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim]
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
cur_channels = 1
|
||||
layers = []
|
||||
block_dims = [32, 64]
|
||||
|
||||
# The decision to use padding=1 for the 1st convolution, then padding=0
|
||||
# for the 2nd and for the max-pooling, and ceil_mode=True, was driven by
|
||||
# a back-compatibility concern so that the number of frames at the
|
||||
# output would be equal to:
|
||||
# (((T-1)//2)-1)//2.
|
||||
# We can consider changing this by using padding=1 on the
|
||||
# 2nd convolution, so the num-frames at the output would be T//4.
|
||||
for block_dim in block_dims:
|
||||
layers.append(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=cur_channels,
|
||||
out_channels=block_dim,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=1,
|
||||
)
|
||||
)
|
||||
layers.append(torch.nn.ReLU())
|
||||
layers.append(
|
||||
torch.nn.Conv2d(
|
||||
in_channels=block_dim,
|
||||
out_channels=block_dim,
|
||||
kernel_size=3,
|
||||
padding=0,
|
||||
stride=1,
|
||||
)
|
||||
)
|
||||
layers.append(
|
||||
torch.nn.MaxPool2d(
|
||||
kernel_size=2, stride=2, padding=0, ceil_mode=True
|
||||
)
|
||||
)
|
||||
cur_channels = block_dim
|
||||
|
||||
self.layers = nn.Sequential(*layers)
|
||||
|
||||
self.out = nn.Linear(
|
||||
block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is [N, T, idim].
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim]
|
||||
"""
|
||||
x = x.unsqueeze(1)
|
||||
x = self.layers(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
return x
|
33
egs/librispeech/ASR/conformer_mmi/test_subsampling.py
Executable file
33
egs/librispeech/ASR/conformer_mmi/test_subsampling.py
Executable file
@ -0,0 +1,33 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
from subsampling import Conv2dSubsampling
|
||||
from subsampling import VggSubsampling
|
||||
import torch
|
||||
|
||||
|
||||
def test_conv2d_subsampling():
|
||||
N = 3
|
||||
odim = 2
|
||||
|
||||
for T in range(7, 19):
|
||||
for idim in range(7, 20):
|
||||
model = Conv2dSubsampling(idim=idim, odim=odim)
|
||||
x = torch.empty(N, T, idim)
|
||||
y = model(x)
|
||||
assert y.shape[0] == N
|
||||
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
|
||||
assert y.shape[2] == odim
|
||||
|
||||
|
||||
def test_vgg_subsampling():
|
||||
N = 3
|
||||
odim = 2
|
||||
|
||||
for T in range(7, 19):
|
||||
for idim in range(7, 20):
|
||||
model = VggSubsampling(idim=idim, odim=odim)
|
||||
x = torch.empty(N, T, idim)
|
||||
y = model(x)
|
||||
assert y.shape[0] == N
|
||||
assert y.shape[1] == ((T - 1) // 2 - 1) // 2
|
||||
assert y.shape[2] == odim
|
89
egs/librispeech/ASR/conformer_mmi/test_transformer.py
Normal file
89
egs/librispeech/ASR/conformer_mmi/test_transformer.py
Normal file
@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import torch
|
||||
from transformer import (
|
||||
Transformer,
|
||||
encoder_padding_mask,
|
||||
generate_square_subsequent_mask,
|
||||
decoder_padding_mask,
|
||||
add_sos,
|
||||
add_eos,
|
||||
)
|
||||
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def test_encoder_padding_mask():
|
||||
supervisions = {
|
||||
"sequence_idx": torch.tensor([0, 1, 2]),
|
||||
"start_frame": torch.tensor([0, 0, 0]),
|
||||
"num_frames": torch.tensor([18, 7, 13]),
|
||||
}
|
||||
|
||||
max_len = ((18 - 1) // 2 - 1) // 2
|
||||
mask = encoder_padding_mask(max_len, supervisions)
|
||||
expected_mask = torch.tensor(
|
||||
[
|
||||
[False, False, False], # ((18 - 1)//2 - 1)//2 = 3,
|
||||
[False, True, True], # ((7 - 1)//2 - 1)//2 = 1,
|
||||
[False, False, True], # ((13 - 1)//2 - 1)//2 = 2,
|
||||
]
|
||||
)
|
||||
assert torch.all(torch.eq(mask, expected_mask))
|
||||
|
||||
|
||||
def test_transformer():
|
||||
num_features = 40
|
||||
num_classes = 87
|
||||
model = Transformer(num_features=num_features, num_classes=num_classes)
|
||||
|
||||
N = 31
|
||||
|
||||
for T in range(7, 30):
|
||||
x = torch.rand(N, T, num_features)
|
||||
y, _, _ = model(x)
|
||||
assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes)
|
||||
|
||||
|
||||
def test_generate_square_subsequent_mask():
|
||||
s = 5
|
||||
mask = generate_square_subsequent_mask(s)
|
||||
inf = float("inf")
|
||||
expected_mask = torch.tensor(
|
||||
[
|
||||
[0.0, -inf, -inf, -inf, -inf],
|
||||
[0.0, 0.0, -inf, -inf, -inf],
|
||||
[0.0, 0.0, 0.0, -inf, -inf],
|
||||
[0.0, 0.0, 0.0, 0.0, -inf],
|
||||
[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||
]
|
||||
)
|
||||
assert torch.all(torch.eq(mask, expected_mask))
|
||||
|
||||
|
||||
def test_decoder_padding_mask():
|
||||
x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])]
|
||||
y = pad_sequence(x, batch_first=True, padding_value=-1)
|
||||
mask = decoder_padding_mask(y, ignore_id=-1)
|
||||
expected_mask = torch.tensor(
|
||||
[
|
||||
[False, False, True],
|
||||
[False, True, True],
|
||||
[False, False, False],
|
||||
]
|
||||
)
|
||||
assert torch.all(torch.eq(mask, expected_mask))
|
||||
|
||||
|
||||
def test_add_sos():
|
||||
x = [[1, 2], [3], [2, 5, 8]]
|
||||
y = add_sos(x, sos_id=0)
|
||||
expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]]
|
||||
assert y == expected_y
|
||||
|
||||
|
||||
def test_add_eos():
|
||||
x = [[1, 2], [3], [2, 5, 8]]
|
||||
y = add_eos(x, eos_id=0)
|
||||
expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]]
|
||||
assert y == expected_y
|
837
egs/librispeech/ASR/conformer_mmi/train-with-attention.py
Executable file
837
egs/librispeech/ASR/conformer_mmi/train-with-attention.py
Executable file
@ -0,0 +1,837 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Dict, Optional
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
|
||||
from icefall.ali import (
|
||||
convert_alignments_to_tensor,
|
||||
load_alignments,
|
||||
lookup_alignments,
|
||||
)
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.mmi import LFMMILoss
|
||||
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--master-port",
|
||||
type=int,
|
||||
default=12354,
|
||||
help="Master port to use for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tensorboard",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""Resume training from from this epoch.
|
||||
If it is positive, it will load checkpoint from
|
||||
conformer_mmi/exp/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ali-dir",
|
||||
type=str,
|
||||
default="data/ali_500",
|
||||
help="""This folder is expected to contain
|
||||
two files, train-960.pt and valid.pt, which
|
||||
contain framewise alignment information for
|
||||
the training set and validation set.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
"""Return a dict containing training parameters.
|
||||
|
||||
All training related parameters that are not passed from the commandline
|
||||
are saved in the variable `params`.
|
||||
|
||||
Commandline options are merged into `params` after they are parsed, so
|
||||
you can also access them via `params`.
|
||||
|
||||
Explanation of options saved in `params`:
|
||||
|
||||
- exp_dir: It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
||||
- lang_dir: It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
|
||||
- best_train_loss: Best training loss so far. It is used to select
|
||||
the model that has the lowest training loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_valid_loss: Best validation loss so far. It is used to select
|
||||
the model that has the lowest validation loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_train_epoch: It is the epoch that has the best training loss.
|
||||
|
||||
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||
|
||||
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||
contains number of batches trained so far across
|
||||
epochs.
|
||||
|
||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||
|
||||
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||
|
||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||
|
||||
- feature_dim: The model input dim. It has to match the one used
|
||||
in computing features.
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- use_feat_batchnorm: Whether to do batch normalization for the
|
||||
input features.
|
||||
|
||||
- attention_dim: Hidden dim for multi-head attention model.
|
||||
|
||||
- head: Number of heads of multi-head attention model.
|
||||
|
||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||
|
||||
- weight_decay: The weight_decay for the optimizer.
|
||||
|
||||
- lr_factor: The lr_factor for Noam optimizer.
|
||||
|
||||
- warm_step: The warm_step for Noam optimizer.
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("conformer_mmi/exp_500_with_attention"),
|
||||
"lang_dir": Path("data/lang_bpe_500"),
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000,
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"use_feat_batchnorm": True,
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
# parameters for loss
|
||||
"beam_size": 6, # will change it to 8 after some batches (see code)
|
||||
"reduction": "sum",
|
||||
"use_double_scores": True,
|
||||
# "att_rate": 0.0,
|
||||
# "num_decoder_layers": 0,
|
||||
"att_rate": 0.7,
|
||||
"num_decoder_layers": 6,
|
||||
# parameters for Noam
|
||||
"weight_decay": 1e-6,
|
||||
"lr_factor": 5.0,
|
||||
"warm_step": 80000,
|
||||
"use_pruned_intersect": False,
|
||||
"den_scale": 1.0,
|
||||
# use alignments before this number of batches
|
||||
"use_ali_until": 13000,
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
) -> None:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_epoch is positive, it will load the checkpoint from
|
||||
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
||||
|
||||
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
||||
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
and `best_valid_loss` in `params`.
|
||||
|
||||
Args:
|
||||
params:
|
||||
The return value of :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
optimizer:
|
||||
The optimizer that we are using.
|
||||
scheduler:
|
||||
The learning rate scheduler we are using.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
if params.start_epoch <= 0:
|
||||
return
|
||||
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
keys = [
|
||||
"best_train_epoch",
|
||||
"best_valid_epoch",
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
save_checkpoint_impl(
|
||||
filename=filename,
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
batch: dict,
|
||||
graph_compiler: MmiTrainingGraphCompiler,
|
||||
is_training: bool,
|
||||
ali: Optional[Dict[str, torch.Tensor]],
|
||||
):
|
||||
"""
|
||||
Compute LF-MMI loss given the model and its inputs.
|
||||
|
||||
Args:
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
model:
|
||||
The model for training. It is an instance of Conformer in our case.
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
graph_compiler:
|
||||
It is used to build a decoding graph from a ctc topo and training
|
||||
transcript. The training transcript is contained in the given `batch`,
|
||||
while the ctc topo is built when this compiler is instantiated.
|
||||
is_training:
|
||||
True for training. False for validation. When it is True, this
|
||||
function enables autograd during computation; when it is False, it
|
||||
disables autograd.
|
||||
ali:
|
||||
Precomputed alignments.
|
||||
"""
|
||||
device = graph_compiler.device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
with torch.set_grad_enabled(is_training):
|
||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
# different duration in decreasing order, required by
|
||||
# `k2.intersect_dense` called in `LFMMILoss.forward()`
|
||||
supervision_segments, texts = encode_supervisions(
|
||||
supervisions, subsampling_factor=params.subsampling_factor
|
||||
)
|
||||
|
||||
if ali is not None and params.batch_idx_train < params.use_ali_until:
|
||||
cut_ids = [cut.id for cut in supervisions["cut"]]
|
||||
|
||||
# As encode_supervisions reorders cuts, we need
|
||||
# also to reorder cut IDs here
|
||||
new2old = supervision_segments[:, 0].tolist()
|
||||
cut_ids = [cut_ids[i] for i in new2old]
|
||||
|
||||
# Check that new2old is just a permutation,
|
||||
# i.e., each cut contains only one utterance
|
||||
new2old.sort()
|
||||
assert new2old == torch.arange(len(new2old)).tolist()
|
||||
mask = lookup_alignments(
|
||||
cut_ids=cut_ids,
|
||||
alignments=ali,
|
||||
num_classes=nnet_output.shape[2],
|
||||
).to(nnet_output)
|
||||
|
||||
min_len = min(nnet_output.shape[1], mask.shape[1])
|
||||
ali_scale = 500.0 / (params.batch_idx_train + 500)
|
||||
|
||||
nnet_output = nnet_output.clone()
|
||||
nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :]
|
||||
|
||||
if (
|
||||
params.batch_idx_train > params.use_ali_until
|
||||
and params.beam_size < 8
|
||||
):
|
||||
# logging.info("Change beam size to 8")
|
||||
params.beam_size = 8
|
||||
else:
|
||||
params.beam_size = 6
|
||||
|
||||
loss_fn = LFMMILoss(
|
||||
graph_compiler=graph_compiler,
|
||||
use_pruned_intersect=params.use_pruned_intersect,
|
||||
den_scale=params.den_scale,
|
||||
beam_size=params.beam_size,
|
||||
)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
nnet_output,
|
||||
supervision_segments,
|
||||
allow_truncate=params.subsampling_factor - 1,
|
||||
)
|
||||
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
||||
|
||||
if params.att_rate != 0.0:
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if hasattr(model, "module"):
|
||||
att_loss = model.module.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
else:
|
||||
att_loss = model.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss
|
||||
else:
|
||||
loss = mmi_loss
|
||||
att_loss = torch.tensor([0])
|
||||
|
||||
# train_frames and valid_frames are used for printing.
|
||||
if is_training:
|
||||
params.train_frames = supervision_segments[:, 2].sum().item()
|
||||
else:
|
||||
params.valid_frames = supervision_segments[:, 2].sum().item()
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
return loss, mmi_loss.detach(), att_loss.detach()
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
graph_compiler: MmiTrainingGraphCompiler,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
ali: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> None:
|
||||
"""Run the validation process. The validation loss
|
||||
is saved in `params.valid_loss`.
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
tot_loss = 0.0
|
||||
tot_mmi_loss = 0.0
|
||||
tot_att_loss = 0.0
|
||||
tot_frames = 0.0
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss, mmi_loss, att_loss = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=False,
|
||||
ali=ali,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
assert mmi_loss.requires_grad is False
|
||||
assert att_loss.requires_grad is False
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
tot_loss += loss_cpu
|
||||
|
||||
tot_mmi_loss += mmi_loss.detach().cpu().item()
|
||||
tot_att_loss += att_loss.detach().cpu().item()
|
||||
|
||||
tot_frames += params.valid_frames
|
||||
|
||||
if world_size > 1:
|
||||
s = torch.tensor(
|
||||
[tot_loss, tot_mmi_loss, tot_att_loss, tot_frames],
|
||||
device=loss.device,
|
||||
)
|
||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||
s = s.cpu().tolist()
|
||||
tot_loss = s[0]
|
||||
tot_mmi_loss = s[1]
|
||||
tot_att_loss = s[2]
|
||||
tot_frames = s[3]
|
||||
|
||||
params.valid_loss = tot_loss / tot_frames
|
||||
params.valid_mmi_loss = tot_mmi_loss / tot_frames
|
||||
params.valid_att_loss = tot_att_loss / tot_frames
|
||||
|
||||
if params.valid_loss < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = params.valid_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
graph_compiler: MmiTrainingGraphCompiler,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
train_ali: Optional[Dict[str, torch.Tensor]],
|
||||
valid_ali: Optional[Dict[str, torch.Tensor]],
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
The training loss from the mean of all frames is saved in
|
||||
`params.train_loss`. It runs the validation process every
|
||||
`params.valid_interval` batches.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model for training.
|
||||
optimizer:
|
||||
The optimizer we are using.
|
||||
graph_compiler:
|
||||
It is used to convert transcripts to FSAs.
|
||||
train_dl:
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
train_ali:
|
||||
Precomputed alignments for the training set.
|
||||
valid_ali:
|
||||
Precomputed alignments for the validation set.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = 0.0 # sum of losses over all batches
|
||||
tot_mmi_loss = 0.0
|
||||
tot_att_loss = 0.0
|
||||
|
||||
tot_frames = 0.0 # sum of frames over all batches
|
||||
params.tot_loss = 0.0
|
||||
params.tot_frames = 0.0
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
loss, mmi_loss, att_loss = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=True,
|
||||
ali=train_ali,
|
||||
)
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
mmi_loss_cpu = mmi_loss.detach().cpu().item()
|
||||
att_loss_cpu = att_loss.detach().cpu().item()
|
||||
|
||||
tot_frames += params.train_frames
|
||||
tot_loss += loss_cpu
|
||||
tot_mmi_loss += mmi_loss_cpu
|
||||
tot_att_loss += att_loss_cpu
|
||||
|
||||
params.tot_frames += params.train_frames
|
||||
params.tot_loss += loss_cpu
|
||||
|
||||
tot_avg_loss = tot_loss / tot_frames
|
||||
tot_avg_mmi_loss = tot_mmi_loss / tot_frames
|
||||
tot_avg_att_loss = tot_att_loss / tot_frames
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"batch avg mmi loss {mmi_loss_cpu/params.train_frames:.4f}, "
|
||||
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
|
||||
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
||||
f"total avg mmiloss: {tot_avg_mmi_loss:.4f}, "
|
||||
f"total avg att loss: {tot_avg_att_loss:.4f}, "
|
||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||
f"batch size: {batch_size}"
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/current_mmi_loss",
|
||||
mmi_loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/current_att_loss",
|
||||
att_loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/current_loss",
|
||||
loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_mmi_loss",
|
||||
tot_avg_mmi_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_att_loss",
|
||||
tot_avg_att_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_loss",
|
||||
tot_avg_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
||||
tot_loss = 0.0 # sum of losses over all batches
|
||||
tot_mmi_loss = 0.0
|
||||
tot_att_loss = 0.0
|
||||
|
||||
tot_frames = 0.0 # sum of frames over all batches
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
graph_compiler=graph_compiler,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
ali=valid_ali,
|
||||
)
|
||||
model.train()
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"valid mmi loss {params.valid_mmi_loss:.4f},"
|
||||
f"valid att loss {params.valid_att_loss:.4f},"
|
||||
f"valid loss {params.valid_loss:.4f},"
|
||||
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||
f"best valid epoch: {params.best_valid_epoch}"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_mmi_loss",
|
||||
params.valid_mmi_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_att_loss",
|
||||
params.valid_att_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_loss",
|
||||
params.valid_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
params.train_loss = params.tot_loss / params.tot_frames
|
||||
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
It is a value between 0 and `world_size-1`, which is
|
||||
passed automatically by `mp.spawn()` in :func:`main`.
|
||||
The node with rank 0 is responsible for saving checkpoint.
|
||||
world_size:
|
||||
Number of GPUs for DDP training.
|
||||
args:
|
||||
The return value of get_parser().parse_args()
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
fix_random_seed(42)
|
||||
if world_size > 1:
|
||||
setup_dist(rank, world_size, params.master_port)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
logging.info(params)
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
|
||||
graph_compiler = MmiTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
uniq_filename="lexicon.txt",
|
||||
device=device,
|
||||
oov="<UNK>",
|
||||
sos_id=1,
|
||||
eos_id=1,
|
||||
)
|
||||
|
||||
logging.info("About to create model")
|
||||
if params.att_rate == 0:
|
||||
assert params.num_decoder_layers == 0, f"{params.num_decoder_layers}"
|
||||
|
||||
model = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
nhead=params.nhead,
|
||||
d_model=params.attention_dim,
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
vgg_frontend=False,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
model = DDP(model, device_ids=[rank])
|
||||
|
||||
optimizer = Noam(
|
||||
model.parameters(),
|
||||
model_size=params.attention_dim,
|
||||
factor=params.lr_factor,
|
||||
warm_step=params.warm_step,
|
||||
weight_decay=params.weight_decay,
|
||||
)
|
||||
|
||||
if checkpoints:
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
|
||||
train_960_ali_filename = Path(params.ali_dir) / "train-960.pt"
|
||||
if (
|
||||
params.batch_idx_train < params.use_ali_until
|
||||
and train_960_ali_filename.is_file()
|
||||
):
|
||||
logging.info("Use pre-computed alignments")
|
||||
subsampling_factor, train_ali = load_alignments(train_960_ali_filename)
|
||||
assert subsampling_factor == params.subsampling_factor
|
||||
assert len(train_ali) == 843723, f"{len(train_ali)} vs 843723"
|
||||
|
||||
valid_ali_filename = Path(params.ali_dir) / "valid.pt"
|
||||
subsampling_factor, valid_ali = load_alignments(valid_ali_filename)
|
||||
assert subsampling_factor == params.subsampling_factor
|
||||
|
||||
train_ali = convert_alignments_to_tensor(train_ali, device=device)
|
||||
valid_ali = convert_alignments_to_tensor(valid_ali, device=device)
|
||||
else:
|
||||
logging.info("Not using alignments")
|
||||
train_ali = None
|
||||
valid_ali = None
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
train_dl = librispeech.train_dataloaders()
|
||||
valid_dl = librispeech.valid_dataloaders()
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
if (
|
||||
params.batch_idx_train >= params.use_ali_until
|
||||
and train_ali is not None
|
||||
):
|
||||
# Delete the alignments to save memory
|
||||
train_ali = None
|
||||
valid_ali = None
|
||||
|
||||
cur_lr = optimizer._rate
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
if rank == 0:
|
||||
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
graph_compiler=graph_compiler,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
train_ali=train_ali,
|
||||
valid_ali=valid_ali,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
837
egs/librispeech/ASR/conformer_mmi/train.py
Executable file
837
egs/librispeech/ASR/conformer_mmi/train.py
Executable file
@ -0,0 +1,837 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Dict, Optional
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from transformer import Noam
|
||||
|
||||
from icefall.ali import (
|
||||
convert_alignments_to_tensor,
|
||||
load_alignments,
|
||||
lookup_alignments,
|
||||
)
|
||||
from icefall.checkpoint import load_checkpoint
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.mmi import LFMMILoss
|
||||
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
encode_supervisions,
|
||||
setup_logger,
|
||||
str2bool,
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of GPUs for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--master-port",
|
||||
type=int,
|
||||
default=12354,
|
||||
help="Master port to use for DDP training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tensorboard",
|
||||
type=str2bool,
|
||||
default=True,
|
||||
help="Should various information be logged in tensorboard.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-epochs",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Number of epochs to train.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--start-epoch",
|
||||
type=int,
|
||||
default=0,
|
||||
help="""Resume training from from this epoch.
|
||||
If it is positive, it will load checkpoint from
|
||||
conformer_mmi/exp/epoch-{start_epoch-1}.pt
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ali-dir",
|
||||
type=str,
|
||||
default="data/ali_500",
|
||||
help="""This folder is expected to contain
|
||||
two files, train-960.pt and valid.pt, which
|
||||
contain framewise alignment information for
|
||||
the training set and validation set.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
"""Return a dict containing training parameters.
|
||||
|
||||
All training related parameters that are not passed from the commandline
|
||||
are saved in the variable `params`.
|
||||
|
||||
Commandline options are merged into `params` after they are parsed, so
|
||||
you can also access them via `params`.
|
||||
|
||||
Explanation of options saved in `params`:
|
||||
|
||||
- exp_dir: It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
|
||||
- lang_dir: It contains language related input files such as
|
||||
"lexicon.txt"
|
||||
|
||||
- best_train_loss: Best training loss so far. It is used to select
|
||||
the model that has the lowest training loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_valid_loss: Best validation loss so far. It is used to select
|
||||
the model that has the lowest validation loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_train_epoch: It is the epoch that has the best training loss.
|
||||
|
||||
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||
|
||||
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||
contains number of batches trained so far across
|
||||
epochs.
|
||||
|
||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||
|
||||
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||
|
||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||
|
||||
- feature_dim: The model input dim. It has to match the one used
|
||||
in computing features.
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- use_feat_batchnorm: Whether to do batch normalization for the
|
||||
input features.
|
||||
|
||||
- attention_dim: Hidden dim for multi-head attention model.
|
||||
|
||||
- head: Number of heads of multi-head attention model.
|
||||
|
||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||
|
||||
- weight_decay: The weight_decay for the optimizer.
|
||||
|
||||
- lr_factor: The lr_factor for Noam optimizer.
|
||||
|
||||
- warm_step: The warm_step for Noam optimizer.
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("conformer_mmi/exp_500"),
|
||||
"lang_dir": Path("data/lang_bpe_500"),
|
||||
"best_train_loss": float("inf"),
|
||||
"best_valid_loss": float("inf"),
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000,
|
||||
# parameters for conformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4,
|
||||
"use_feat_batchnorm": True,
|
||||
"attention_dim": 512,
|
||||
"nhead": 8,
|
||||
# parameters for loss
|
||||
"beam_size": 6, # will change it to 8 after some batches (see code)
|
||||
"reduction": "sum",
|
||||
"use_double_scores": True,
|
||||
"att_rate": 0.0,
|
||||
"num_decoder_layers": 0,
|
||||
# "att_rate": 0.7,
|
||||
# "num_decoder_layers": 6,
|
||||
# parameters for Noam
|
||||
"weight_decay": 1e-6,
|
||||
"lr_factor": 5.0,
|
||||
"warm_step": 80000,
|
||||
"use_pruned_intersect": False,
|
||||
"den_scale": 1.0,
|
||||
# use alignments before this number of batches
|
||||
"use_ali_until": 13000,
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
) -> None:
|
||||
"""Load checkpoint from file.
|
||||
|
||||
If params.start_epoch is positive, it will load the checkpoint from
|
||||
`params.start_epoch - 1`. Otherwise, this function does nothing.
|
||||
|
||||
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
|
||||
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
|
||||
and `best_valid_loss` in `params`.
|
||||
|
||||
Args:
|
||||
params:
|
||||
The return value of :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
optimizer:
|
||||
The optimizer that we are using.
|
||||
scheduler:
|
||||
The learning rate scheduler we are using.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
if params.start_epoch <= 0:
|
||||
return
|
||||
|
||||
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
|
||||
saved_params = load_checkpoint(
|
||||
filename,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
keys = [
|
||||
"best_train_epoch",
|
||||
"best_valid_epoch",
|
||||
"batch_idx_train",
|
||||
"best_train_loss",
|
||||
"best_valid_loss",
|
||||
]
|
||||
for k in keys:
|
||||
params[k] = saved_params[k]
|
||||
|
||||
return saved_params
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
|
||||
rank: int = 0,
|
||||
) -> None:
|
||||
"""Save model, optimizer, scheduler and training stats to file.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The training model.
|
||||
"""
|
||||
if rank != 0:
|
||||
return
|
||||
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
|
||||
save_checkpoint_impl(
|
||||
filename=filename,
|
||||
model=model,
|
||||
params=params,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
if params.best_train_epoch == params.cur_epoch:
|
||||
best_train_filename = params.exp_dir / "best-train-loss.pt"
|
||||
copyfile(src=filename, dst=best_train_filename)
|
||||
|
||||
if params.best_valid_epoch == params.cur_epoch:
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
batch: dict,
|
||||
graph_compiler: MmiTrainingGraphCompiler,
|
||||
is_training: bool,
|
||||
ali: Optional[Dict[str, torch.Tensor]],
|
||||
):
|
||||
"""
|
||||
Compute LF-MMI loss given the model and its inputs.
|
||||
|
||||
Args:
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
model:
|
||||
The model for training. It is an instance of Conformer in our case.
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
graph_compiler:
|
||||
It is used to build a decoding graph from a ctc topo and training
|
||||
transcript. The training transcript is contained in the given `batch`,
|
||||
while the ctc topo is built when this compiler is instantiated.
|
||||
is_training:
|
||||
True for training. False for validation. When it is True, this
|
||||
function enables autograd during computation; when it is False, it
|
||||
disables autograd.
|
||||
ali:
|
||||
Precomputed alignments.
|
||||
"""
|
||||
device = graph_compiler.device
|
||||
feature = batch["inputs"]
|
||||
# at entry, feature is (N, T, C)
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
with torch.set_grad_enabled(is_training):
|
||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||
# nnet_output is (N, T, C)
|
||||
|
||||
# NOTE: We need `encode_supervisions` to sort sequences with
|
||||
# different duration in decreasing order, required by
|
||||
# `k2.intersect_dense` called in `LFMMILoss.forward()`
|
||||
supervision_segments, texts = encode_supervisions(
|
||||
supervisions, subsampling_factor=params.subsampling_factor
|
||||
)
|
||||
|
||||
if ali is not None and params.batch_idx_train < params.use_ali_until:
|
||||
cut_ids = [cut.id for cut in supervisions["cut"]]
|
||||
|
||||
# As encode_supervisions reorders cuts, we need
|
||||
# also to reorder cut IDs here
|
||||
new2old = supervision_segments[:, 0].tolist()
|
||||
cut_ids = [cut_ids[i] for i in new2old]
|
||||
|
||||
# Check that new2old is just a permutation,
|
||||
# i.e., each cut contains only one utterance
|
||||
new2old.sort()
|
||||
assert new2old == torch.arange(len(new2old)).tolist()
|
||||
mask = lookup_alignments(
|
||||
cut_ids=cut_ids,
|
||||
alignments=ali,
|
||||
num_classes=nnet_output.shape[2],
|
||||
).to(nnet_output)
|
||||
|
||||
min_len = min(nnet_output.shape[1], mask.shape[1])
|
||||
ali_scale = 500.0 / (params.batch_idx_train + 500)
|
||||
|
||||
nnet_output = nnet_output.clone()
|
||||
nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :]
|
||||
|
||||
if (
|
||||
params.batch_idx_train > params.use_ali_until
|
||||
and params.beam_size < 8
|
||||
):
|
||||
logging.info("Change beam size to 8")
|
||||
params.beam_size = 8
|
||||
else:
|
||||
params.beam_size = 6
|
||||
|
||||
loss_fn = LFMMILoss(
|
||||
graph_compiler=graph_compiler,
|
||||
use_pruned_intersect=params.use_pruned_intersect,
|
||||
den_scale=params.den_scale,
|
||||
beam_size=params.beam_size,
|
||||
)
|
||||
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
nnet_output,
|
||||
supervision_segments,
|
||||
allow_truncate=params.subsampling_factor - 1,
|
||||
)
|
||||
mmi_loss = loss_fn(dense_fsa_vec=dense_fsa_vec, texts=texts)
|
||||
|
||||
if params.att_rate != 0.0:
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
with torch.set_grad_enabled(is_training):
|
||||
if hasattr(model, "module"):
|
||||
att_loss = model.module.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
else:
|
||||
att_loss = model.decoder_forward(
|
||||
encoder_memory,
|
||||
memory_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=graph_compiler.sos_id,
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
loss = (1.0 - params.att_rate) * mmi_loss + params.att_rate * att_loss
|
||||
else:
|
||||
loss = mmi_loss
|
||||
att_loss = torch.tensor([0])
|
||||
|
||||
# train_frames and valid_frames are used for printing.
|
||||
if is_training:
|
||||
params.train_frames = supervision_segments[:, 2].sum().item()
|
||||
else:
|
||||
params.valid_frames = supervision_segments[:, 2].sum().item()
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
return loss, mmi_loss.detach(), att_loss.detach()
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
graph_compiler: MmiTrainingGraphCompiler,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
ali: Optional[Dict[str, torch.Tensor]] = None,
|
||||
) -> None:
|
||||
"""Run the validation process. The validation loss
|
||||
is saved in `params.valid_loss`.
|
||||
"""
|
||||
model.eval()
|
||||
|
||||
tot_loss = 0.0
|
||||
tot_mmi_loss = 0.0
|
||||
tot_att_loss = 0.0
|
||||
tot_frames = 0.0
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss, mmi_loss, att_loss = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=False,
|
||||
ali=ali,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
assert mmi_loss.requires_grad is False
|
||||
assert att_loss.requires_grad is False
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
tot_loss += loss_cpu
|
||||
|
||||
tot_mmi_loss += mmi_loss.detach().cpu().item()
|
||||
tot_att_loss += att_loss.detach().cpu().item()
|
||||
|
||||
tot_frames += params.valid_frames
|
||||
|
||||
if world_size > 1:
|
||||
s = torch.tensor(
|
||||
[tot_loss, tot_mmi_loss, tot_att_loss, tot_frames],
|
||||
device=loss.device,
|
||||
)
|
||||
dist.all_reduce(s, op=dist.ReduceOp.SUM)
|
||||
s = s.cpu().tolist()
|
||||
tot_loss = s[0]
|
||||
tot_mmi_loss = s[1]
|
||||
tot_att_loss = s[2]
|
||||
tot_frames = s[3]
|
||||
|
||||
params.valid_loss = tot_loss / tot_frames
|
||||
params.valid_mmi_loss = tot_mmi_loss / tot_frames
|
||||
params.valid_att_loss = tot_att_loss / tot_frames
|
||||
|
||||
if params.valid_loss < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = params.valid_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
graph_compiler: MmiTrainingGraphCompiler,
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
train_ali: Optional[Dict[str, torch.Tensor]],
|
||||
valid_ali: Optional[Dict[str, torch.Tensor]],
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
) -> None:
|
||||
"""Train the model for one epoch.
|
||||
|
||||
The training loss from the mean of all frames is saved in
|
||||
`params.train_loss`. It runs the validation process every
|
||||
`params.valid_interval` batches.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The model for training.
|
||||
optimizer:
|
||||
The optimizer we are using.
|
||||
graph_compiler:
|
||||
It is used to convert transcripts to FSAs.
|
||||
train_dl:
|
||||
Dataloader for the training dataset.
|
||||
valid_dl:
|
||||
Dataloader for the validation dataset.
|
||||
train_ali:
|
||||
Precomputed alignments for the training set.
|
||||
valid_ali:
|
||||
Precomputed alignments for the validation set.
|
||||
tb_writer:
|
||||
Writer to write log messages to tensorboard.
|
||||
world_size:
|
||||
Number of nodes in DDP training. If it is 1, DDP is disabled.
|
||||
"""
|
||||
model.train()
|
||||
|
||||
tot_loss = 0.0 # sum of losses over all batches
|
||||
tot_mmi_loss = 0.0
|
||||
tot_att_loss = 0.0
|
||||
|
||||
tot_frames = 0.0 # sum of frames over all batches
|
||||
params.tot_loss = 0.0
|
||||
params.tot_frames = 0.0
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
loss, mmi_loss, att_loss = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
graph_compiler=graph_compiler,
|
||||
is_training=True,
|
||||
ali=train_ali,
|
||||
)
|
||||
|
||||
# NOTE: We use reduction==sum and loss is computed over utterances
|
||||
# in the batch and there is no normalization to it so far.
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
mmi_loss_cpu = mmi_loss.detach().cpu().item()
|
||||
att_loss_cpu = att_loss.detach().cpu().item()
|
||||
|
||||
tot_frames += params.train_frames
|
||||
tot_loss += loss_cpu
|
||||
tot_mmi_loss += mmi_loss_cpu
|
||||
tot_att_loss += att_loss_cpu
|
||||
|
||||
params.tot_frames += params.train_frames
|
||||
params.tot_loss += loss_cpu
|
||||
|
||||
tot_avg_loss = tot_loss / tot_frames
|
||||
tot_avg_mmi_loss = tot_mmi_loss / tot_frames
|
||||
tot_avg_att_loss = tot_att_loss / tot_frames
|
||||
|
||||
if batch_idx % params.log_interval == 0:
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
|
||||
f"batch avg mmi loss {mmi_loss_cpu/params.train_frames:.4f}, "
|
||||
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
|
||||
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
|
||||
f"total avg mmiloss: {tot_avg_mmi_loss:.4f}, "
|
||||
f"total avg att loss: {tot_avg_att_loss:.4f}, "
|
||||
f"total avg loss: {tot_avg_loss:.4f}, "
|
||||
f"batch size: {batch_size}"
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/current_mmi_loss",
|
||||
mmi_loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/current_att_loss",
|
||||
att_loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/current_loss",
|
||||
loss_cpu / params.train_frames,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_mmi_loss",
|
||||
tot_avg_mmi_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_att_loss",
|
||||
tot_avg_att_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/tot_avg_loss",
|
||||
tot_avg_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
if batch_idx > 0 and batch_idx % params.reset_interval == 0:
|
||||
tot_loss = 0.0 # sum of losses over all batches
|
||||
tot_mmi_loss = 0.0
|
||||
tot_att_loss = 0.0
|
||||
|
||||
tot_frames = 0.0 # sum of frames over all batches
|
||||
|
||||
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
|
||||
compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
graph_compiler=graph_compiler,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
ali=valid_ali,
|
||||
)
|
||||
model.train()
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"valid mmi loss {params.valid_mmi_loss:.4f},"
|
||||
f"valid att loss {params.valid_att_loss:.4f},"
|
||||
f"valid loss {params.valid_loss:.4f},"
|
||||
f" best valid loss: {params.best_valid_loss:.4f} "
|
||||
f"best valid epoch: {params.best_valid_epoch}"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_mmi_loss",
|
||||
params.valid_mmi_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_att_loss",
|
||||
params.valid_att_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
tb_writer.add_scalar(
|
||||
"train/valid_loss",
|
||||
params.valid_loss,
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
params.train_loss = params.tot_loss / params.tot_frames
|
||||
|
||||
if params.train_loss < params.best_train_loss:
|
||||
params.best_train_epoch = params.cur_epoch
|
||||
params.best_train_loss = params.train_loss
|
||||
|
||||
|
||||
def run(rank, world_size, args):
|
||||
"""
|
||||
Args:
|
||||
rank:
|
||||
It is a value between 0 and `world_size-1`, which is
|
||||
passed automatically by `mp.spawn()` in :func:`main`.
|
||||
The node with rank 0 is responsible for saving checkpoint.
|
||||
world_size:
|
||||
Number of GPUs for DDP training.
|
||||
args:
|
||||
The return value of get_parser().parse_args()
|
||||
"""
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
fix_random_seed(42)
|
||||
if world_size > 1:
|
||||
setup_dist(rank, world_size, params.master_port)
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log/log-train")
|
||||
logging.info("Training started")
|
||||
logging.info(params)
|
||||
|
||||
if args.tensorboard and rank == 0:
|
||||
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
|
||||
else:
|
||||
tb_writer = None
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
|
||||
graph_compiler = MmiTrainingGraphCompiler(
|
||||
params.lang_dir,
|
||||
uniq_filename="lexicon.txt",
|
||||
device=device,
|
||||
oov="<UNK>",
|
||||
sos_id=1,
|
||||
eos_id=1,
|
||||
)
|
||||
|
||||
logging.info("About to create model")
|
||||
if params.att_rate == 0:
|
||||
assert params.num_decoder_layers == 0, f"{params.num_decoder_layers}"
|
||||
|
||||
model = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
nhead=params.nhead,
|
||||
d_model=params.attention_dim,
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
vgg_frontend=False,
|
||||
use_feat_batchnorm=params.use_feat_batchnorm,
|
||||
)
|
||||
|
||||
checkpoints = load_checkpoint_if_available(params=params, model=model)
|
||||
|
||||
model.to(device)
|
||||
if world_size > 1:
|
||||
model = DDP(model, device_ids=[rank])
|
||||
|
||||
optimizer = Noam(
|
||||
model.parameters(),
|
||||
model_size=params.attention_dim,
|
||||
factor=params.lr_factor,
|
||||
warm_step=params.warm_step,
|
||||
weight_decay=params.weight_decay,
|
||||
)
|
||||
|
||||
if checkpoints:
|
||||
optimizer.load_state_dict(checkpoints["optimizer"])
|
||||
|
||||
train_960_ali_filename = Path(params.ali_dir) / "train-960.pt"
|
||||
if (
|
||||
params.batch_idx_train < params.use_ali_until
|
||||
and train_960_ali_filename.is_file()
|
||||
):
|
||||
logging.info("Use pre-computed alignments")
|
||||
subsampling_factor, train_ali = load_alignments(train_960_ali_filename)
|
||||
assert subsampling_factor == params.subsampling_factor
|
||||
assert len(train_ali) == 843723, f"{len(train_ali)} vs 843723"
|
||||
|
||||
valid_ali_filename = Path(params.ali_dir) / "valid.pt"
|
||||
subsampling_factor, valid_ali = load_alignments(valid_ali_filename)
|
||||
assert subsampling_factor == params.subsampling_factor
|
||||
|
||||
train_ali = convert_alignments_to_tensor(train_ali, device=device)
|
||||
valid_ali = convert_alignments_to_tensor(valid_ali, device=device)
|
||||
else:
|
||||
logging.info("Not using alignments")
|
||||
train_ali = None
|
||||
valid_ali = None
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
train_dl = librispeech.train_dataloaders()
|
||||
valid_dl = librispeech.valid_dataloaders()
|
||||
|
||||
for epoch in range(params.start_epoch, params.num_epochs):
|
||||
train_dl.sampler.set_epoch(epoch)
|
||||
if (
|
||||
params.batch_idx_train >= params.use_ali_until
|
||||
and train_ali is not None
|
||||
):
|
||||
# Delete the alignments to save memory
|
||||
train_ali = None
|
||||
valid_ali = None
|
||||
|
||||
cur_lr = optimizer._rate
|
||||
if tb_writer is not None:
|
||||
tb_writer.add_scalar(
|
||||
"train/learning_rate", cur_lr, params.batch_idx_train
|
||||
)
|
||||
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
|
||||
|
||||
if rank == 0:
|
||||
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
|
||||
|
||||
params.cur_epoch = epoch
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
graph_compiler=graph_compiler,
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
train_ali=train_ali,
|
||||
valid_ali=valid_ali,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
)
|
||||
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
rank=rank,
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
if world_size > 1:
|
||||
torch.distributed.barrier()
|
||||
cleanup_dist()
|
||||
|
||||
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
|
||||
else:
|
||||
run(rank=0, world_size=1, args=args)
|
||||
|
||||
|
||||
torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
998
egs/librispeech/ASR/conformer_mmi/transformer.py
Normal file
998
egs/librispeech/ASR/conformer_mmi/transformer.py
Normal file
@ -0,0 +1,998 @@
|
||||
# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
# Note: TorchScript requires Dict/List/etc. to be fully typed.
|
||||
Supervisions = Dict[str, torch.Tensor]
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_features: int,
|
||||
num_classes: int,
|
||||
subsampling_factor: int = 4,
|
||||
d_model: int = 256,
|
||||
nhead: int = 4,
|
||||
dim_feedforward: int = 2048,
|
||||
num_encoder_layers: int = 12,
|
||||
num_decoder_layers: int = 6,
|
||||
dropout: float = 0.1,
|
||||
normalize_before: bool = True,
|
||||
vgg_frontend: bool = False,
|
||||
use_feat_batchnorm: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
num_features:
|
||||
The input dimension of the model.
|
||||
num_classes:
|
||||
The output dimension of the model.
|
||||
subsampling_factor:
|
||||
Number of output frames is num_in_frames // subsampling_factor.
|
||||
Currently, subsampling_factor MUST be 4.
|
||||
d_model:
|
||||
Attention dimension.
|
||||
nhead:
|
||||
Number of heads in multi-head attention.
|
||||
Must satisfy d_model // nhead == 0.
|
||||
dim_feedforward:
|
||||
The output dimension of the feedforward layers in encoder/decoder.
|
||||
num_encoder_layers:
|
||||
Number of encoder layers.
|
||||
num_decoder_layers:
|
||||
Number of decoder layers.
|
||||
dropout:
|
||||
Dropout in encoder/decoder.
|
||||
normalize_before:
|
||||
If True, use pre-layer norm; False to use post-layer norm.
|
||||
vgg_frontend:
|
||||
True to use vgg style frontend for subsampling.
|
||||
use_feat_batchnorm:
|
||||
True to use batchnorm for the input layer.
|
||||
"""
|
||||
super().__init__()
|
||||
self.use_feat_batchnorm = use_feat_batchnorm
|
||||
if use_feat_batchnorm:
|
||||
self.feat_batchnorm = nn.BatchNorm1d(num_features)
|
||||
|
||||
self.num_features = num_features
|
||||
self.num_classes = num_classes
|
||||
self.subsampling_factor = subsampling_factor
|
||||
if subsampling_factor != 4:
|
||||
raise NotImplementedError("Support only 'subsampling_factor=4'.")
|
||||
|
||||
# self.encoder_embed converts the input of shape (N, T, num_classes)
|
||||
# to the shape (N, T//subsampling_factor, d_model).
|
||||
# That is, it does two things simultaneously:
|
||||
# (1) subsampling: T -> T//subsampling_factor
|
||||
# (2) embedding: num_classes -> d_model
|
||||
if vgg_frontend:
|
||||
self.encoder_embed = VggSubsampling(num_features, d_model)
|
||||
else:
|
||||
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
|
||||
|
||||
self.encoder_pos = PositionalEncoding(d_model, dropout)
|
||||
|
||||
encoder_layer = TransformerEncoderLayer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=dropout,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
if normalize_before:
|
||||
encoder_norm = nn.LayerNorm(d_model)
|
||||
else:
|
||||
encoder_norm = None
|
||||
|
||||
self.encoder = nn.TransformerEncoder(
|
||||
encoder_layer=encoder_layer,
|
||||
num_layers=num_encoder_layers,
|
||||
norm=encoder_norm,
|
||||
)
|
||||
|
||||
# TODO(fangjun): remove dropout
|
||||
self.encoder_output_layer = nn.Sequential(
|
||||
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
|
||||
)
|
||||
|
||||
if num_decoder_layers > 0:
|
||||
self.decoder_num_class = (
|
||||
self.num_classes
|
||||
) # bpe model already has sos/eos symbol
|
||||
|
||||
self.decoder_embed = nn.Embedding(
|
||||
num_embeddings=self.decoder_num_class, embedding_dim=d_model
|
||||
)
|
||||
self.decoder_pos = PositionalEncoding(d_model, dropout)
|
||||
|
||||
decoder_layer = TransformerDecoderLayer(
|
||||
d_model=d_model,
|
||||
nhead=nhead,
|
||||
dim_feedforward=dim_feedforward,
|
||||
dropout=dropout,
|
||||
normalize_before=normalize_before,
|
||||
)
|
||||
|
||||
if normalize_before:
|
||||
decoder_norm = nn.LayerNorm(d_model)
|
||||
else:
|
||||
decoder_norm = None
|
||||
|
||||
self.decoder = nn.TransformerDecoder(
|
||||
decoder_layer=decoder_layer,
|
||||
num_layers=num_decoder_layers,
|
||||
norm=decoder_norm,
|
||||
)
|
||||
|
||||
self.decoder_output_layer = torch.nn.Linear(
|
||||
d_model, self.decoder_num_class
|
||||
)
|
||||
|
||||
self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class)
|
||||
else:
|
||||
self.decoder_criterion = None
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, supervision: Optional[Supervisions] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The input tensor. Its shape is (N, T, C).
|
||||
supervision:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
(CAUTION: It contains length information, i.e., start and number of
|
||||
frames, before subsampling)
|
||||
|
||||
Returns:
|
||||
Return a tuple containing 3 tensors:
|
||||
- CTC output for ctc decoding. Its shape is (N, T, C)
|
||||
- Encoder output with shape (T, N, C). It can be used as key and
|
||||
value for the decoder.
|
||||
- Encoder output padding mask. It can be used as
|
||||
memory_key_padding_mask for the decoder. Its shape is (N, T).
|
||||
It is None if `supervision` is None.
|
||||
"""
|
||||
if self.use_feat_batchnorm:
|
||||
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
|
||||
x = self.feat_batchnorm(x)
|
||||
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
|
||||
encoder_memory, memory_key_padding_mask = self.run_encoder(
|
||||
x, supervision
|
||||
)
|
||||
x = self.ctc_output(encoder_memory)
|
||||
return x, encoder_memory, memory_key_padding_mask
|
||||
|
||||
def run_encoder(
|
||||
self, x: torch.Tensor, supervisions: Optional[Supervisions] = None
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""Run the transformer encoder.
|
||||
|
||||
Args:
|
||||
x:
|
||||
The model input. Its shape is (N, T, C).
|
||||
supervisions:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
CAUTION: It contains length information, i.e., start and number of
|
||||
frames, before subsampling
|
||||
It is read directly from the batch, without any sorting. It is used
|
||||
to compute the encoder padding mask, which is used as memory key
|
||||
padding mask for the decoder.
|
||||
Returns:
|
||||
Return a tuple with two tensors:
|
||||
- The encoder output, with shape (T, N, C)
|
||||
- encoder padding mask, with shape (N, T).
|
||||
The mask is None if `supervisions` is None.
|
||||
It is used as memory key padding mask in the decoder.
|
||||
"""
|
||||
x = self.encoder_embed(x)
|
||||
x = self.encoder_pos(x)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
mask = encoder_padding_mask(x.size(0), supervisions)
|
||||
mask = mask.to(x.device) if mask is not None else None
|
||||
x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C)
|
||||
|
||||
return x, mask
|
||||
|
||||
def ctc_output(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
The output tensor from the transformer encoder.
|
||||
Its shape is (T, N, C)
|
||||
|
||||
Returns:
|
||||
Return a tensor that can be used for CTC decoding.
|
||||
Its shape is (N, T, C)
|
||||
"""
|
||||
x = self.encoder_output_layer(x)
|
||||
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
|
||||
return x
|
||||
|
||||
def decoder_forward(
|
||||
self,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
token_ids: List[List[int]],
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
memory:
|
||||
It's the output of the encoder with shape (T, N, C)
|
||||
memory_key_padding_mask:
|
||||
The padding mask from the encoder.
|
||||
token_ids:
|
||||
A list-of-list IDs. Each sublist contains IDs for an utterance.
|
||||
The IDs can be either phone IDs or word piece IDs.
|
||||
sos_id:
|
||||
sos token id
|
||||
eos_id:
|
||||
eos token id
|
||||
|
||||
Returns:
|
||||
A scalar, the **sum** of label smoothing loss over utterances
|
||||
in the batch without any normalization.
|
||||
"""
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
||||
|
||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||
ys_out = [torch.tensor(y) for y in ys_out]
|
||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
|
||||
|
||||
device = memory.device
|
||||
ys_in_pad = ys_in_pad.to(device)
|
||||
ys_out_pad = ys_out_pad.to(device)
|
||||
|
||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
|
||||
device
|
||||
)
|
||||
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
# TODO: Use length information to create the decoder padding mask
|
||||
# We set the first column to False since the first column in ys_in_pad
|
||||
# contains sos_id, which is the same as eos_id in our current setting.
|
||||
tgt_key_padding_mask[:, 0] = False
|
||||
|
||||
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
|
||||
tgt = self.decoder_pos(tgt)
|
||||
tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
pred_pad = self.decoder(
|
||||
tgt=tgt,
|
||||
memory=memory,
|
||||
tgt_mask=tgt_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
) # (T, N, C)
|
||||
pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
|
||||
pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C)
|
||||
|
||||
decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
|
||||
|
||||
return decoder_loss
|
||||
|
||||
def decoder_nll(
|
||||
self,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: torch.Tensor,
|
||||
token_ids: List[List[int]],
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
memory:
|
||||
It's the output of the encoder with shape (T, N, C)
|
||||
memory_key_padding_mask:
|
||||
The padding mask from the encoder.
|
||||
token_ids:
|
||||
A list-of-list IDs (e.g., word piece IDs).
|
||||
Each sublist represents an utterance.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
Returns:
|
||||
A 2-D tensor of shape (len(token_ids), max_token_length)
|
||||
representing the cross entropy loss (i.e., negative log-likelihood).
|
||||
"""
|
||||
# The common part between this function and decoder_forward could be
|
||||
# extracted as a separate function.
|
||||
|
||||
ys_in = add_sos(token_ids, sos_id=sos_id)
|
||||
ys_in = [torch.tensor(y) for y in ys_in]
|
||||
ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id)
|
||||
|
||||
ys_out = add_eos(token_ids, eos_id=eos_id)
|
||||
ys_out = [torch.tensor(y) for y in ys_out]
|
||||
ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1)
|
||||
|
||||
device = memory.device
|
||||
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
|
||||
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
|
||||
|
||||
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
|
||||
device
|
||||
)
|
||||
|
||||
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
|
||||
# TODO: Use length information to create the decoder padding mask
|
||||
# We set the first column to False since the first column in ys_in_pad
|
||||
# contains sos_id, which is the same as eos_id in our current setting.
|
||||
tgt_key_padding_mask[:, 0] = False
|
||||
|
||||
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
|
||||
tgt = self.decoder_pos(tgt)
|
||||
tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
|
||||
pred_pad = self.decoder(
|
||||
tgt=tgt,
|
||||
memory=memory,
|
||||
tgt_mask=tgt_mask,
|
||||
tgt_key_padding_mask=tgt_key_padding_mask,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
) # (T, B, F)
|
||||
pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F)
|
||||
pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F)
|
||||
# nll: negative log-likelihood
|
||||
nll = torch.nn.functional.cross_entropy(
|
||||
pred_pad.view(-1, self.decoder_num_class),
|
||||
ys_out_pad.view(-1),
|
||||
ignore_index=-1,
|
||||
reduction="none",
|
||||
)
|
||||
|
||||
nll = nll.view(pred_pad.shape[0], -1)
|
||||
|
||||
return nll
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
"""
|
||||
Modified from torch.nn.TransformerEncoderLayer.
|
||||
Add support of normalize_before,
|
||||
i.e., use layer_norm before the first block.
|
||||
|
||||
Args:
|
||||
d_model:
|
||||
the number of expected features in the input (required).
|
||||
nhead:
|
||||
the number of heads in the multiheadattention models (required).
|
||||
dim_feedforward:
|
||||
the dimension of the feedforward network model (default=2048).
|
||||
dropout:
|
||||
the dropout value (default=0.1).
|
||||
activation:
|
||||
the activation function of intermediate layer, relu or
|
||||
gelu (default=relu).
|
||||
normalize_before:
|
||||
whether to use layer_norm before the first block.
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> out = encoder_layer(src)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
nhead: int,
|
||||
dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
activation: str = "relu",
|
||||
normalize_before: bool = True,
|
||||
) -> None:
|
||||
super(TransformerEncoderLayer, self).__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def __setstate__(self, state):
|
||||
if "activation" not in state:
|
||||
state["activation"] = nn.functional.relu
|
||||
super(TransformerEncoderLayer, self).__setstate__(state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: torch.Tensor,
|
||||
src_mask: Optional[torch.Tensor] = None,
|
||||
src_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder layer (required).
|
||||
src_mask: the mask for the src sequence (optional).
|
||||
src_key_padding_mask: the mask for the src keys per batch (optional)
|
||||
|
||||
Shape:
|
||||
src: (S, N, E).
|
||||
src_mask: (S, S).
|
||||
src_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length,
|
||||
N is the batch size, E is the feature number
|
||||
"""
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm1(src)
|
||||
src2 = self.self_attn(
|
||||
src,
|
||||
src,
|
||||
src,
|
||||
attn_mask=src_mask,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)[0]
|
||||
src = residual + self.dropout1(src2)
|
||||
if not self.normalize_before:
|
||||
src = self.norm1(src)
|
||||
|
||||
residual = src
|
||||
if self.normalize_before:
|
||||
src = self.norm2(src)
|
||||
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
||||
src = residual + self.dropout2(src2)
|
||||
if not self.normalize_before:
|
||||
src = self.norm2(src)
|
||||
return src
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
"""
|
||||
Modified from torch.nn.TransformerDecoderLayer.
|
||||
Add support of normalize_before,
|
||||
i.e., use layer_norm before the first block.
|
||||
|
||||
Args:
|
||||
d_model:
|
||||
the number of expected features in the input (required).
|
||||
nhead:
|
||||
the number of heads in the multiheadattention models (required).
|
||||
dim_feedforward:
|
||||
the dimension of the feedforward network model (default=2048).
|
||||
dropout:
|
||||
the dropout value (default=0.1).
|
||||
activation:
|
||||
the activation function of intermediate layer, relu or
|
||||
gelu (default=relu).
|
||||
|
||||
Examples::
|
||||
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
|
||||
>>> memory = torch.rand(10, 32, 512)
|
||||
>>> tgt = torch.rand(20, 32, 512)
|
||||
>>> out = decoder_layer(tgt, memory)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
nhead: int,
|
||||
dim_feedforward: int = 2048,
|
||||
dropout: float = 0.1,
|
||||
activation: str = "relu",
|
||||
normalize_before: bool = True,
|
||||
) -> None:
|
||||
super(TransformerDecoderLayer, self).__init__()
|
||||
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
|
||||
self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
self.dropout3 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def __setstate__(self, state):
|
||||
if "activation" not in state:
|
||||
state["activation"] = nn.functional.relu
|
||||
super(TransformerDecoderLayer, self).__setstate__(state)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
tgt_mask: Optional[torch.Tensor] = None,
|
||||
memory_mask: Optional[torch.Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Pass the inputs (and mask) through the decoder layer.
|
||||
|
||||
Args:
|
||||
tgt:
|
||||
the sequence to the decoder layer (required).
|
||||
memory:
|
||||
the sequence from the last layer of the encoder (required).
|
||||
tgt_mask:
|
||||
the mask for the tgt sequence (optional).
|
||||
memory_mask:
|
||||
the mask for the memory sequence (optional).
|
||||
tgt_key_padding_mask:
|
||||
the mask for the tgt keys per batch (optional).
|
||||
memory_key_padding_mask:
|
||||
the mask for the memory keys per batch (optional).
|
||||
|
||||
Shape:
|
||||
tgt: (T, N, E).
|
||||
memory: (S, N, E).
|
||||
tgt_mask: (T, T).
|
||||
memory_mask: (T, S).
|
||||
tgt_key_padding_mask: (N, T).
|
||||
memory_key_padding_mask: (N, S).
|
||||
S is the source sequence length, T is the target sequence length,
|
||||
N is the batch size, E is the feature number
|
||||
"""
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm1(tgt)
|
||||
tgt2 = self.self_attn(
|
||||
tgt,
|
||||
tgt,
|
||||
tgt,
|
||||
attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask,
|
||||
)[0]
|
||||
tgt = residual + self.dropout1(tgt2)
|
||||
if not self.normalize_before:
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm2(tgt)
|
||||
tgt2 = self.src_attn(
|
||||
tgt,
|
||||
memory,
|
||||
memory,
|
||||
attn_mask=memory_mask,
|
||||
key_padding_mask=memory_key_padding_mask,
|
||||
)[0]
|
||||
tgt = residual + self.dropout2(tgt2)
|
||||
if not self.normalize_before:
|
||||
tgt = self.norm2(tgt)
|
||||
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm3(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = residual + self.dropout3(tgt2)
|
||||
if not self.normalize_before:
|
||||
tgt = self.norm3(tgt)
|
||||
return tgt
|
||||
|
||||
|
||||
def _get_activation_fn(activation: str):
|
||||
if activation == "relu":
|
||||
return nn.functional.relu
|
||||
elif activation == "gelu":
|
||||
return nn.functional.gelu
|
||||
|
||||
raise RuntimeError(
|
||||
"activation should be relu/gelu, not {}".format(activation)
|
||||
)
|
||||
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
"""This class implements the positional encoding
|
||||
proposed in the following paper:
|
||||
|
||||
- Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
|
||||
|
||||
PE(pos, 2i) = sin(pos / (10000^(2i/d_modle))
|
||||
PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle))
|
||||
|
||||
Note::
|
||||
|
||||
1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model)))
|
||||
= exp(-1* 2i / d_model * log(100000))
|
||||
= exp(2i * -(log(10000) / d_model))
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, dropout: float = 0.1) -> None:
|
||||
"""
|
||||
Args:
|
||||
d_model:
|
||||
Embedding dimension.
|
||||
dropout:
|
||||
Dropout probability to be applied to the output of this module.
|
||||
"""
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.pe = None
|
||||
|
||||
def extend_pe(self, x: torch.Tensor) -> None:
|
||||
"""Extend the time t in the positional encoding if required.
|
||||
|
||||
The shape of `self.pe` is (1, T1, d_model). The shape of the input x
|
||||
is (N, T, d_model). If T > T1, then we change the shape of self.pe
|
||||
to (N, T, d_model). Otherwise, nothing is done.
|
||||
|
||||
Args:
|
||||
x:
|
||||
It is a tensor of shape (N, T, C).
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
if self.pe is not None:
|
||||
if self.pe.size(1) >= x.size(1):
|
||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
# Now pe is of shape (1, T, d_model), where T is x.size(1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Add positional encoding.
|
||||
|
||||
Args:
|
||||
x:
|
||||
Its shape is (N, T, C)
|
||||
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, C)
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale + self.pe[:, : x.size(1), :]
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
class Noam(object):
|
||||
"""
|
||||
Implements Noam optimizer.
|
||||
|
||||
Proposed in
|
||||
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
|
||||
|
||||
Modified from
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
|
||||
|
||||
Args:
|
||||
params:
|
||||
iterable of parameters to optimize or dicts defining parameter groups
|
||||
model_size:
|
||||
attention dimension of the transformer model
|
||||
factor:
|
||||
learning rate factor
|
||||
warm_step:
|
||||
warmup steps
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
model_size: int = 256,
|
||||
factor: float = 10.0,
|
||||
warm_step: int = 25000,
|
||||
weight_decay=0,
|
||||
) -> None:
|
||||
"""Construct an Noam object."""
|
||||
self.optimizer = torch.optim.Adam(
|
||||
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
|
||||
)
|
||||
self._step = 0
|
||||
self.warmup = warm_step
|
||||
self.factor = factor
|
||||
self.model_size = model_size
|
||||
self._rate = 0
|
||||
|
||||
@property
|
||||
def param_groups(self):
|
||||
"""Return param_groups."""
|
||||
return self.optimizer.param_groups
|
||||
|
||||
def step(self):
|
||||
"""Update parameters and rate."""
|
||||
self._step += 1
|
||||
rate = self.rate()
|
||||
for p in self.optimizer.param_groups:
|
||||
p["lr"] = rate
|
||||
self._rate = rate
|
||||
self.optimizer.step()
|
||||
|
||||
def rate(self, step=None):
|
||||
"""Implement `lrate` above."""
|
||||
if step is None:
|
||||
step = self._step
|
||||
return (
|
||||
self.factor
|
||||
* self.model_size ** (-0.5)
|
||||
* min(step ** (-0.5), step * self.warmup ** (-1.5))
|
||||
)
|
||||
|
||||
def zero_grad(self):
|
||||
"""Reset gradient."""
|
||||
self.optimizer.zero_grad()
|
||||
|
||||
def state_dict(self):
|
||||
"""Return state_dict."""
|
||||
return {
|
||||
"_step": self._step,
|
||||
"warmup": self.warmup,
|
||||
"factor": self.factor,
|
||||
"model_size": self.model_size,
|
||||
"_rate": self._rate,
|
||||
"optimizer": self.optimizer.state_dict(),
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
"""Load state_dict."""
|
||||
for key, value in state_dict.items():
|
||||
if key == "optimizer":
|
||||
self.optimizer.load_state_dict(state_dict["optimizer"])
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class LabelSmoothingLoss(nn.Module):
|
||||
"""
|
||||
Label-smoothing loss. KL-divergence between
|
||||
q_{smoothed ground truth prob.}(w)
|
||||
and p_{prob. computed by model}(w) is minimized.
|
||||
Modified from
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
|
||||
|
||||
Args:
|
||||
size: the number of class
|
||||
padding_idx: padding_idx: ignored class id
|
||||
smoothing: smoothing rate (0.0 means the conventional CE)
|
||||
normalize_length: normalize loss by sequence length if True
|
||||
criterion: loss function to be smoothed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
padding_idx: int = -1,
|
||||
smoothing: float = 0.1,
|
||||
normalize_length: bool = False,
|
||||
criterion: nn.Module = nn.KLDivLoss(reduction="none"),
|
||||
) -> None:
|
||||
"""Construct an LabelSmoothingLoss object."""
|
||||
super(LabelSmoothingLoss, self).__init__()
|
||||
self.criterion = criterion
|
||||
self.padding_idx = padding_idx
|
||||
assert 0.0 < smoothing <= 1.0
|
||||
self.confidence = 1.0 - smoothing
|
||||
self.smoothing = smoothing
|
||||
self.size = size
|
||||
self.true_dist = None
|
||||
self.normalize_length = normalize_length
|
||||
|
||||
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute loss between x and target.
|
||||
|
||||
Args:
|
||||
x:
|
||||
prediction of dimension
|
||||
(batch_size, input_length, number_of_classes).
|
||||
target:
|
||||
target masked with self.padding_id of
|
||||
dimension (batch_size, input_length).
|
||||
|
||||
Returns:
|
||||
A scalar tensor containing the loss without normalization.
|
||||
"""
|
||||
assert x.size(2) == self.size
|
||||
# batch_size = x.size(0)
|
||||
x = x.view(-1, self.size)
|
||||
target = target.view(-1)
|
||||
with torch.no_grad():
|
||||
true_dist = x.clone()
|
||||
true_dist.fill_(self.smoothing / (self.size - 1))
|
||||
ignore = target == self.padding_idx # (B,)
|
||||
total = len(target) - ignore.sum().item()
|
||||
target = target.masked_fill(ignore, 0) # avoid -1 index
|
||||
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
||||
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
|
||||
# denom = total if self.normalize_length else batch_size
|
||||
denom = total if self.normalize_length else 1
|
||||
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
|
||||
|
||||
|
||||
def encoder_padding_mask(
|
||||
max_len: int, supervisions: Optional[Supervisions] = None
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Make mask tensor containing indexes of padded part.
|
||||
|
||||
TODO::
|
||||
This function **assumes** that the model uses
|
||||
a subsampling factor of 4. We should remove that
|
||||
assumption later.
|
||||
|
||||
Args:
|
||||
max_len:
|
||||
Maximum length of input features.
|
||||
CAUTION: It is the length after subsampling.
|
||||
supervisions:
|
||||
Supervision in lhotse format.
|
||||
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
|
||||
(CAUTION: It contains length information, i.e., start and number of
|
||||
frames, before subsampling)
|
||||
|
||||
Returns:
|
||||
Tensor: Mask tensor of dimension (batch_size, input_length),
|
||||
True denote the masked indices.
|
||||
"""
|
||||
if supervisions is None:
|
||||
return None
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
supervisions["start_frame"],
|
||||
supervisions["num_frames"],
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
lengths = [
|
||||
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
|
||||
]
|
||||
for idx in range(supervision_segments.size(0)):
|
||||
# Note: TorchScript doesn't allow to unpack tensors as tuples
|
||||
sequence_idx = supervision_segments[idx, 0].item()
|
||||
start_frame = supervision_segments[idx, 1].item()
|
||||
num_frames = supervision_segments[idx, 2].item()
|
||||
lengths[sequence_idx] = start_frame + num_frames
|
||||
|
||||
lengths = [((i - 1) // 2 - 1) // 2 for i in lengths]
|
||||
bs = int(len(lengths))
|
||||
seq_range = torch.arange(0, max_len, dtype=torch.int64)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
|
||||
# Note: TorchScript doesn't implement Tensor.new()
|
||||
seq_length_expand = torch.tensor(
|
||||
lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype
|
||||
).unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def decoder_padding_mask(
|
||||
ys_pad: torch.Tensor, ignore_id: int = -1
|
||||
) -> torch.Tensor:
|
||||
"""Generate a length mask for input.
|
||||
|
||||
The masked position are filled with True,
|
||||
Unmasked positions are filled with False.
|
||||
|
||||
Args:
|
||||
ys_pad:
|
||||
padded tensor of dimension (batch_size, input_length).
|
||||
ignore_id:
|
||||
the ignored number (the padding number) in ys_pad
|
||||
|
||||
Returns:
|
||||
Tensor:
|
||||
a bool tensor of the same shape as the input tensor.
|
||||
"""
|
||||
ys_mask = ys_pad == ignore_id
|
||||
return ys_mask
|
||||
|
||||
|
||||
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
|
||||
"""Generate a square mask for the sequence. The masked positions are
|
||||
filled with float('-inf'). Unmasked positions are filled with float(0.0).
|
||||
The mask can be used for masked self-attention.
|
||||
|
||||
For instance, if sz is 3, it returns::
|
||||
|
||||
tensor([[0., -inf, -inf],
|
||||
[0., 0., -inf],
|
||||
[0., 0., 0]])
|
||||
|
||||
Args:
|
||||
sz: mask size
|
||||
|
||||
Returns:
|
||||
A square mask of dimension (sz, sz)
|
||||
"""
|
||||
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
||||
mask = (
|
||||
mask.float()
|
||||
.masked_fill(mask == 0, float("-inf"))
|
||||
.masked_fill(mask == 1, float(0.0))
|
||||
)
|
||||
return mask
|
||||
|
||||
|
||||
def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
|
||||
"""Prepend sos_id to each utterance.
|
||||
|
||||
Args:
|
||||
token_ids:
|
||||
A list-of-list of token IDs. Each sublist contains
|
||||
token IDs (e.g., word piece IDs) of an utterance.
|
||||
sos_id:
|
||||
The ID of the SOS token.
|
||||
|
||||
Return:
|
||||
Return a new list-of-list, where each sublist starts
|
||||
with SOS ID.
|
||||
"""
|
||||
ans = []
|
||||
for utt in token_ids:
|
||||
ans.append([sos_id] + utt)
|
||||
return ans
|
||||
|
||||
|
||||
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
|
||||
"""Append eos_id to each utterance.
|
||||
|
||||
Args:
|
||||
token_ids:
|
||||
A list-of-list of token IDs. Each sublist contains
|
||||
token IDs (e.g., word piece IDs) of an utterance.
|
||||
eos_id:
|
||||
The ID of the EOS token.
|
||||
|
||||
Return:
|
||||
Return a new list-of-list, where each sublist ends
|
||||
with EOS ID.
|
||||
"""
|
||||
ans = []
|
||||
for utt in token_ids:
|
||||
ans.append(utt + [eos_id])
|
||||
return ans
|
107
egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py
Executable file
107
egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py
Executable file
@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Fangjun Kuang)
|
||||
"""
|
||||
Convert a transcript file containing words to a corpus file containing tokens
|
||||
for LM training with the help of a lexicon.
|
||||
|
||||
If the lexicon contains phones, the resulting LM will be a phone LM; If the
|
||||
lexicon contains word pieces, the resulting LM will be a word piece LM.
|
||||
|
||||
If a word has multiple pronunciations, the one that appears first in the lexicon
|
||||
is kept; others are removed.
|
||||
|
||||
If the input transcript is:
|
||||
|
||||
hello zoo world hello
|
||||
world zoo
|
||||
foo zoo world hellO
|
||||
|
||||
and if the lexicon is
|
||||
|
||||
<UNK> SPN
|
||||
hello h e l l o 2
|
||||
hello h e l l o
|
||||
world w o r l d
|
||||
zoo z o o
|
||||
|
||||
Then the output is
|
||||
|
||||
h e l l o 2 z o o w o r l d h e l l o 2
|
||||
w o r l d z o o
|
||||
SPN z o o w o r l d SPN
|
||||
"""
|
||||
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Dict, List
|
||||
|
||||
from generate_unique_lexicon import filter_multiple_pronunications
|
||||
|
||||
from icefall.lexicon import read_lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--transcript",
|
||||
type=str,
|
||||
help="The input transcript file."
|
||||
"We assume that the transcript file consists of "
|
||||
"lines. Each line consists of space separated words.",
|
||||
)
|
||||
parser.add_argument("--lexicon", type=str, help="The input lexicon file.")
|
||||
parser.add_argument(
|
||||
"--oov", type=str, default="<UNK>", help="The OOV word."
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def process_line(
|
||||
lexicon: Dict[str, List[str]], line: str, oov_token: str
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
lexicon:
|
||||
A dict containing pronunciations. Its keys are words and values
|
||||
are pronunciations (i.e., tokens).
|
||||
line:
|
||||
A line of transcript consisting of space(s) separated words.
|
||||
oov_token:
|
||||
The pronunciation of the oov word if a word in `line` is not present
|
||||
in the lexicon.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
s = ""
|
||||
words = line.strip().split()
|
||||
for i, w in enumerate(words):
|
||||
tokens = lexicon.get(w, oov_token)
|
||||
s += " ".join(tokens)
|
||||
s += " "
|
||||
print(s.strip())
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert Path(args.lexicon).is_file()
|
||||
assert Path(args.transcript).is_file()
|
||||
assert len(args.oov) > 0
|
||||
|
||||
# Only the first pronunciation of a word is kept
|
||||
lexicon = filter_multiple_pronunications(read_lexicon(args.lexicon))
|
||||
|
||||
lexicon = dict(lexicon)
|
||||
|
||||
assert args.oov in lexicon
|
||||
|
||||
oov_token = lexicon[args.oov]
|
||||
|
||||
with open(args.transcript) as f:
|
||||
for line in f:
|
||||
process_line(lexicon=lexicon, line=line, oov_token=oov_token)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
100
egs/librispeech/ASR/local/generate_unique_lexicon.py
Executable file
100
egs/librispeech/ASR/local/generate_unique_lexicon.py
Executable file
@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This file takes as input a lexicon.txt and output a new lexicon,
|
||||
in which each word has a unique pronunciation.
|
||||
|
||||
The way to do this is to keep only the first pronunciation of a word
|
||||
in lexicon.txt.
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from icefall.lexicon import read_lexicon, write_lexicon
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain a file lexicon.txt.
|
||||
This file will generate a new file uniq_lexicon.txt
|
||||
in it.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def filter_multiple_pronunications(
|
||||
lexicon: List[Tuple[str, List[str]]]
|
||||
) -> List[Tuple[str, List[str]]]:
|
||||
"""Remove multiple pronunciations of words from a lexicon.
|
||||
|
||||
If a word has more than one pronunciation in the lexicon, only
|
||||
the first one is kept, while other pronunciations are removed
|
||||
from the lexicon.
|
||||
|
||||
Args:
|
||||
lexicon:
|
||||
The input lexicon, containing a list of (word, [p1, p2, ..., pn]),
|
||||
where "p1, p2, ..., pn" are the pronunciations of the "word".
|
||||
Returns:
|
||||
Return a new lexicon where each word has a unique pronunciation.
|
||||
"""
|
||||
seen = set()
|
||||
ans = []
|
||||
|
||||
for word, tokens in lexicon:
|
||||
if word in seen:
|
||||
continue
|
||||
seen.add(word)
|
||||
ans.append((word, tokens))
|
||||
return ans
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
|
||||
lexicon_filename = lang_dir / "lexicon.txt"
|
||||
|
||||
in_lexicon = read_lexicon(lexicon_filename)
|
||||
|
||||
out_lexicon = filter_multiple_pronunications(in_lexicon)
|
||||
|
||||
write_lexicon(lang_dir / "uniq_lexicon.txt", out_lexicon)
|
||||
|
||||
logging.info(f"Number of entries in lexicon.txt: {len(in_lexicon)}")
|
||||
logging.info(f"Number of entries in uniq_lexicon.txt: {len(out_lexicon)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
@ -33,6 +33,7 @@ consisting of words and tokens (i.e., phones) and does the following:
|
||||
|
||||
5. Generate L_disambig.pt, in k2 format.
|
||||
"""
|
||||
import argparse
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
@ -42,10 +43,37 @@ import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import read_lexicon, write_lexicon
|
||||
from icefall.utils import str2bool
|
||||
|
||||
Lexicon = List[Tuple[str, List[str]]]
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain a file lexicon.txt.
|
||||
Generated files by this script are saved into this directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True for debugging, which will generate
|
||||
a visualization of the lexicon FST.
|
||||
|
||||
Caution: If your lexicon contains hundreds of thousands
|
||||
of lines, please set it to False!
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def write_mapping(filename: str, sym2id: Dict[str, int]) -> None:
|
||||
"""Write a symbol to ID mapping to a file.
|
||||
|
||||
@ -315,8 +343,9 @@ def lexicon_to_fst(
|
||||
|
||||
|
||||
def main():
|
||||
out_dir = Path("data/lang_phone")
|
||||
lexicon_filename = out_dir / "lexicon.txt"
|
||||
args = get_args()
|
||||
lang_dir = Path(args.lang_dir)
|
||||
lexicon_filename = lang_dir / "lexicon.txt"
|
||||
sil_token = "SIL"
|
||||
sil_prob = 0.5
|
||||
|
||||
@ -344,9 +373,9 @@ def main():
|
||||
token2id = generate_id_map(tokens)
|
||||
word2id = generate_id_map(words)
|
||||
|
||||
write_mapping(out_dir / "tokens.txt", token2id)
|
||||
write_mapping(out_dir / "words.txt", word2id)
|
||||
write_lexicon(out_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
write_mapping(lang_dir / "tokens.txt", token2id)
|
||||
write_mapping(lang_dir / "words.txt", word2id)
|
||||
write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig)
|
||||
|
||||
L = lexicon_to_fst(
|
||||
lexicon,
|
||||
@ -364,17 +393,20 @@ def main():
|
||||
sil_prob=sil_prob,
|
||||
need_self_loops=True,
|
||||
)
|
||||
torch.save(L.as_dict(), out_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), out_dir / "L_disambig.pt")
|
||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||
|
||||
if False:
|
||||
# Just for debugging, will remove it
|
||||
L.labels_sym = k2.SymbolTable.from_file(out_dir / "tokens.txt")
|
||||
L.aux_labels_sym = k2.SymbolTable.from_file(out_dir / "words.txt")
|
||||
L_disambig.labels_sym = L.labels_sym
|
||||
L_disambig.aux_labels_sym = L.aux_labels_sym
|
||||
L.draw(out_dir / "L.png", title="L")
|
||||
L_disambig.draw(out_dir / "L_disambig.png", title="L_disambig")
|
||||
if args.debug:
|
||||
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
L.labels_sym = labels_sym
|
||||
L.aux_labels_sym = aux_labels_sym
|
||||
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||
|
||||
L_disambig.labels_sym = labels_sym
|
||||
L_disambig.aux_labels_sym = aux_labels_sym
|
||||
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -49,6 +49,8 @@ from prepare_lang import (
|
||||
write_mapping,
|
||||
)
|
||||
|
||||
from icefall.utils import str2bool
|
||||
|
||||
|
||||
def lexicon_to_fst_no_sil(
|
||||
lexicon: Lexicon,
|
||||
@ -169,6 +171,20 @@ def get_args():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="""True for debugging, which will generate
|
||||
a visualization of the lexicon FST.
|
||||
|
||||
Caution: If your lexicon contains hundreds of thousands
|
||||
of lines, please set it to False!
|
||||
|
||||
See "test/test_bpe_lexicon.py" for usage.
|
||||
""",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -221,6 +237,18 @@ def main():
|
||||
torch.save(L.as_dict(), lang_dir / "L.pt")
|
||||
torch.save(L_disambig.as_dict(), lang_dir / "L_disambig.pt")
|
||||
|
||||
if args.debug:
|
||||
labels_sym = k2.SymbolTable.from_file(lang_dir / "tokens.txt")
|
||||
aux_labels_sym = k2.SymbolTable.from_file(lang_dir / "words.txt")
|
||||
|
||||
L.labels_sym = labels_sym
|
||||
L.aux_labels_sym = aux_labels_sym
|
||||
L.draw(f"{lang_dir / 'L.svg'}", title="L.pt")
|
||||
|
||||
L_disambig.labels_sym = labels_sym
|
||||
L_disambig.aux_labels_sym = aux_labels_sym
|
||||
L_disambig.draw(f"{lang_dir / 'L_disambig.svg'}", title="L_disambig.pt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -38,10 +38,17 @@ def get_args():
|
||||
"--lang-dir",
|
||||
type=str,
|
||||
help="""Input and output directory.
|
||||
It should contain the training corpus: train.txt.
|
||||
It should contain the training corpus: transcript_words.txt.
|
||||
The generated bpe.model is saved to this directory.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--transcript",
|
||||
type=str,
|
||||
help="Training transcript.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vocab-size",
|
||||
type=int,
|
||||
@ -59,7 +66,7 @@ def main():
|
||||
model_type = "unigram"
|
||||
|
||||
model_prefix = f"{lang_dir}/{model_type}_{vocab_size}"
|
||||
train_text = f"{lang_dir}/train.txt"
|
||||
train_text = args.transcript
|
||||
character_coverage = 1.0
|
||||
input_sentence_size = 100000000
|
||||
|
||||
|
@ -116,17 +116,19 @@ fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare phone based lang"
|
||||
mkdir -p data/lang_phone
|
||||
lang_dir=data/lang_phone
|
||||
mkdir -p $lang_dir
|
||||
|
||||
(echo '!SIL SIL'; echo '<SPOKEN_NOISE> SPN'; echo '<UNK> SPN'; ) |
|
||||
cat - $dl_dir/lm/librispeech-lexicon.txt |
|
||||
sort | uniq > data/lang_phone/lexicon.txt
|
||||
sort | uniq > $lang_dir/lexicon.txt
|
||||
|
||||
if [ ! -f data/lang_phone/L_disambig.pt ]; then
|
||||
./local/prepare_lang.py
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang.py --lang-dir $lang_dir
|
||||
fi
|
||||
fi
|
||||
|
||||
|
||||
if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
log "Stage 6: Prepare BPE based lang"
|
||||
|
||||
@ -137,7 +139,7 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
# so that the two can share G.pt later.
|
||||
cp data/lang_phone/words.txt $lang_dir
|
||||
|
||||
if [ ! -f $lang_dir/train.txt ]; then
|
||||
if [ ! -f $lang_dir/transcript_words.txt ]; then
|
||||
log "Generate data for BPE training"
|
||||
files=$(
|
||||
find "$dl_dir/LibriSpeech/train-clean-100" -name "*.trans.txt"
|
||||
@ -146,12 +148,13 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
)
|
||||
for f in ${files[@]}; do
|
||||
cat $f | cut -d " " -f 2-
|
||||
done > $lang_dir/train.txt
|
||||
done > $lang_dir/transcript_words.txt
|
||||
fi
|
||||
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir $lang_dir \
|
||||
--vocab-size $vocab_size
|
||||
--vocab-size $vocab_size \
|
||||
--transcript $lang_dir/transcript_words.txt
|
||||
|
||||
if [ ! -f $lang_dir/L_disambig.pt ]; then
|
||||
./local/prepare_lang_bpe.py --lang-dir $lang_dir
|
||||
@ -160,7 +163,38 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then
|
||||
fi
|
||||
|
||||
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
log "Stage 7: Prepare G"
|
||||
log "Stage 7: Prepare bigram P"
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
lang_dir=data/lang_bpe_${vocab_size}
|
||||
|
||||
if [ ! -f $lang_dir/transcript_tokens.txt ]; then
|
||||
./local/convert_transcript_words_to_tokens.py \
|
||||
--lexicon $lang_dir/lexicon.txt \
|
||||
--transcript $lang_dir/transcript_words.txt \
|
||||
--oov "<UNK>" \
|
||||
> $lang_dir/transcript_tokens.txt
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/P.arpa ]; then
|
||||
./shared/make_kn_lm.py \
|
||||
-ngram-order 2 \
|
||||
-text $lang_dir/transcript_tokens.txt \
|
||||
-lm $lang_dir/P.arpa
|
||||
fi
|
||||
|
||||
if [ ! -f $lang_dir/P.fst.txt ]; then
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="$lang_dir/tokens.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=2 \
|
||||
$lang_dir/P.arpa > $lang_dir/P.fst.txt
|
||||
fi
|
||||
done
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "Stage 8: Prepare G"
|
||||
# We assume you have install kaldilm, if not, please install
|
||||
# it using: pip install kaldilm
|
||||
|
||||
@ -184,8 +218,8 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 8 ] && [ $stop_stage -ge 8 ]; then
|
||||
log "Stage 8: Compile HLG"
|
||||
if [ $stage -le 9 ] && [ $stop_stage -ge 9 ]; then
|
||||
log "Stage 9: Compile HLG"
|
||||
./local/compile_hlg.py --lang-dir data/lang_phone
|
||||
|
||||
for vocab_size in ${vocab_sizes[@]}; do
|
||||
|
142
icefall/ali.py
Normal file
142
icefall/ali.py
Normal file
@ -0,0 +1,142 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
|
||||
def save_alignments(
|
||||
alignments: Dict[str, List[int]],
|
||||
subsampling_factor: int,
|
||||
filename: str,
|
||||
) -> None:
|
||||
"""Save alignments to a file.
|
||||
|
||||
Args:
|
||||
alignments:
|
||||
A dict containing alignments. Keys of the dict are utterances and
|
||||
values are the corresponding framewise alignments after subsampling.
|
||||
subsampling_factor:
|
||||
The subsampling factor of the model.
|
||||
filename:
|
||||
Path to save the alignments.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
ali_dict = {
|
||||
"subsampling_factor": subsampling_factor,
|
||||
"alignments": alignments,
|
||||
}
|
||||
torch.save(ali_dict, filename)
|
||||
|
||||
|
||||
def load_alignments(filename: str) -> Tuple[int, Dict[str, List[int]]]:
|
||||
"""Load alignments from a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Path to the file containing alignment information.
|
||||
The file should be saved by :func:`save_alignments`.
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- subsampling_factor: The subsampling_factor used to compute
|
||||
the alignments.
|
||||
- alignments: A dict containing utterances and their corresponding
|
||||
framewise alignment, after subsampling.
|
||||
"""
|
||||
ali_dict = torch.load(filename)
|
||||
subsampling_factor = ali_dict["subsampling_factor"]
|
||||
alignments = ali_dict["alignments"]
|
||||
return subsampling_factor, alignments
|
||||
|
||||
|
||||
def convert_alignments_to_tensor(
|
||||
alignments: Dict[str, List[int]], device: torch.device
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Convert alignments from list of int to a 1-D torch.Tensor.
|
||||
|
||||
Args:
|
||||
alignments:
|
||||
A dict containing alignments. Keys are utterance IDs and
|
||||
values are their corresponding frame-wise alignments.
|
||||
device:
|
||||
The device to move the alignments to.
|
||||
Returns:
|
||||
Return a dict using 1-D torch.Tensor to store the alignments.
|
||||
The dtype of the tensor are `torch.int64`. We choose `torch.int64`
|
||||
because `torch.nn.functional.one_hot` requires that.
|
||||
"""
|
||||
ans = {}
|
||||
for utt_id, ali in alignments.items():
|
||||
ali = torch.tensor(ali, dtype=torch.int64, device=device)
|
||||
ans[utt_id] = ali
|
||||
return ans
|
||||
|
||||
|
||||
def lookup_alignments(
|
||||
cut_ids: List[str],
|
||||
alignments: Dict[str, torch.Tensor],
|
||||
num_classes: int,
|
||||
log_score: float = -10,
|
||||
) -> torch.Tensor:
|
||||
"""Return a mask constructed from alignments by a list of cut IDs.
|
||||
|
||||
The returned mask is a 3-D tensor of shape (N, T, C). For each frame,
|
||||
i.e., each row, of the returned mask, positions not corresponding to
|
||||
the alignments are filled with `log_score`, while the position
|
||||
specified by the alignment is filled with 0. For instance, if the alignments
|
||||
of two utterances are:
|
||||
|
||||
[ [1, 3, 2], [1, 0, 4, 2] ]
|
||||
num_classes is 5 and log_score is -10, then the returned mask is
|
||||
|
||||
[
|
||||
[[-10, 0, -10, -10, -10],
|
||||
[-10, -10, -10, 0, -10],
|
||||
[-10, -10, 0, -10, -10],
|
||||
[0, -10, -10, -10, -10]],
|
||||
[[-10, 0, -10, -10, -10],
|
||||
[0, -10, -10, -10, -10],
|
||||
[-10, -10, -10, -10, 0],
|
||||
[-10, -10, 0, -10, -10]]
|
||||
]
|
||||
Note: We pad the alignment of the first utterance with 0.
|
||||
|
||||
Args:
|
||||
cut_ids:
|
||||
A list of utterance IDs.
|
||||
alignments:
|
||||
A dict containing alignments. The keys are utterance IDs and the values
|
||||
are framewise alignments.
|
||||
num_classes:
|
||||
The max token ID + 1 that appears in the alignments.
|
||||
log_score:
|
||||
Positions in the returned tensor not corresponding to the alignments
|
||||
are filled with this value.
|
||||
Returns:
|
||||
Return a 3-D torch.float32 tensor of shape (N, T, C).
|
||||
"""
|
||||
# We assume all utterances have their alignments.
|
||||
ali = [alignments[cut_id] for cut_id in cut_ids]
|
||||
padded_ali = pad_sequence(ali, batch_first=True, padding_value=0)
|
||||
padded_one_hot = torch.nn.functional.one_hot(
|
||||
padded_ali,
|
||||
num_classes=num_classes,
|
||||
)
|
||||
mask = (1 - padded_one_hot) * float(log_score)
|
||||
return mask
|
@ -84,6 +84,69 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None:
|
||||
f.write(f"{word} {' '.join(tokens)}\n")
|
||||
|
||||
|
||||
def convert_lexicon_to_ragged(
|
||||
filename: str, word_table: k2.SymbolTable, token_table: k2.SymbolTable
|
||||
) -> k2.RaggedTensor:
|
||||
"""Read a lexicon and convert it to a ragged tensor.
|
||||
|
||||
The ragged tensor has two axes: [word][token].
|
||||
|
||||
Caution:
|
||||
We assume that each word has a unique pronunciation.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the lexicon. It has a format that can be read
|
||||
by :func:`read_lexicon`.
|
||||
word_table:
|
||||
The word symbol table.
|
||||
token_table:
|
||||
The token symbol table.
|
||||
Returns:
|
||||
A k2 ragged tensor with two axes [word][token].
|
||||
"""
|
||||
disambig_id = word_table["#0"]
|
||||
# We reuse the same words.txt from the phone based lexicon
|
||||
# so that we can share the same G.fst. Here, we have to
|
||||
# exclude some words present only in the phone based lexicon.
|
||||
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
|
||||
|
||||
# epsilon is not a word, but it occupies a position
|
||||
#
|
||||
row_splits = [0]
|
||||
token_ids_list = []
|
||||
|
||||
lexicon_tmp = read_lexicon(filename)
|
||||
lexicon = dict(lexicon_tmp)
|
||||
if len(lexicon_tmp) != len(lexicon):
|
||||
raise RuntimeError(
|
||||
"It's assumed that each word has a unique pronunciation"
|
||||
)
|
||||
|
||||
for i in range(disambig_id):
|
||||
w = word_table[i]
|
||||
if w in excluded_words:
|
||||
row_splits.append(row_splits[-1])
|
||||
continue
|
||||
tokens = lexicon[w]
|
||||
token_ids = [token_table[k] for k in tokens]
|
||||
|
||||
row_splits.append(row_splits[-1] + len(token_ids))
|
||||
token_ids_list.extend(token_ids)
|
||||
|
||||
cached_tot_size = row_splits[-1]
|
||||
row_splits = torch.tensor(row_splits, dtype=torch.int32)
|
||||
|
||||
shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits,
|
||||
None,
|
||||
cached_tot_size,
|
||||
)
|
||||
values = torch.tensor(token_ids_list, dtype=torch.int32)
|
||||
|
||||
return k2.RaggedTensor(shape, values)
|
||||
|
||||
|
||||
class Lexicon(object):
|
||||
"""Phone based lexicon."""
|
||||
|
||||
@ -95,7 +158,7 @@ class Lexicon(object):
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
Path to the lang director. It is expected to contain the following
|
||||
Path to the lang directory. It is expected to contain the following
|
||||
files:
|
||||
- tokens.txt
|
||||
- words.txt
|
||||
@ -119,7 +182,7 @@ class Lexicon(object):
|
||||
torch.save(L_inv.as_dict(), lang_dir / "Linv.pt")
|
||||
|
||||
# We save L_inv instead of L because it will be used to intersect with
|
||||
# transcript, both of whose labels are word IDs.
|
||||
# transcript FSAs, both of whose labels are word IDs.
|
||||
self.L_inv = L_inv
|
||||
self.disambig_pattern = disambig_pattern
|
||||
|
||||
@ -142,69 +205,66 @@ class Lexicon(object):
|
||||
return ans
|
||||
|
||||
|
||||
class BpeLexicon(Lexicon):
|
||||
class UniqLexicon(Lexicon):
|
||||
def __init__(
|
||||
self,
|
||||
lang_dir: Path,
|
||||
uniq_filename: str = "uniq_lexicon.txt",
|
||||
disambig_pattern: str = re.compile(r"^#\d+$"),
|
||||
):
|
||||
"""
|
||||
Refer to the help information in Lexicon.__init__.
|
||||
|
||||
uniq_filename: It is assumed to be inside the given `lang_dir`.
|
||||
|
||||
Each word in the lexicon is assumed to have a unique pronunciation.
|
||||
"""
|
||||
lang_dir = Path(lang_dir)
|
||||
super().__init__(lang_dir=lang_dir, disambig_pattern=disambig_pattern)
|
||||
|
||||
self.ragged_lexicon = self.convert_lexicon_to_ragged(
|
||||
lang_dir / "lexicon.txt"
|
||||
self.ragged_lexicon = convert_lexicon_to_ragged(
|
||||
filename=lang_dir / uniq_filename,
|
||||
word_table=self.word_table,
|
||||
token_table=self.token_table,
|
||||
)
|
||||
# TODO: should we move it to a certain device ?
|
||||
|
||||
def convert_lexicon_to_ragged(self, filename: str) -> k2.RaggedTensor:
|
||||
"""Read a BPE lexicon from file and convert it to a
|
||||
k2 ragged tensor.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
Filename of the BPE lexicon, e.g., data/lang/bpe/lexicon.txt
|
||||
Returns:
|
||||
A k2 ragged tensor with two axes [word_id]
|
||||
def texts_to_token_ids(
|
||||
self, texts: List[str], oov: str = "<UNK>"
|
||||
) -> k2.RaggedTensor:
|
||||
"""
|
||||
disambig_id = self.word_table["#0"]
|
||||
# We reuse the same words.txt from the phone based lexicon
|
||||
# so that we can share the same G.fst. Here, we have to
|
||||
# exclude some words present only in the phone based lexicon.
|
||||
excluded_words = ["<eps>", "!SIL", "<SPOKEN_NOISE>"]
|
||||
Args:
|
||||
texts:
|
||||
A list of transcripts. Each transcript contains space(s)
|
||||
separated words. An example texts is::
|
||||
|
||||
# epsilon is not a word, but it occupies on position
|
||||
#
|
||||
row_splits = [0]
|
||||
token_ids = []
|
||||
['HELLO k2', 'HELLO icefall']
|
||||
oov:
|
||||
The OOV word. If a word in `texts` is not in the lexicon, it is
|
||||
replaced with `oov`.
|
||||
Returns:
|
||||
Return a ragged int tensor with 2 axes [utterance][token_id]
|
||||
"""
|
||||
oov_id = self.word_table[oov]
|
||||
|
||||
lexicon = read_lexicon(filename)
|
||||
lexicon = dict(lexicon)
|
||||
word_ids_list = []
|
||||
for text in texts:
|
||||
word_ids = []
|
||||
for word in text.split():
|
||||
if word in self.word_table:
|
||||
word_ids.append(self.word_table[word])
|
||||
else:
|
||||
word_ids.append(oov_id)
|
||||
word_ids_list.append(word_ids)
|
||||
ragged_indexes = k2.RaggedTensor(word_ids_list, dtype=torch.int32)
|
||||
ans = self.ragged_lexicon.index(ragged_indexes)
|
||||
ans = ans.remove_axis(ans.num_axes - 2)
|
||||
return ans
|
||||
|
||||
for i in range(disambig_id):
|
||||
w = self.word_table[i]
|
||||
if w in excluded_words:
|
||||
row_splits.append(row_splits[-1])
|
||||
continue
|
||||
pieces = lexicon[w]
|
||||
piece_ids = [self.token_table[k] for k in pieces]
|
||||
def words_to_token_ids(self, words: List[str]) -> k2.RaggedTensor:
|
||||
"""Convert a list of words to a ragged tensor containing token IDs.
|
||||
|
||||
row_splits.append(row_splits[-1] + len(piece_ids))
|
||||
token_ids.extend(piece_ids)
|
||||
|
||||
cached_tot_size = row_splits[-1]
|
||||
row_splits = torch.tensor(row_splits, dtype=torch.int32)
|
||||
|
||||
shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=cached_tot_size
|
||||
)
|
||||
values = torch.tensor(token_ids, dtype=torch.int32)
|
||||
|
||||
return k2.RaggedTensor(shape, values)
|
||||
|
||||
def words_to_piece_ids(self, words: List[str]) -> k2.RaggedTensor:
|
||||
"""Convert a list of words to a ragged tensor contained
|
||||
word piece IDs.
|
||||
We assume there are no OOVs in "words".
|
||||
"""
|
||||
word_ids = [self.word_table[w] for w in words]
|
||||
word_ids = torch.tensor(word_ids, dtype=torch.int32)
|
||||
|
232
icefall/mmi.py
Normal file
232
icefall/mmi.py
Normal file
@ -0,0 +1,232 @@
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
||||
|
||||
|
||||
def _compute_mmi_loss_exact_optimized(
|
||||
dense_fsa_vec: k2.DenseFsaVec,
|
||||
texts: List[str],
|
||||
graph_compiler: MmiTrainingGraphCompiler,
|
||||
den_scale: float = 1.0,
|
||||
beam_size: float = 8.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
The function name contains `exact`, which means it uses a version of
|
||||
intersection without pruning.
|
||||
|
||||
`optimized` in the function name means this function is optimized
|
||||
in that it calls k2.intersect_dense only once
|
||||
|
||||
Note:
|
||||
It is faster at the cost of using more memory.
|
||||
|
||||
Args:
|
||||
dense_fsa_vec:
|
||||
It contains the neural network output.
|
||||
texts:
|
||||
The transcript. Each element consists of space(s) separated words.
|
||||
graph_compiler:
|
||||
Used to build num_graphs and den_graphs
|
||||
den_scale:
|
||||
The scale applied to the denominator tot_scores.
|
||||
Returns:
|
||||
Return a scalar loss. It is the sum over utterances in a batch,
|
||||
without normalization.
|
||||
"""
|
||||
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
|
||||
|
||||
device = num_graphs.device
|
||||
|
||||
num_fsas = num_graphs.shape[0]
|
||||
assert dense_fsa_vec.dim0() == num_fsas
|
||||
|
||||
assert den_graphs.shape[0] == 1
|
||||
|
||||
# The motivation to concatenate num_graphs and den_graphs
|
||||
# is to reduce the number of calls to k2.intersect_dense.
|
||||
num_den_graphs = k2.cat([num_graphs, den_graphs])
|
||||
|
||||
# NOTE: The a_to_b_map in k2.intersect_dense must be sorted
|
||||
# so the following reorders num_den_graphs.
|
||||
#
|
||||
# The following code computes a_to_b_map
|
||||
|
||||
# [0, 1, 2, ... ]
|
||||
num_graphs_indexes = torch.arange(num_fsas, dtype=torch.int32)
|
||||
|
||||
# [num_fsas, num_fsas, num_fsas, ... ]
|
||||
den_graphs_indexes = torch.tensor([num_fsas] * num_fsas, dtype=torch.int32)
|
||||
|
||||
# [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ]
|
||||
num_den_graphs_indexes = (
|
||||
torch.stack([num_graphs_indexes, den_graphs_indexes])
|
||||
.t()
|
||||
.reshape(-1)
|
||||
.to(device)
|
||||
)
|
||||
|
||||
num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes)
|
||||
|
||||
# [[0, 1, 2, ...]]
|
||||
a_to_b_map = torch.arange(num_fsas, dtype=torch.int32).reshape(1, -1)
|
||||
|
||||
# [[0, 1, 2, ...]] -> [0, 0, 1, 1, 2, 2, ... ]
|
||||
a_to_b_map = a_to_b_map.repeat(2, 1).t().reshape(-1).to(device)
|
||||
|
||||
num_den_lats = k2.intersect_dense(
|
||||
num_den_reordered_graphs,
|
||||
dense_fsa_vec,
|
||||
output_beam=beam_size,
|
||||
a_to_b_map=a_to_b_map,
|
||||
)
|
||||
|
||||
num_den_tot_scores = num_den_lats.get_tot_scores(
|
||||
log_semiring=True, use_double_scores=True
|
||||
)
|
||||
|
||||
num_tot_scores = num_den_tot_scores[::2]
|
||||
den_tot_scores = num_den_tot_scores[1::2]
|
||||
|
||||
tot_scores = num_tot_scores - den_scale * den_tot_scores
|
||||
loss = -1 * tot_scores.sum()
|
||||
return loss
|
||||
|
||||
|
||||
def _compute_mmi_loss_exact_non_optimized(
|
||||
dense_fsa_vec: k2.DenseFsaVec,
|
||||
texts: List[str],
|
||||
graph_compiler: MmiTrainingGraphCompiler,
|
||||
den_scale: float = 1.0,
|
||||
beam_size: float = 8.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
See :func:`_compute_mmi_loss_exact_optimized` for the meaning
|
||||
of the arguments.
|
||||
|
||||
It's more readable, though it invokes k2.intersect_dense twice.
|
||||
|
||||
Note:
|
||||
It uses less memory at the cost of speed. It is slower.
|
||||
"""
|
||||
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
|
||||
|
||||
# TODO: pass output_beam as function argument
|
||||
num_lats = k2.intersect_dense(
|
||||
num_graphs, dense_fsa_vec, output_beam=beam_size
|
||||
)
|
||||
den_lats = k2.intersect_dense(
|
||||
den_graphs, dense_fsa_vec, output_beam=beam_size
|
||||
)
|
||||
|
||||
num_tot_scores = num_lats.get_tot_scores(
|
||||
log_semiring=True, use_double_scores=True
|
||||
)
|
||||
|
||||
den_tot_scores = den_lats.get_tot_scores(
|
||||
log_semiring=True, use_double_scores=True
|
||||
)
|
||||
|
||||
tot_scores = num_tot_scores - den_scale * den_tot_scores
|
||||
|
||||
loss = -1 * tot_scores.sum()
|
||||
return loss
|
||||
|
||||
|
||||
def _compute_mmi_loss_pruned(
|
||||
dense_fsa_vec: k2.DenseFsaVec,
|
||||
texts: List[str],
|
||||
graph_compiler: MmiTrainingGraphCompiler,
|
||||
den_scale: float = 1.0,
|
||||
beam_size: float = 8.0,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
See :func:`_compute_mmi_loss_exact_optimized` for the meaning
|
||||
of the arguments.
|
||||
|
||||
`pruned` means it uses k2.intersect_dense_pruned
|
||||
|
||||
Note:
|
||||
It uses the least amount of memory, but the loss is not exact due
|
||||
to pruning.
|
||||
"""
|
||||
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=False)
|
||||
|
||||
num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=10.0)
|
||||
|
||||
# the values for search_beam/output_beam/min_active_states/max_active_states
|
||||
# are not tuned. You may want to tune them.
|
||||
den_lats = k2.intersect_dense_pruned(
|
||||
den_graphs,
|
||||
dense_fsa_vec,
|
||||
search_beam=20.0,
|
||||
output_beam=beam_size,
|
||||
min_active_states=30,
|
||||
max_active_states=10000,
|
||||
)
|
||||
|
||||
num_tot_scores = num_lats.get_tot_scores(
|
||||
log_semiring=True, use_double_scores=True
|
||||
)
|
||||
|
||||
den_tot_scores = den_lats.get_tot_scores(
|
||||
log_semiring=True, use_double_scores=True
|
||||
)
|
||||
|
||||
tot_scores = num_tot_scores - den_scale * den_tot_scores
|
||||
|
||||
loss = -1 * tot_scores.sum()
|
||||
return loss
|
||||
|
||||
|
||||
class LFMMILoss(nn.Module):
|
||||
"""
|
||||
Computes Lattice-Free Maximum Mutual Information (LFMMI) loss.
|
||||
|
||||
TODO: more detailed description
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
graph_compiler: MmiTrainingGraphCompiler,
|
||||
use_pruned_intersect: bool = False,
|
||||
den_scale: float = 1.0,
|
||||
beam_size: float = 8.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.graph_compiler = graph_compiler
|
||||
self.den_scale = den_scale
|
||||
self.use_pruned_intersect = use_pruned_intersect
|
||||
self.beam_size = beam_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
dense_fsa_vec: k2.DenseFsaVec,
|
||||
texts: List[str],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
dense_fsa_vec:
|
||||
It contains the neural network output.
|
||||
texts:
|
||||
A list of strings. Each string contains space(s) separated words.
|
||||
Returns:
|
||||
Return a scalar loss. It is the sum over utterances in a batch,
|
||||
without normalization.
|
||||
"""
|
||||
if self.use_pruned_intersect:
|
||||
func = _compute_mmi_loss_pruned
|
||||
else:
|
||||
func = _compute_mmi_loss_exact_non_optimized
|
||||
# func = _compute_mmi_loss_exact_optimized
|
||||
|
||||
return func(
|
||||
dense_fsa_vec=dense_fsa_vec,
|
||||
texts=texts,
|
||||
graph_compiler=self.graph_compiler,
|
||||
den_scale=self.den_scale,
|
||||
beam_size=self.beam_size,
|
||||
)
|
221
icefall/mmi_graph_compiler.py
Normal file
221
icefall/mmi_graph_compiler.py
Normal file
@ -0,0 +1,221 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Iterable, List, Tuple, Union
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from icefall.lexicon import UniqLexicon
|
||||
|
||||
|
||||
class MmiTrainingGraphCompiler(object):
|
||||
def __init__(
|
||||
self,
|
||||
lang_dir: Path,
|
||||
uniq_filename: str = "uniq_lexicon.txt",
|
||||
device: Union[str, torch.device] = "cpu",
|
||||
oov: str = "<UNK>",
|
||||
sos_id: int = 1,
|
||||
eos_id: int = 1,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
lang_dir:
|
||||
Path to the lang directory. It is expected to contain the
|
||||
following files::
|
||||
|
||||
- tokens.txt
|
||||
- words.txt
|
||||
- P.fst.txt
|
||||
|
||||
The above files are generated by the script `prepare.sh`. You
|
||||
should have run it before running the training code.
|
||||
uniq_filename:
|
||||
File name to the lexicon in which every word has exactly one
|
||||
pronunciation. We assume this file is inside the given `lang_dir`.
|
||||
|
||||
device:
|
||||
It indicates CPU or CUDA.
|
||||
oov:
|
||||
Out of vocabulary word. When a word in the transcript
|
||||
does not exist in the lexicon, it is replaced with `oov`.
|
||||
"""
|
||||
self.lang_dir = Path(lang_dir)
|
||||
self.lexicon = UniqLexicon(lang_dir, uniq_filename=uniq_filename)
|
||||
self.device = torch.device(device)
|
||||
|
||||
self.L_inv = self.lexicon.L_inv.to(self.device)
|
||||
|
||||
self.oov_id = self.lexicon.word_table[oov]
|
||||
self.sos_id = sos_id
|
||||
self.eos_id = eos_id
|
||||
|
||||
self.build_ctc_topo_P()
|
||||
|
||||
def build_ctc_topo_P(self):
|
||||
"""Built ctc_topo_P, the composition result of
|
||||
ctc_topo and P, where P is a pre-trained bigram
|
||||
word piece LM.
|
||||
"""
|
||||
# Note: there is no need to save a pre-compiled P and ctc_topo
|
||||
# as it is very fast to generate them.
|
||||
logging.info(f"Loading P from {self.lang_dir/'P.fst.txt'}")
|
||||
with open(self.lang_dir / "P.fst.txt") as f:
|
||||
# P is not an acceptor because there is
|
||||
# a back-off state, whose incoming arcs
|
||||
# have label #0 and aux_label 0 (i.e., <eps>).
|
||||
P = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
|
||||
first_token_disambig_id = self.lexicon.token_table["#0"]
|
||||
|
||||
# P.aux_labels is not needed in later computations, so
|
||||
# remove it here.
|
||||
del P.aux_labels
|
||||
# CAUTION: The following line is crucial.
|
||||
# Arcs entering the back-off state have label equal to #0.
|
||||
# We have to change it to 0 here.
|
||||
P.labels[P.labels >= first_token_disambig_id] = 0
|
||||
|
||||
P = k2.remove_epsilon(P)
|
||||
P = k2.arc_sort(P)
|
||||
P = P.to(self.device)
|
||||
# Add epsilon self-loops to P because we want the
|
||||
# following operation "k2.intersect" to run on GPU.
|
||||
P_with_self_loops = k2.add_epsilon_self_loops(P)
|
||||
|
||||
max_token_id = max(self.lexicon.tokens)
|
||||
logging.info(
|
||||
f"Building ctc_topo (modified=False). max_token_id: {max_token_id}"
|
||||
)
|
||||
ctc_topo = k2.ctc_topo(max_token_id, modified=False, device=self.device)
|
||||
|
||||
ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())
|
||||
|
||||
logging.info("Building ctc_topo_P")
|
||||
ctc_topo_P = k2.intersect(
|
||||
ctc_topo_inv, P_with_self_loops, treat_epsilons_specially=False
|
||||
).invert()
|
||||
|
||||
self.ctc_topo_P = k2.arc_sort(ctc_topo_P)
|
||||
logging.info(f"ctc_topo_P num_arcs: {self.ctc_topo_P.num_arcs}")
|
||||
|
||||
def compile(
|
||||
self, texts: Iterable[str], replicate_den: bool = True
|
||||
) -> Tuple[k2.Fsa, k2.Fsa]:
|
||||
"""Create numerator and denominator graphs from transcripts
|
||||
and the bigram phone LM.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
A list of transcripts. Within a transcript, words are
|
||||
separated by spaces. An example `texts` is given below::
|
||||
|
||||
["Hello icefall", "LF-MMI training with icefall using k2"]
|
||||
|
||||
replicate_den:
|
||||
If True, the returned den_graph is replicated to match the number
|
||||
of FSAs in the returned num_graph; if False, the returned den_graph
|
||||
contains only a single FSA
|
||||
Returns:
|
||||
A tuple (num_graph, den_graph), where
|
||||
|
||||
- `num_graph` is the numerator graph. It is an FsaVec with
|
||||
shape `(len(texts), None, None)`.
|
||||
|
||||
- `den_graph` is the denominator graph. It is an FsaVec
|
||||
with the same shape of the `num_graph` if replicate_den is
|
||||
True; otherwise, it is an FsaVec containing only a single FSA.
|
||||
"""
|
||||
transcript_fsa = self.build_transcript_fsa(texts)
|
||||
|
||||
# remove word IDs from transcript_fsa since it is not needed
|
||||
del transcript_fsa.aux_labels
|
||||
# NOTE: You can comment out the above statement
|
||||
# if you want to run test/test_mmi_graph_compiler.py
|
||||
|
||||
transcript_fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(
|
||||
transcript_fsa
|
||||
)
|
||||
|
||||
transcript_fsa_with_self_loops = k2.arc_sort(
|
||||
transcript_fsa_with_self_loops
|
||||
)
|
||||
|
||||
num = k2.compose(
|
||||
self.ctc_topo_P,
|
||||
transcript_fsa_with_self_loops,
|
||||
treat_epsilons_specially=False,
|
||||
)
|
||||
|
||||
# CAUTION: Due to the presence of P,
|
||||
# the resulting `num` may not be connected
|
||||
num = k2.connect(num)
|
||||
|
||||
num = k2.arc_sort(num)
|
||||
|
||||
ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P])
|
||||
if replicate_den:
|
||||
indexes = torch.zeros(
|
||||
len(texts), dtype=torch.int32, device=self.device
|
||||
)
|
||||
den = k2.index_fsa(ctc_topo_P_vec, indexes)
|
||||
else:
|
||||
den = ctc_topo_P_vec
|
||||
|
||||
return num, den
|
||||
|
||||
def build_transcript_fsa(self, texts: List[str]) -> k2.Fsa:
|
||||
"""Convert transcripts to an FsaVec with the help of a lexicon
|
||||
and word symbol table.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
Each element is a transcript containing words separated by space(s).
|
||||
For instance, it may be 'HELLO icefall', which contains
|
||||
two words.
|
||||
|
||||
Returns:
|
||||
Return an FST (FsaVec) corresponding to the transcript.
|
||||
Its `labels` is token IDs and `aux_labels` is word IDs.
|
||||
"""
|
||||
word_ids_list = []
|
||||
for text in texts:
|
||||
word_ids = []
|
||||
for word in text.split():
|
||||
if word in self.lexicon.word_table:
|
||||
word_ids.append(self.lexicon.word_table[word])
|
||||
else:
|
||||
word_ids.append(self.oov_id)
|
||||
word_ids_list.append(word_ids)
|
||||
|
||||
fsa = k2.linear_fsa(word_ids_list, self.device)
|
||||
fsa = k2.add_epsilon_self_loops(fsa)
|
||||
|
||||
# The reason to use `invert_()` at the end is as follows:
|
||||
#
|
||||
# (1) The `labels` of L_inv is word IDs and `aux_labels` is token IDs
|
||||
# (2) `fsa.labels` is word IDs
|
||||
# (3) after intersection, the `labels` is still word IDs
|
||||
# (4) after `invert_()`, the `labels` is token IDs
|
||||
# and `aux_labels` is word IDs
|
||||
transcript_fsa = k2.intersect(
|
||||
self.L_inv, fsa, treat_epsilons_specially=False
|
||||
).invert_()
|
||||
transcript_fsa = k2.arc_sort(transcript_fsa)
|
||||
return transcript_fsa
|
||||
|
||||
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
||||
"""Convert a list of texts to a list-of-list of piece IDs.
|
||||
|
||||
Args:
|
||||
texts:
|
||||
It is a list of strings. Each string consists of space(s)
|
||||
separated words. An example containing two strings is given below:
|
||||
|
||||
['HELLO ICEFALL', 'HELLO k2']
|
||||
We assume it contains no OOVs. Otherwise, it will raise an
|
||||
exception.
|
||||
Returns:
|
||||
Return a list-of-list of token IDs.
|
||||
"""
|
||||
return self.lexicon.texts_to_token_ids(texts).tolist()
|
377
icefall/shared/make_kn_lm.py
Executable file
377
icefall/shared/make_kn_lm.py
Executable file
@ -0,0 +1,377 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2016 Johns Hopkins University (Author: Daniel Povey)
|
||||
# 2018 Ruizhe Huang
|
||||
# Apache 2.0.
|
||||
|
||||
# This is an implementation of computing Kneser-Ney smoothed language model
|
||||
# in the same way as srilm. This is a back-off, unmodified version of
|
||||
# Kneser-Ney smoothing, which produces the same results as the following
|
||||
# command (as an example) of srilm:
|
||||
#
|
||||
# $ ngram-count -order 4 -kn-modify-counts-at-end -ukndiscount -gt1min 0 -gt2min 0 -gt3min 0 -gt4min 0 \
|
||||
# -text corpus.txt -lm lm.arpa
|
||||
#
|
||||
# The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py
|
||||
# The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html
|
||||
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
import io
|
||||
import math
|
||||
import argparse
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description="""
|
||||
Generate kneser-ney language model as arpa format. By default,
|
||||
it will read the corpus from standard input, and output to standard output.
|
||||
""")
|
||||
parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram")
|
||||
parser.add_argument("-text", type=str, default=None, help="Path to the corpus file")
|
||||
parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models")
|
||||
parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level")
|
||||
args = parser.parse_args()
|
||||
|
||||
default_encoding = "latin-1" # For encoding-agnostic scripts, we assume byte stream as input.
|
||||
# Need to be very careful about the use of strip() and split()
|
||||
# in this case, because there is a latin-1 whitespace character
|
||||
# (nbsp) which is part of the unicode encoding range.
|
||||
# Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717
|
||||
strip_chars = " \t\r\n"
|
||||
whitespace = re.compile("[ \t]+")
|
||||
|
||||
|
||||
class CountsForHistory:
|
||||
# This class (which is more like a struct) stores the counts seen in a
|
||||
# particular history-state. It is used inside class NgramCounts.
|
||||
# It really does the job of a dict from int to float, but it also
|
||||
# keeps track of the total count.
|
||||
def __init__(self):
|
||||
# The 'lambda: defaultdict(float)' is an anonymous function taking no
|
||||
# arguments that returns a new defaultdict(float).
|
||||
self.word_to_count = defaultdict(int)
|
||||
self.word_to_context = defaultdict(set) # using a set to count the number of unique contexts
|
||||
self.word_to_f = dict() # discounted probability
|
||||
self.word_to_bow = dict() # back-off weight
|
||||
self.total_count = 0
|
||||
|
||||
def words(self):
|
||||
return self.word_to_count.keys()
|
||||
|
||||
def __str__(self):
|
||||
# e.g. returns ' total=12: 3->4, 4->6, -1->2'
|
||||
return ' total={0}: {1}'.format(
|
||||
str(self.total_count),
|
||||
', '.join(['{0} -> {1}'.format(word, count)
|
||||
for word, count in self.word_to_count.items()]))
|
||||
|
||||
def add_count(self, predicted_word, context_word, count):
|
||||
assert count >= 0
|
||||
|
||||
self.total_count += count
|
||||
self.word_to_count[predicted_word] += count
|
||||
if context_word is not None:
|
||||
self.word_to_context[predicted_word].add(context_word)
|
||||
|
||||
|
||||
class NgramCounts:
|
||||
# A note on data-structure. Firstly, all words are represented as
|
||||
# integers. We store n-gram counts as an array, indexed by (history-length
|
||||
# == n-gram order minus one) (note: python calls arrays "lists") of dicts
|
||||
# from histories to counts, where histories are arrays of integers and
|
||||
# "counts" are dicts from integer to float. For instance, when
|
||||
# accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd
|
||||
# do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an
|
||||
# array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict.
|
||||
def __init__(self, ngram_order, bos_symbol='<s>', eos_symbol='</s>'):
|
||||
assert ngram_order >= 2
|
||||
|
||||
self.ngram_order = ngram_order
|
||||
self.bos_symbol = bos_symbol
|
||||
self.eos_symbol = eos_symbol
|
||||
|
||||
self.counts = []
|
||||
for n in range(ngram_order):
|
||||
self.counts.append(defaultdict(lambda: CountsForHistory()))
|
||||
|
||||
self.d = [] # list of discounting factor for each order of ngram
|
||||
|
||||
# adds a raw count (called while processing input data).
|
||||
# Suppose we see the sequence '6 7 8 9' and ngram_order=4, 'history'
|
||||
# would be (6,7,8) and 'predicted_word' would be 9; 'count' would be
|
||||
# 1.
|
||||
def add_count(self, history, predicted_word, context_word, count):
|
||||
self.counts[len(history)][history].add_count(predicted_word, context_word, count)
|
||||
|
||||
# 'line' is a string containing a sequence of integer word-ids.
|
||||
# This function adds the un-smoothed counts from this line of text.
|
||||
def add_raw_counts_from_line(self, line):
|
||||
if line == '':
|
||||
words = [self.bos_symbol, self.eos_symbol]
|
||||
else:
|
||||
words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol]
|
||||
|
||||
for i in range(len(words)):
|
||||
for n in range(1, self.ngram_order+1):
|
||||
if i + n > len(words):
|
||||
break
|
||||
ngram = words[i: i + n]
|
||||
predicted_word = ngram[-1]
|
||||
history = tuple(ngram[: -1])
|
||||
if i == 0 or n == self.ngram_order:
|
||||
context_word = None
|
||||
else:
|
||||
context_word = words[i-1]
|
||||
|
||||
self.add_count(history, predicted_word, context_word, 1)
|
||||
|
||||
def add_raw_counts_from_standard_input(self):
|
||||
lines_processed = 0
|
||||
infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding) # byte stream as input
|
||||
for line in infile:
|
||||
line = line.strip(strip_chars)
|
||||
self.add_raw_counts_from_line(line)
|
||||
lines_processed += 1
|
||||
if lines_processed == 0 or args.verbose > 0:
|
||||
print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
|
||||
|
||||
def add_raw_counts_from_file(self, filename):
|
||||
lines_processed = 0
|
||||
with open(filename, encoding=default_encoding) as fp:
|
||||
for line in fp:
|
||||
line = line.strip(strip_chars)
|
||||
self.add_raw_counts_from_line(line)
|
||||
lines_processed += 1
|
||||
if lines_processed == 0 or args.verbose > 0:
|
||||
print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr)
|
||||
|
||||
def cal_discounting_constants(self):
|
||||
# For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N),
|
||||
# where n1_N is the number of unique N-grams with count = 1 (counts-of-counts).
|
||||
# This constant is used similarly to absolute discounting.
|
||||
# Return value: d is a list of floats, where d[N+1] = D_N
|
||||
|
||||
self.d = [0] # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0
|
||||
# This is a special case: as we currently assumed having seen all vocabularies in the dictionary,
|
||||
# but perhaps this is not the case for some other scenarios.
|
||||
for n in range(1, self.ngram_order):
|
||||
this_order_counts = self.counts[n]
|
||||
n1 = 0
|
||||
n2 = 0
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
stat = Counter(counts_for_hist.word_to_count.values())
|
||||
n1 += stat[1]
|
||||
n2 += stat[2]
|
||||
assert n1 + 2 * n2 > 0
|
||||
self.d.append(n1 * 1.0 / (n1 + 2 * n2))
|
||||
|
||||
def cal_f(self):
|
||||
# f(a_z) is a probability distribution of word sequence a_z.
|
||||
# Typically f(a_z) is discounted to be less than the ML estimate so we have
|
||||
# some leftover probability for the z words unseen in the context (a_).
|
||||
#
|
||||
# f(a_z) = (c(a_z) - D0) / c(a_) ;; for highest order N-grams
|
||||
# f(_z) = (n(*_z) - D1) / n(*_*) ;; for lower order N-grams
|
||||
|
||||
# highest order N-grams
|
||||
n = self.ngram_order - 1
|
||||
this_order_counts = self.counts[n]
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
for w, c in counts_for_hist.word_to_count.items():
|
||||
counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
|
||||
|
||||
# lower order N-grams
|
||||
for n in range(0, self.ngram_order - 1):
|
||||
this_order_counts = self.counts[n]
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
|
||||
n_star_star = 0
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
n_star_star += len(counts_for_hist.word_to_context[w])
|
||||
|
||||
if n_star_star != 0:
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
n_star_z = len(counts_for_hist.word_to_context[w])
|
||||
counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star
|
||||
else: # patterns begin with <s>, they do not have "modified count", so use raw count instead
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
n_star_z = counts_for_hist.word_to_count[w]
|
||||
counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count
|
||||
|
||||
def cal_bow(self):
|
||||
# Backoff weights are only necessary for ngrams which form a prefix of a longer ngram.
|
||||
# Thus, two sorts of ngrams do not have a bow:
|
||||
# 1) highest order ngram
|
||||
# 2) ngrams ending in </s>
|
||||
#
|
||||
# bow(a_) = (1 - Sum_Z1 f(a_z)) / (1 - Sum_Z1 f(_z))
|
||||
# Note that Z1 is the set of all words with c(a_z) > 0
|
||||
|
||||
# highest order N-grams
|
||||
n = self.ngram_order - 1
|
||||
this_order_counts = self.counts[n]
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
counts_for_hist.word_to_bow[w] = None
|
||||
|
||||
# lower order N-grams
|
||||
for n in range(0, self.ngram_order - 1):
|
||||
this_order_counts = self.counts[n]
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
if w == self.eos_symbol:
|
||||
counts_for_hist.word_to_bow[w] = None
|
||||
else:
|
||||
a_ = hist + (w,)
|
||||
|
||||
assert len(a_) < self.ngram_order
|
||||
assert a_ in self.counts[len(a_)].keys()
|
||||
|
||||
a_counts_for_hist = self.counts[len(a_)][a_]
|
||||
|
||||
sum_z1_f_a_z = 0
|
||||
for u in a_counts_for_hist.word_to_count.keys():
|
||||
sum_z1_f_a_z += a_counts_for_hist.word_to_f[u]
|
||||
|
||||
sum_z1_f_z = 0
|
||||
_ = a_[1:]
|
||||
_counts_for_hist = self.counts[len(_)][_]
|
||||
for u in a_counts_for_hist.word_to_count.keys(): # Should be careful here: what is Z1
|
||||
sum_z1_f_z += _counts_for_hist.word_to_f[u]
|
||||
|
||||
counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z)
|
||||
|
||||
def print_raw_counts(self, info_string):
|
||||
# these are useful for debug.
|
||||
print(info_string)
|
||||
res = []
|
||||
for this_order_counts in self.counts:
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
ngram = " ".join(hist) + " " + w
|
||||
ngram = ngram.strip(strip_chars)
|
||||
|
||||
res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w]))
|
||||
res.sort(reverse=True)
|
||||
for r in res:
|
||||
print(r)
|
||||
|
||||
def print_modified_counts(self, info_string):
|
||||
# these are useful for debug.
|
||||
print(info_string)
|
||||
res = []
|
||||
for this_order_counts in self.counts:
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
ngram = " ".join(hist) + " " + w
|
||||
ngram = ngram.strip(strip_chars)
|
||||
|
||||
modified_count = len(counts_for_hist.word_to_context[w])
|
||||
raw_count = counts_for_hist.word_to_count[w]
|
||||
|
||||
if modified_count == 0:
|
||||
res.append("{0}\t{1}".format(ngram, raw_count))
|
||||
else:
|
||||
res.append("{0}\t{1}".format(ngram, modified_count))
|
||||
res.sort(reverse=True)
|
||||
for r in res:
|
||||
print(r)
|
||||
|
||||
def print_f(self, info_string):
|
||||
# these are useful for debug.
|
||||
print(info_string)
|
||||
res = []
|
||||
for this_order_counts in self.counts:
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
ngram = " ".join(hist) + " " + w
|
||||
ngram = ngram.strip(strip_chars)
|
||||
|
||||
f = counts_for_hist.word_to_f[w]
|
||||
if f == 0: # f(<s>) is always 0
|
||||
f = 1e-99
|
||||
|
||||
res.append("{0}\t{1}".format(ngram, math.log(f, 10)))
|
||||
res.sort(reverse=True)
|
||||
for r in res:
|
||||
print(r)
|
||||
|
||||
def print_f_and_bow(self, info_string):
|
||||
# these are useful for debug.
|
||||
print(info_string)
|
||||
res = []
|
||||
for this_order_counts in self.counts:
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
for w in counts_for_hist.word_to_count.keys():
|
||||
ngram = " ".join(hist) + " " + w
|
||||
ngram = ngram.strip(strip_chars)
|
||||
|
||||
f = counts_for_hist.word_to_f[w]
|
||||
if f == 0: # f(<s>) is always 0
|
||||
f = 1e-99
|
||||
|
||||
bow = counts_for_hist.word_to_bow[w]
|
||||
if bow is None:
|
||||
res.append("{1}\t{0}".format(ngram, math.log(f, 10)))
|
||||
else:
|
||||
res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10)))
|
||||
res.sort(reverse=True)
|
||||
for r in res:
|
||||
print(r)
|
||||
|
||||
def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')):
|
||||
# print as ARPA format.
|
||||
|
||||
print('\\data\\', file=fout)
|
||||
for hist_len in range(self.ngram_order):
|
||||
# print the number of n-grams.
|
||||
print('ngram {0}={1}'.format(
|
||||
hist_len + 1,
|
||||
sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])),
|
||||
file=fout
|
||||
)
|
||||
|
||||
print('', file=fout)
|
||||
|
||||
for hist_len in range(self.ngram_order):
|
||||
print('\\{0}-grams:'.format(hist_len + 1), file=fout)
|
||||
|
||||
this_order_counts = self.counts[hist_len]
|
||||
for hist, counts_for_hist in this_order_counts.items():
|
||||
for word in counts_for_hist.word_to_count.keys():
|
||||
ngram = hist + (word,)
|
||||
prob = counts_for_hist.word_to_f[word]
|
||||
bow = counts_for_hist.word_to_bow[word]
|
||||
|
||||
if prob == 0: # f(<s>) is always 0
|
||||
prob = 1e-99
|
||||
|
||||
line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram))
|
||||
if bow is not None:
|
||||
line += '\t{0}'.format('%.7f' % math.log10(bow))
|
||||
print(line, file=fout)
|
||||
print('', file=fout)
|
||||
print('\\end\\', file=fout)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
ngram_counts = NgramCounts(args.ngram_order)
|
||||
|
||||
if args.text is None:
|
||||
ngram_counts.add_raw_counts_from_standard_input()
|
||||
else:
|
||||
assert os.path.isfile(args.text)
|
||||
ngram_counts.add_raw_counts_from_file(args.text)
|
||||
|
||||
ngram_counts.cal_discounting_constants()
|
||||
ngram_counts.cal_f()
|
||||
ngram_counts.cal_bow()
|
||||
|
||||
if args.lm is None:
|
||||
ngram_counts.print_as_arpa()
|
||||
else:
|
||||
with open(args.lm, 'w', encoding=default_encoding) as f:
|
||||
ngram_counts.print_as_arpa(fout=f)
|
@ -8,4 +8,5 @@ exclude = '''
|
||||
\.git
|
||||
| \.github
|
||||
)/
|
||||
| make_kn_lm.py
|
||||
'''
|
||||
|
223
test/test_ali.py
Executable file
223
test/test_ali.py
Executable file
@ -0,0 +1,223 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Runt his file using one of the following two ways:
|
||||
# (1) python3 ./test/test_ali.py
|
||||
# (2) pytest ./test/test_ali.py
|
||||
|
||||
# The purpose of this file is to show that if we build a mask
|
||||
# from alignments and add it to a randomly generated nnet_output,
|
||||
# we can decode the correct transcript.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import torch
|
||||
from lhotse import load_manifest
|
||||
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from icefall.ali import (
|
||||
convert_alignments_to_tensor,
|
||||
load_alignments,
|
||||
lookup_alignments,
|
||||
)
|
||||
from icefall.decode import get_lattice, one_best_decoding
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import get_texts
|
||||
|
||||
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||
egs_dir = ICEFALL_DIR / "egs/librispeech/ASR"
|
||||
lang_dir = egs_dir / "data/lang_bpe_500"
|
||||
# cut_json = egs_dir / "data/fbank/cuts_train-clean-100.json.gz"
|
||||
# cut_json = egs_dir / "data/fbank/cuts_train-clean-360.json.gz"
|
||||
# cut_json = egs_dir / "data/fbank/cuts_train-other-500.json.gz"
|
||||
# ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/train-960.pt"
|
||||
|
||||
cut_json = egs_dir / "data/fbank/cuts_test-clean.json.gz"
|
||||
ali_filename = ICEFALL_DIR / "egs/librispeech/ASR/data/ali_500/test_clean.pt"
|
||||
|
||||
|
||||
def data_exists():
|
||||
return ali_filename.exists() and cut_json.exists() and lang_dir.exists()
|
||||
|
||||
|
||||
def get_dataloader():
|
||||
cuts_train = load_manifest(cut_json)
|
||||
cuts_train = cuts_train.with_features_path_prefix(egs_dir)
|
||||
train_sampler = SingleCutSampler(
|
||||
cuts_train,
|
||||
max_duration=40,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
train = K2SpeechRecognitionDataset(return_cuts=True)
|
||||
|
||||
train_dl = DataLoader(
|
||||
train,
|
||||
sampler=train_sampler,
|
||||
batch_size=None,
|
||||
num_workers=1,
|
||||
persistent_workers=False,
|
||||
)
|
||||
return train_dl
|
||||
|
||||
|
||||
def test_one_hot():
|
||||
a = [1, 3, 2]
|
||||
b = [1, 0, 4, 2]
|
||||
c = [torch.tensor(a), torch.tensor(b)]
|
||||
d = pad_sequence(c, batch_first=True, padding_value=0)
|
||||
f = torch.nn.functional.one_hot(d, num_classes=5)
|
||||
e = (1 - f) * -10.0
|
||||
expected = torch.tensor(
|
||||
[
|
||||
[
|
||||
[-10, 0, -10, -10, -10],
|
||||
[-10, -10, -10, 0, -10],
|
||||
[-10, -10, 0, -10, -10],
|
||||
[0, -10, -10, -10, -10],
|
||||
],
|
||||
[
|
||||
[-10, 0, -10, -10, -10],
|
||||
[0, -10, -10, -10, -10],
|
||||
[-10, -10, -10, -10, 0],
|
||||
[-10, -10, 0, -10, -10],
|
||||
],
|
||||
]
|
||||
).to(e.dtype)
|
||||
assert torch.all(torch.eq(e, expected))
|
||||
|
||||
|
||||
def test():
|
||||
"""
|
||||
The purpose of this test is to show that we can use pre-computed
|
||||
alignments to construct a mask, adding it to a randomly generated
|
||||
nnet_output, to decode the correct transcript from the resulting
|
||||
nnet_output.
|
||||
"""
|
||||
if not data_exists():
|
||||
return
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
dl = get_dataloader()
|
||||
|
||||
subsampling_factor, ali = load_alignments(ali_filename)
|
||||
ali = convert_alignments_to_tensor(ali, device=device)
|
||||
|
||||
lexicon = Lexicon(lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
word_table = lexicon.word_table
|
||||
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{lang_dir}/HLG.pt", map_location=device)
|
||||
)
|
||||
|
||||
for batch in dl:
|
||||
features = batch["inputs"]
|
||||
supervisions = batch["supervisions"]
|
||||
N = features.shape[0]
|
||||
T = features.shape[1] // subsampling_factor
|
||||
nnet_output = (
|
||||
torch.rand(N, T, num_classes, dtype=torch.float32, device=device)
|
||||
.softmax(dim=-1)
|
||||
.log()
|
||||
)
|
||||
cut_ids = [cut.id for cut in supervisions["cut"]]
|
||||
mask = lookup_alignments(
|
||||
cut_ids=cut_ids, alignments=ali, num_classes=num_classes
|
||||
)
|
||||
min_len = min(nnet_output.shape[1], mask.shape[1])
|
||||
ali_model_scale = 0.8
|
||||
|
||||
nnet_output[:, :min_len, :] += ali_model_scale * mask[:, :min_len, :]
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
supervisions["start_frame"] // subsampling_factor,
|
||||
supervisions["num_frames"] // subsampling_factor,
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
decoding_graph=HLG,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=20,
|
||||
output_beam=8,
|
||||
min_active_states=30,
|
||||
max_active_states=10000,
|
||||
subsampling_factor=subsampling_factor,
|
||||
)
|
||||
|
||||
best_path = one_best_decoding(lattice=lattice, use_double_scores=True)
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[word_table[i] for i in ids] for ids in hyps]
|
||||
hyps = [" ".join(s) for s in hyps]
|
||||
print(hyps)
|
||||
print(supervisions["text"])
|
||||
break
|
||||
|
||||
|
||||
def show_cut_ids():
|
||||
# The purpose of this function is to check that
|
||||
# for each utterance in the training set, there is
|
||||
# a corresponding alignment.
|
||||
#
|
||||
# After generating a1.txt and b1.txt
|
||||
# You can use
|
||||
# wc -l a1.txt b1.txt
|
||||
# which should show the same number of lines.
|
||||
#
|
||||
# cat a1.txt | sort | uniq > a11.txt
|
||||
# cat b1.txt | sort | uniq > b11.txt
|
||||
#
|
||||
# md5sum a11.txt b11.txt
|
||||
# which should show the identical hash
|
||||
#
|
||||
# diff a11.txt b11.txt
|
||||
# should print nothing
|
||||
|
||||
subsampling_factor, ali = load_alignments(ali_filename)
|
||||
with open("a1.txt", "w") as f:
|
||||
for key in ali:
|
||||
f.write(f"{key}\n")
|
||||
|
||||
# dl = get_dataloader()
|
||||
cuts_train = (
|
||||
load_manifest(egs_dir / "data/fbank/cuts_train-clean-100.json.gz")
|
||||
+ load_manifest(egs_dir / "data/fbank/cuts_train-clean-360.json.gz")
|
||||
+ load_manifest(egs_dir / "data/fbank/cuts_train-other-500.json.gz")
|
||||
)
|
||||
|
||||
ans = []
|
||||
for cut in cuts_train:
|
||||
ans.append(cut.id)
|
||||
with open("b1.txt", "w") as f:
|
||||
for line in ans:
|
||||
f.write(f"{line}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test()
|
@ -19,20 +19,21 @@
|
||||
from pathlib import Path
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.lexicon import BpeLexicon
|
||||
from icefall.lexicon import UniqLexicon
|
||||
|
||||
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
def test():
|
||||
lang_dir = Path("data/lang/bpe")
|
||||
lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe"
|
||||
if not lang_dir.is_dir():
|
||||
return
|
||||
# TODO: generate data for testing
|
||||
|
||||
compiler = BpeCtcTrainingGraphCompiler(lang_dir)
|
||||
ids = compiler.texts_to_ids(["HELLO", "WORLD ZZZ"])
|
||||
compiler.compile(ids)
|
||||
|
||||
lexicon = BpeLexicon(lang_dir)
|
||||
lexicon = UniqLexicon(lang_dir, uniq_filename="lexicon.txt")
|
||||
ids0 = lexicon.words_to_piece_ids(["HELLO"])
|
||||
assert ids[0] == ids0.values().tolist()
|
||||
|
||||
|
175
test/test_lexicon.py
Normal file → Executable file
175
test/test_lexicon.py
Normal file → Executable file
@ -14,80 +14,135 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
You can run this file in one of the two ways:
|
||||
|
||||
(1) cd icefall; pytest test/test_lexicon.py
|
||||
(2) cd icefall; ./test/test_lexicon.py
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import pytest
|
||||
import torch
|
||||
import sentencepiece as spm
|
||||
|
||||
from icefall.lexicon import BpeLexicon, Lexicon
|
||||
from icefall.lexicon import UniqLexicon
|
||||
|
||||
TMP_DIR = "/tmp/icefall-test-lexicon"
|
||||
USING_PYTEST = "pytest" in sys.modules
|
||||
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lang_dir(tmp_path):
|
||||
phone2id = """
|
||||
<eps> 0
|
||||
a 1
|
||||
b 2
|
||||
f 3
|
||||
o 4
|
||||
r 5
|
||||
z 6
|
||||
SPN 7
|
||||
#0 8
|
||||
"""
|
||||
word2id = """
|
||||
<eps> 0
|
||||
foo 1
|
||||
bar 2
|
||||
baz 3
|
||||
<UNK> 4
|
||||
#0 5
|
||||
def generate_test_data():
|
||||
Path(TMP_DIR).mkdir(exist_ok=True)
|
||||
sentences = """
|
||||
cat tac cat cat
|
||||
at
|
||||
tac at ta at at
|
||||
at cat ct ct ta
|
||||
cat cat cat cat
|
||||
at at at at at at at
|
||||
"""
|
||||
|
||||
L = k2.Fsa.from_str(
|
||||
"""
|
||||
0 0 7 4 0
|
||||
0 7 -1 -1 0
|
||||
0 1 3 1 0
|
||||
0 3 2 2 0
|
||||
0 5 2 3 0
|
||||
1 2 4 0 0
|
||||
2 0 4 0 0
|
||||
3 4 1 0 0
|
||||
4 0 5 0 0
|
||||
5 6 1 0 0
|
||||
6 0 6 0 0
|
||||
7
|
||||
""",
|
||||
num_aux_labels=1,
|
||||
transcript = Path(TMP_DIR) / "transcript_words.txt"
|
||||
with open(transcript, "w") as f:
|
||||
for line in sentences.strip().split("\n"):
|
||||
f.write(f"{line}\n")
|
||||
|
||||
words = """
|
||||
<eps> 0
|
||||
<UNK> 1
|
||||
at 2
|
||||
cat 3
|
||||
ct 4
|
||||
ta 5
|
||||
tac 6
|
||||
#0 7
|
||||
<s> 8
|
||||
</s> 9
|
||||
"""
|
||||
word_txt = Path(TMP_DIR) / "words.txt"
|
||||
with open(word_txt, "w") as f:
|
||||
for line in words.strip().split("\n"):
|
||||
f.write(f"{line}\n")
|
||||
|
||||
vocab_size = 8
|
||||
|
||||
os.system(
|
||||
f"""
|
||||
cd {ICEFALL_DIR}/egs/librispeech/ASR
|
||||
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir {TMP_DIR} \
|
||||
--vocab-size {vocab_size} \
|
||||
--transcript {transcript}
|
||||
|
||||
./local/prepare_lang_bpe.py --lang-dir {TMP_DIR} --debug 1
|
||||
"""
|
||||
)
|
||||
|
||||
with open(tmp_path / "tokens.txt", "w") as f:
|
||||
f.write(phone2id)
|
||||
with open(tmp_path / "words.txt", "w") as f:
|
||||
f.write(word2id)
|
||||
|
||||
torch.save(L.as_dict(), tmp_path / "L.pt")
|
||||
|
||||
return tmp_path
|
||||
def delete_test_data():
|
||||
shutil.rmtree(TMP_DIR)
|
||||
|
||||
|
||||
def test_lexicon(lang_dir):
|
||||
lexicon = Lexicon(lang_dir)
|
||||
assert lexicon.tokens == list(range(1, 8))
|
||||
def uniq_lexicon_test():
|
||||
lexicon = UniqLexicon(lang_dir=TMP_DIR, uniq_filename="lexicon.txt")
|
||||
|
||||
# case 1: No OOV
|
||||
texts = ["cat cat", "at ct", "at tac cat"]
|
||||
token_ids = lexicon.texts_to_token_ids(texts)
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(f"{TMP_DIR}/bpe.model")
|
||||
|
||||
expected_token_ids: List[List[int]] = sp.encode(texts, out_type=int)
|
||||
assert token_ids.tolist() == expected_token_ids
|
||||
|
||||
# case 2: With OOV
|
||||
texts = ["ca"]
|
||||
token_ids = lexicon.texts_to_token_ids(texts)
|
||||
expected_token_ids = sp.encode(texts, out_type=int)
|
||||
assert token_ids.tolist() != expected_token_ids
|
||||
# Note: sentencepiece breaks "ca" into "_ c a"
|
||||
# But there is no word "ca" in the lexicon, so our
|
||||
# implementation returns the id of "<UNK>"
|
||||
print(token_ids, expected_token_ids)
|
||||
assert token_ids.tolist() == [[sp.unk_id()]]
|
||||
|
||||
# case 3: With OOV
|
||||
texts = ["foo"]
|
||||
token_ids = lexicon.texts_to_token_ids(texts)
|
||||
expected_token_ids = sp.encode(texts, out_type=int)
|
||||
print(token_ids)
|
||||
print(expected_token_ids)
|
||||
|
||||
# test ragged lexicon
|
||||
ragged_lexicon = lexicon.ragged_lexicon.tolist()
|
||||
word_disambig_id = lexicon.word_table["#0"]
|
||||
for i in range(2, word_disambig_id):
|
||||
piece_id = ragged_lexicon[i]
|
||||
word = lexicon.word_table[i]
|
||||
assert word == sp.decode(piece_id)
|
||||
assert piece_id == sp.encode(word)
|
||||
|
||||
|
||||
def test_bpe_lexicon():
|
||||
lang_dir = Path("data/lang/bpe")
|
||||
if not lang_dir.is_dir():
|
||||
return
|
||||
# TODO: Generate test data for BpeLexicon
|
||||
def test_main():
|
||||
generate_test_data()
|
||||
|
||||
lexicon = BpeLexicon(lang_dir)
|
||||
words = ["<UNK>", "HELLO", "ZZZZ", "WORLD"]
|
||||
ids = lexicon.words_to_piece_ids(words)
|
||||
print(ids)
|
||||
print([lexicon.token_table[i] for i in ids.values().tolist()])
|
||||
uniq_lexicon_test()
|
||||
|
||||
if USING_PYTEST:
|
||||
delete_test_data()
|
||||
|
||||
|
||||
def main():
|
||||
test_main()
|
||||
|
||||
|
||||
if __name__ == "__main__" and not USING_PYTEST:
|
||||
main()
|
||||
|
196
test/test_mmi_graph_compiler.py
Executable file
196
test/test_mmi_graph_compiler.py
Executable file
@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
You can run this file in one of the two ways:
|
||||
|
||||
(1) cd icefall; pytest test/test_mmi_graph_compiler.py
|
||||
(2) cd icefall; ./test/test_mmi_graph_compiler.py
|
||||
"""
|
||||
|
||||
import copy
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
|
||||
from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler
|
||||
|
||||
TMP_DIR = "/tmp/icefall-test-mmi-graph-compiler"
|
||||
USING_PYTEST = "pytest" in sys.modules
|
||||
ICEFALL_DIR = Path(__file__).resolve().parent.parent
|
||||
|
||||
|
||||
def generate_test_data():
|
||||
Path(TMP_DIR).mkdir(exist_ok=True)
|
||||
sentences = """
|
||||
cat tac cat cat
|
||||
at at cat at cat cat
|
||||
tac at ta at at
|
||||
at cat ct ct ta ct ct cat tac
|
||||
cat cat cat cat
|
||||
at at at at at at at
|
||||
"""
|
||||
|
||||
transcript = Path(TMP_DIR) / "transcript_words.txt"
|
||||
with open(transcript, "w") as f:
|
||||
for line in sentences.strip().split("\n"):
|
||||
f.write(f"{line}\n")
|
||||
|
||||
words = """
|
||||
<eps> 0
|
||||
<UNK> 1
|
||||
at 2
|
||||
cat 3
|
||||
ct 4
|
||||
ta 5
|
||||
tac 6
|
||||
#0 7
|
||||
<s> 8
|
||||
</s> 9
|
||||
"""
|
||||
word_txt = Path(TMP_DIR) / "words.txt"
|
||||
with open(word_txt, "w") as f:
|
||||
for line in words.strip().split("\n"):
|
||||
f.write(f"{line}\n")
|
||||
|
||||
vocab_size = 8
|
||||
|
||||
os.system(
|
||||
f"""
|
||||
cd {ICEFALL_DIR}/egs/librispeech/ASR
|
||||
|
||||
./local/train_bpe_model.py \
|
||||
--lang-dir {TMP_DIR} \
|
||||
--vocab-size {vocab_size} \
|
||||
--transcript {transcript}
|
||||
|
||||
./local/prepare_lang_bpe.py --lang-dir {TMP_DIR} --debug 0
|
||||
|
||||
./local/convert_transcript_words_to_tokens.py \
|
||||
--lexicon {TMP_DIR}/lexicon.txt \
|
||||
--transcript {transcript} \
|
||||
--oov "<UNK>" \
|
||||
> {TMP_DIR}/transcript_tokens.txt
|
||||
|
||||
./shared/make_kn_lm.py \
|
||||
-ngram-order 2 \
|
||||
-text {TMP_DIR}/transcript_tokens.txt \
|
||||
-lm {TMP_DIR}/P.arpa
|
||||
|
||||
python3 -m kaldilm \
|
||||
--read-symbol-table="{TMP_DIR}/tokens.txt" \
|
||||
--disambig-symbol='#0' \
|
||||
--max-order=2 \
|
||||
{TMP_DIR}/P.arpa > {TMP_DIR}/P.fst.txt
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def delete_test_data():
|
||||
shutil.rmtree(TMP_DIR)
|
||||
|
||||
|
||||
def mmi_graph_compiler_test():
|
||||
# Caution:
|
||||
# You have to uncomment
|
||||
# del transcript_fsa.aux_labels
|
||||
# in mmi_graph_compiler.py
|
||||
# to see the correct aux_labels in *.svg
|
||||
graph_compiler = MmiTrainingGraphCompiler(
|
||||
lang_dir=TMP_DIR, uniq_filename="lexicon.txt"
|
||||
)
|
||||
print(graph_compiler.device)
|
||||
L_inv = graph_compiler.L_inv
|
||||
L = k2.invert(L_inv)
|
||||
|
||||
L.labels_sym = graph_compiler.lexicon.token_table
|
||||
L.aux_labels_sym = graph_compiler.lexicon.word_table
|
||||
L.draw(f"{TMP_DIR}/L.svg", title="L")
|
||||
|
||||
L_inv.labels_sym = graph_compiler.lexicon.word_table
|
||||
L_inv.aux_labels_sym = graph_compiler.lexicon.token_table
|
||||
L_inv.draw(f"{TMP_DIR}/L_inv.svg", title="L")
|
||||
|
||||
ctc_topo_P = graph_compiler.ctc_topo_P
|
||||
ctc_topo_P.labels_sym = copy.deepcopy(graph_compiler.lexicon.token_table)
|
||||
ctc_topo_P.labels_sym._id2sym[0] = "<blk>"
|
||||
ctc_topo_P.labels_sym._sym2id["<blk>"] = 0
|
||||
ctc_topo_P.aux_labels_sym = graph_compiler.lexicon.token_table
|
||||
ctc_topo_P.draw(f"{TMP_DIR}/ctc_topo_P.svg", title="ctc_topo_P")
|
||||
|
||||
print(ctc_topo_P.num_arcs)
|
||||
print(k2.connect(ctc_topo_P).num_arcs)
|
||||
|
||||
with open(str(TMP_DIR) + "/P.fst.txt") as f:
|
||||
# P is not an acceptor because there is
|
||||
# a back-off state, whose incoming arcs
|
||||
# have label #0 and aux_label 0 (i.e., <eps>).
|
||||
P = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
P.labels_sym = graph_compiler.lexicon.token_table
|
||||
P.aux_labels_sym = graph_compiler.lexicon.token_table
|
||||
P.draw(f"{TMP_DIR}/P.svg", title="P")
|
||||
|
||||
ctc_topo = k2.ctc_topo(max(graph_compiler.lexicon.tokens), False)
|
||||
ctc_topo.labels_sym = ctc_topo_P.labels_sym
|
||||
ctc_topo.aux_labels_sym = graph_compiler.lexicon.token_table
|
||||
ctc_topo.draw(f"{TMP_DIR}/ctc_topo.svg", title="ctc_topo")
|
||||
print("p num arcs", P.num_arcs)
|
||||
print("ctc_topo num arcs", ctc_topo.num_arcs)
|
||||
print("ctc_topo_P num arcs", ctc_topo_P.num_arcs)
|
||||
|
||||
texts = ["cat at ct", "at ta", "cat tac"]
|
||||
transcript_fsa = graph_compiler.build_transcript_fsa(texts)
|
||||
transcript_fsa[0].draw(f"{TMP_DIR}/cat_at_ct.svg", title="cat_at_ct")
|
||||
transcript_fsa[1].draw(f"{TMP_DIR}/at_ta.svg", title="at_ta")
|
||||
transcript_fsa[2].draw(f"{TMP_DIR}/cat_tac.svg", title="cat_tac")
|
||||
|
||||
num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True)
|
||||
num_graphs[0].draw(f"{TMP_DIR}/num_cat_at_ct.svg", title="num_cat_at_ct")
|
||||
num_graphs[1].draw(f"{TMP_DIR}/num_at_ta.svg", title="num_at_ta")
|
||||
num_graphs[2].draw(f"{TMP_DIR}/num_cat_tac.svg", title="num_cat_tac")
|
||||
|
||||
den_graphs[0].draw(f"{TMP_DIR}/den_cat_at_ct.svg", title="den_cat_at_ct")
|
||||
den_graphs[2].draw(f"{TMP_DIR}/den_cat_tac.svg", title="den_cat_tac")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(f"{TMP_DIR}/bpe.model")
|
||||
|
||||
texts = ["cat at cat", "at tac"]
|
||||
token_ids = graph_compiler.texts_to_ids(texts)
|
||||
expected_token_ids = sp.encode(texts)
|
||||
assert token_ids == expected_token_ids
|
||||
|
||||
|
||||
def test_main():
|
||||
generate_test_data()
|
||||
|
||||
mmi_graph_compiler_test()
|
||||
|
||||
if USING_PYTEST:
|
||||
delete_test_data()
|
||||
|
||||
|
||||
def main():
|
||||
test_main()
|
||||
|
||||
|
||||
if __name__ == "__main__" and not USING_PYTEST:
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user