Draft streaming decoding (#89)

* reusable parts from conformer_ctc

* streaming conformer code

* a trained model
This commit is contained in:
LIyong.Guo 2021-11-24 19:35:18 +08:00 committed by GitHub
parent 769a9791ec
commit 00e2f0ade8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 3680 additions and 0 deletions

View File

@ -0,0 +1,90 @@
## Train and Decode
Commands of data preparation/train/decode steps are almost the same with
../conformer_ctc experiment except some options.
Please read the code and understand following new added options before running this experiment:
For data preparation:
Nothing new.
For streaming_conformer_ctc/train.py:
--dynamic-chunk-training
--short-chunk-proportion
For streaming_conformer_ctc/streaming_decode.py:
--chunk-size
--tailing-num-frames
--simulate-streaming
## Performence and a trained model.
The latest results with this streaming code is shown in following table:
chunk size | wer on test-clean | wer on test-other
-- | -- | --
full | 3.53 | 8.52
40(1.96s) | 3.78 | 9.38
32(1.28s) | 3.82 | 9.44
24(0.96s) | 3.95 | 9.76
16(0.64s) | 4.06 | 9.98
8(0.32s) | 4.30 | 10.55
4(0.16s) | 5.88 | 12.01
A trained model is also provided.
By run
```
git clone https://huggingface.co/GuoLiyong/streaming_conformer
# You may want to manually check md5sum values of downloaded files
# 8e633bc1de37f5ae57a2694ceee32a93 trained_streaming_conformer.pt
# 4c0aeefe26c784ec64873cc9b95420f1 L.pt
# d1f91d81005fb8ce4d65953a4a984ee7 Linv.pt
# e1c1902feb7b9fc69cd8d26e663c2608 bpe.model
# 8617e67159b0ff9118baa54f04db24cc tokens.txt
# 72b075ab5e851005cd854e666c82c3bb words.txt
```
If there is any different md5sum values, please run
```
cd streaming_models
git lfs pull
```
And check md5sum values again.
Finally, following files will be downloaded:
<pre>
streaming_models/
|-- lang_bpe
| |-- L.pt
| |-- Linv.pt
| |-- bpe.model
| |-- tokens.txt
| `-- words.txt
`-- trained_streaming_conformer.pt
</pre>
And run commands you will get the same results of previous table:
```
trained_models=/path/to/downloaded/streaming_models/
for chunk_size in 4 8 16 24 36 40 -1; do
./streaming_conformer_ctc/streaming_decode.py \
--chunk-size=${chunk_size} \
--trained-dir=${trained_models}
done
```
Results of following command is indentical to previous one,
but model consumes features chunk_by_chunk, i.e. a streaming way.
```
trained_models=/path/to/downloaded/streaming_models/
for chunk_size in 4 8 16 24 36 40 -1; do
./streaming_conformer_ctc/streaming_decode.py \
--simulate-streaming=True \
--chunk-size=${chunk_size} \
--trained-dir=${trained_models}
done
```

View File

@ -0,0 +1 @@
../conformer_ctc/asr_datamodule.py

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1 @@
../conformer_ctc/label_smoothing.py

View File

@ -0,0 +1,521 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
# from https://github.com/wenet-e2e/wenet/blob/main/wenet/utils/common.py#L166
def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
new_hyp: List[int] = []
cur = 0
while cur < len(hyp):
if hyp[cur] != 0:
new_hyp.append(hyp[cur])
prev = cur
while cur < len(hyp) and hyp[cur] == hyp[prev]:
cur += 1
return new_hyp
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=34,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
)
parser.add_argument(
"--avg",
type=int,
default=20,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
)
parser.add_argument(
"--chunk-size",
type=int,
default=8,
help="Frames of right context"
"-1 for whole right context, i.e. non-streaming decoding",
)
parser.add_argument(
"--tailing-num-frames",
type=int,
default=20,
help="tailing dummy frames padded to the right,"
"only used during decoding",
)
parser.add_argument(
"--simulate-streaming",
type=str2bool,
default=False,
help="simulate chunk by chunk decoding",
)
parser.add_argument(
"--method",
type=str,
default="ctc-greedy-search",
help="Streaming Decoding method",
)
parser.add_argument(
"--export",
type=str2bool,
default=False,
help="""When enabled, the averaged model is saved to
conformer_ctc/exp/pretrained.pt. Note: only model.state_dict() is saved.
pretrained.pt contains a dict {"model": model.state_dict()},
which can be loaded by `icefall.checkpoint.load_checkpoint()`.
""",
)
parser.add_argument(
"--exp-dir",
type=Path,
default="streaming_conformer_ctc/exp",
help="The experiment dir",
)
parser.add_argument(
"--trained-dir",
type=Path,
default=None,
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_bpe",
help="The lang dir",
)
parser.add_argument(
"--avg-models",
type=str,
default=None,
help="Manually select models to average, seperated by comma;"
"e.g. 60,62,63,72",
)
return parser
def get_params() -> AttributeDict:
params = AttributeDict(
{
"exp_dir": Path("conformer_ctc/exp"),
"lang_dir": Path("data/lang_bpe"),
# parameters for conformer
"causal": True,
"subsampling_factor": 4,
"vgg_frontend": False,
"use_feat_batchnorm": True,
"feature_dim": 80,
"nhead": 8,
"attention_dim": 512,
"num_decoder_layers": 6,
# parameters for decoding
"search_beam": 20,
"output_beam": 8,
"min_active_states": 30,
"max_active_states": 10000,
"use_double_scores": True,
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
bpe_model: Optional[spm.SentencePieceProcessor],
batch: dict,
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
chunk_size: int = -1,
simulate_streaming=False,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if no rescoring is used, the key is the string `no_rescore`.
If LM rescoring is used, the key is the string `lm_scale_xxx`,
where `xxx` is the value of `lm_scale`. An example key is
`lm_scale_0.7`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
word_table:
The word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
feature = batch["inputs"]
device = torch.device("cuda")
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
# Extra dummy tailing frames my reduce deletion error
# example WITHOUT padding:
# CHAPTER SEVEN ON THE RACES OF MAN
# example WITH padding:
# CHAPTER SEVEN ON THE RACES OF (MAN->*)
tailing_frames = (
torch.tensor([-23.0259])
.expand([feature.size(0), params.tailing_num_frames, 80])
.to(feature.device)
)
feature = torch.cat([feature, tailing_frames], dim=1)
supervisions["num_frames"] += params.tailing_num_frames
nnet_output, memory, memory_key_padding_mask = model(
feature,
supervisions,
chunk_size=chunk_size,
simulate_streaming=simulate_streaming,
)
assert params.method == "ctc-greedy-search"
key = "ctc-greedy-search"
batch_size = nnet_output.size(0)
maxlen = nnet_output.size(1)
topk_prob, topk_index = nnet_output.topk(1, dim=2) # (B, maxlen, 1)
topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen)
topk_index = topk_index.masked_fill_(
memory_key_padding_mask, 0
) # (B, maxlen)
token_ids = [token_id.tolist() for token_id in topk_index]
token_ids = [
remove_duplicates_and_blank(token_id) for token_id in token_ids
]
hyps = bpe_model.decode(token_ids)
hyps = [s.split() for s in hyps]
return {key: hyps}
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
bpe_model: Optional[spm.SentencePieceProcessor],
word_table: k2.SymbolTable,
sos_id: int,
eos_id: int,
chunk_size: int = -1,
simulate_streaming=False,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
bpe_model:
The BPE model. Used only when params.method is ctc-decoding.
word_table:
It is the word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
chunk_size:
right context to simulate streaming decoding
-1 for whole right context, i.e. non-stream decoding
Returns:
Return a dict, whose key may be "no-rescore" if no LM rescoring
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
Its value is a list of tuples. Each tuple contains two elements:
The first is the reference transcript, and the second is the
predicted result.
"""
results = []
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
hyps_dict = decode_one_batch(
params=params,
model=model,
bpe_model=bpe_model,
batch=batch,
word_table=word_table,
sos_id=sos_id,
eos_id=eos_id,
chunk_size=chunk_size,
simulate_streaming=simulate_streaming,
)
for lm_scale, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for hyp_words, ref_text in zip(hyps, texts):
ref_words = ref_text.split()
this_batch.append((ref_words, hyp_words))
results[lm_scale].extend(this_batch)
num_cuts += len(batch["supervisions"]["text"])
if batch_idx % 100 == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(
f"batch {batch_str}, cuts processed until now is {num_cuts}"
)
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
):
if params.method == "attention-decoder":
# Set it to False since there are too many logs.
enable_log = False
else:
enable_log = True
test_set_wers = dict()
if params.avg_models is not None:
avg_models = params.avg_models.replace(",", "_")
result_file_prefix = f"epoch-avg-{avg_models}-chunksize \
-{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
else:
result_file_prefix = f"epoch-{params.epoch}-avg-{params.avg}-chunksize \
-{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-"
for key, results in results_dict.items():
recog_path = (
params.exp_dir
/ f"{result_file_prefix}recogs-{test_set_name}-{key}.txt"
)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.exp_dir
/ f"{result_file_prefix}-errs-{test_set_name}-{key}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log
)
test_set_wers[key] = wer
if enable_log:
logging.info(
"Wrote detailed error stats to {}".format(errs_filename)
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
params = get_params()
params.update(vars(args))
setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode")
logging.info("Decoding started")
logging.info(params)
if params.trained_dir is not None:
params.lang_dir = Path(params.trained_dir) / "lang_bpe"
# used naming result files
params.epoch = "trained_model"
params.avg = 1
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
sos_id = graph_compiler.sos_id
eos_id = graph_compiler.eos_id
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=params.vgg_frontend,
use_feat_batchnorm=params.use_feat_batchnorm,
causal=params.causal,
)
if params.trained_dir is not None:
model_name = f"{params.trained_dir}/trained_streaming_conformer.pt"
load_checkpoint(model_name, model)
elif params.avg == 1 and params.avg_models is not None:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
filenames = []
if params.avg_models is not None:
model_ids = params.avg_models.split(",")
for i in model_ids:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
else:
start = params.epoch - params.avg + 1
for i in range(start, params.epoch + 1):
if start >= 0:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.load_state_dict(average_checkpoints(filenames))
if params.export:
logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt")
torch.save(
{"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt"
)
return
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
librispeech = LibriSpeechAsrDataModule(args)
# CAUTION: `test_sets` is for displaying only.
# If you want to skip test-clean, you have to skip
# it inside the for loop. That is, use
#
# if test_set == 'test-clean': continue
#
bpe_model = spm.SentencePieceProcessor()
bpe_model.load(str(params.lang_dir / "bpe.model"))
test_sets = ["test-clean", "test-other"]
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
bpe_model=bpe_model,
word_table=lexicon.word_table,
sos_id=sos_id,
eos_id=eos_id,
chunk_size=params.chunk_size,
simulate_streaming=params.simulate_streaming,
)
save_results(
params=params, test_set_name=test_set, results_dict=results_dict
)
logging.info("Done!")
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1 @@
../conformer_ctc/subsampling.py

View File

@ -0,0 +1,745 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
import k2
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
encode_supervisions,
setup_logger,
str2bool,
)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--tensorboard",
type=str2bool,
default=True,
help="Should various information be logged in tensorboard.",
)
parser.add_argument(
"--num-epochs",
type=int,
default=78,
help="Number of epochs to train.",
)
parser.add_argument(
"--start-epoch",
type=int,
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
conformer_ctc/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--exp-dir",
type=str,
default="conformer_ctc/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
""",
)
parser.add_argument(
"--lang-dir",
type=str,
default="data/lang_bpe_500",
help="""The lang dir
It contains language related input files such as
"lexicon.txt"
""",
)
parser.add_argument(
"--att-rate",
type=float,
default=0.8,
help="""The attention rate.
The total loss is (1 - att_rate) * ctc_loss + att_rate * att_loss
""",
)
parser.add_argument(
"--dynamic-chunk-training",
type=str2bool,
default=True,
help="Whether to use dynamic right context during training.",
)
parser.add_argument(
"--short-chunk-proportion",
type=float,
default=0.7,
help="Proportion of samples trained with short right context",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing training parameters.
All training related parameters that are not passed from the commandline
are saved in the variable `params`.
Commandline options are merged into `params` after they are parsed, so
you can also access them via `params`.
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- attention_dim: Hidden dim for multi-head attention model.
- head: Number of heads of multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- beam_size: It is used in k2.ctc_loss
- reduction: It is used in k2.ctc_loss
- use_double_scores: It is used in k2.ctc_loss
- weight_decay: The weight_decay for the optimizer.
- lr_factor: The lr_factor for Noam optimizer.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
{
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000,
# parameters for conformer
"feature_dim": 80,
"subsampling_factor": 4,
"use_feat_batchnorm": True,
"attention_dim": 512,
"nhead": 8,
"num_decoder_layers": 6,
# parameters for loss
"beam_size": 10,
"reduction": "sum",
"use_double_scores": True,
# parameters for Noam
"weight_decay": 1e-6,
"lr_factor": 5.0,
"warm_step": 80000,
"env_info": get_env_info(),
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
params:
The return value of :func:`get_params`.
model:
The training model.
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
Returns:
Return None.
"""
if params.start_epoch <= 0:
return
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)
keys = [
"best_train_epoch",
"best_valid_epoch",
"batch_idx_train",
"best_train_loss",
"best_valid_loss",
]
for k in keys:
params[k] = saved_params[k]
return saved_params
def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
Args:
params:
It is returned by :func:`get_params`.
model:
The training model.
"""
if rank != 0:
return
filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
save_checkpoint_impl(
filename=filename,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
rank=rank,
)
if params.best_train_epoch == params.cur_epoch:
best_train_filename = params.exp_dir / "best-train-loss.pt"
copyfile(src=filename, dst=best_train_filename)
if params.best_valid_epoch == params.cur_epoch:
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
model: nn.Module,
batch: dict,
graph_compiler: BpeCtcTrainingGraphCompiler,
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Conformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
graph_compiler:
It is used to build a decoding graph from a ctc topo and training
transcript. The training transcript is contained in the given `batch`,
while the ctc topo is built when this compiler is instantiated.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
"""
device = graph_compiler.device
feature = batch["inputs"]
# at entry, feature is (N, T, C)
assert feature.ndim == 3
feature = feature.to(device)
supervisions = batch["supervisions"]
with torch.set_grad_enabled(is_training):
nnet_output, encoder_memory, memory_mask = model(
feature,
supervisions,
dynamic_chunk_training=params.dynamic_chunk_training,
short_chunk_proportion=params.short_chunk_proportion,
)
# nnet_output is (N, T, C)
# NOTE: We need `encode_supervisions` to sort sequences with
# different duration in decreasing order, required by
# `k2.intersect_dense` called in `k2.ctc_loss`
supervision_segments, texts = encode_supervisions(
supervisions, subsampling_factor=params.subsampling_factor
)
token_ids = graph_compiler.texts_to_ids(texts)
decoding_graph = graph_compiler.compile(token_ids)
dense_fsa_vec = k2.DenseFsaVec(
nnet_output,
supervision_segments,
allow_truncate=params.subsampling_factor - 1,
)
ctc_loss = k2.ctc_loss(
decoding_graph=decoding_graph,
dense_fsa_vec=dense_fsa_vec,
output_beam=params.beam_size,
reduction=params.reduction,
use_double_scores=params.use_double_scores,
)
if params.att_rate != 0.0:
with torch.set_grad_enabled(is_training):
mmodel = model.module if hasattr(model, "module") else model
# Note: We need to generate an unsorted version of token_ids
# `encode_supervisions()` called above sorts text, but
# encoder_memory and memory_mask are not sorted, so we
# use an unsorted version `supervisions["text"]` to regenerate
# the token_ids
#
# See https://github.com/k2-fsa/icefall/issues/97
# for more details
unsorted_token_ids = graph_compiler.texts_to_ids(
supervisions["text"]
)
att_loss = mmodel.decoder_forward(
encoder_memory,
memory_mask,
token_ids=unsorted_token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
else:
loss = ctc_loss
att_loss = torch.tensor([0])
assert loss.requires_grad == is_training
info = MetricsTracker()
info["frames"] = supervision_segments[:, 2].sum().item()
info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.att_rate != 0.0:
info["att_loss"] = att_loss.detach().cpu().item()
info["loss"] = loss.detach().cpu().item()
return loss, info
def compute_validation_loss(
params: AttributeDict,
model: nn.Module,
graph_compiler: BpeCtcTrainingGraphCompiler,
valid_dl: torch.utils.data.DataLoader,
world_size: int = 1,
) -> MetricsTracker:
"""Run the validation process."""
model.eval()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(valid_dl):
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=False,
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
if world_size > 1:
tot_loss.reduce(loss.device)
loss_value = tot_loss["loss"] / tot_loss["frames"]
if loss_value < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
params.best_valid_loss = loss_value
return tot_loss
def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
graph_compiler: BpeCtcTrainingGraphCompiler,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
) -> None:
"""Train the model for one epoch.
The training loss from the mean of all frames is saved in
`params.train_loss`. It runs the validation process every
`params.valid_interval` batches.
Args:
params:
It is returned by :func:`get_params`.
model:
The model for training.
optimizer:
The optimizer we are using.
graph_compiler:
It is used to convert transcripts to FSAs.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
"""
model.train()
tot_loss = MetricsTracker()
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
)
if batch_idx % params.log_interval == 0:
if tb_writer is not None:
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(
tb_writer, "train/tot_", params.batch_idx_train
)
if batch_idx > 0 and batch_idx % params.valid_interval == 0:
logging.info("Computing validation loss")
valid_info = compute_validation_loss(
params=params,
model=model,
graph_compiler=graph_compiler,
valid_dl=valid_dl,
world_size=world_size,
)
model.train()
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
if tb_writer is not None:
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
params.best_train_epoch = params.cur_epoch
params.best_train_loss = params.train_loss
def run(rank, world_size, args):
"""
Args:
rank:
It is a value between 0 and `world_size-1`, which is
passed automatically by `mp.spawn()` in :func:`main`.
The node with rank 0 is responsible for saving checkpoint.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
fix_random_seed(42)
if world_size > 1:
setup_dist(rank, world_size, params.master_port)
setup_logger(f"{params.exp_dir}/log/log-train")
logging.info("Training started")
logging.info(params)
if args.tensorboard and rank == 0:
tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
else:
tb_writer = None
lexicon = Lexicon(params.lang_dir)
max_token_id = max(lexicon.tokens)
num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
graph_compiler = BpeCtcTrainingGraphCompiler(
params.lang_dir,
device=device,
sos_token="<sos/eos>",
eos_token="<sos/eos>",
)
logging.info("About to create model")
model = Conformer(
num_features=params.feature_dim,
nhead=params.nhead,
d_model=params.attention_dim,
num_classes=num_classes,
subsampling_factor=params.subsampling_factor,
num_decoder_layers=params.num_decoder_layers,
vgg_frontend=False,
use_feat_batchnorm=params.use_feat_batchnorm,
)
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
model = DDP(model, device_ids=[rank])
optimizer = Noam(
model.parameters(),
model_size=params.attention_dim,
factor=params.lr_factor,
warm_step=params.warm_step,
weight_decay=params.weight_decay,
)
if checkpoints:
optimizer.load_state_dict(checkpoints["optimizer"])
librispeech = LibriSpeechAsrDataModule(args)
train_dl = librispeech.train_dataloaders()
valid_dl = librispeech.valid_dataloaders()
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
graph_compiler=graph_compiler,
params=params,
)
for epoch in range(params.start_epoch, params.num_epochs):
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
graph_compiler=graph_compiler,
train_dl=train_dl,
valid_dl=valid_dl,
tb_writer=tb_writer,
world_size=world_size,
)
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)
logging.info("Done!")
if world_size > 1:
torch.distributed.barrier()
cleanup_dist()
def scan_pessimistic_batches_for_oom(
model: nn.Module,
train_dl: torch.utils.data.DataLoader,
optimizer: torch.optim.Optimizer,
graph_compiler: BpeCtcTrainingGraphCompiler,
params: AttributeDict,
):
from lhotse.dataset import find_pessimistic_batches
logging.info(
"Sanity check -- see if any of the batches in epoch 0 would cause OOM."
)
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=True,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(
"Your GPU ran out of memory with the current "
"max_duration setting. We recommend decreasing "
"max_duration and trying again.\n"
f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..."
)
raise
def main():
parser = get_parser()
LibriSpeechAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True)
else:
run(rank=0, world_size=1, args=args)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,966 @@
# Copyright 2021 University of Chinese Academy of Sciences (author: Han Zhu)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from label_smoothing import LabelSmoothingLoss
from subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence
# Note: TorchScript requires Dict/List/etc. to be fully typed.
Supervisions = Dict[str, torch.Tensor]
class Transformer(nn.Module):
def __init__(
self,
num_features: int,
num_classes: int,
subsampling_factor: int = 4,
d_model: int = 256,
nhead: int = 4,
dim_feedforward: int = 2048,
num_encoder_layers: int = 12,
num_decoder_layers: int = 6,
dropout: float = 0.1,
normalize_before: bool = True,
vgg_frontend: bool = False,
use_feat_batchnorm: bool = False,
) -> None:
"""
Args:
num_features:
The input dimension of the model.
num_classes:
The output dimension of the model.
subsampling_factor:
Number of output frames is num_in_frames // subsampling_factor.
Currently, subsampling_factor MUST be 4.
d_model:
Attention dimension.
nhead:
Number of heads in multi-head attention.
Must satisfy d_model // nhead == 0.
dim_feedforward:
The output dimension of the feedforward layers in encoder/decoder.
num_encoder_layers:
Number of encoder layers.
num_decoder_layers:
Number of decoder layers.
dropout:
Dropout in encoder/decoder.
normalize_before:
If True, use pre-layer norm; False to use post-layer norm.
vgg_frontend:
True to use vgg style frontend for subsampling.
use_feat_batchnorm:
True to use batchnorm for the input layer.
"""
super().__init__()
self.use_feat_batchnorm = use_feat_batchnorm
if use_feat_batchnorm:
self.feat_batchnorm = nn.BatchNorm1d(num_features)
self.num_features = num_features
self.num_classes = num_classes
self.subsampling_factor = subsampling_factor
if subsampling_factor != 4:
raise NotImplementedError("Support only 'subsampling_factor=4'.")
# self.encoder_embed converts the input of shape (N, T, num_classes)
# to the shape (N, T//subsampling_factor, d_model).
# That is, it does two things simultaneously:
# (1) subsampling: T -> T//subsampling_factor
# (2) embedding: num_classes -> d_model
if vgg_frontend:
self.encoder_embed = VggSubsampling(num_features, d_model)
else:
self.encoder_embed = Conv2dSubsampling(num_features, d_model)
self.encoder_pos = PositionalEncoding(d_model, dropout)
encoder_layer = TransformerEncoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
normalize_before=normalize_before,
)
if normalize_before:
encoder_norm = nn.LayerNorm(d_model)
else:
encoder_norm = None
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=num_encoder_layers,
norm=encoder_norm,
)
# TODO(fangjun): remove dropout
self.encoder_output_layer = nn.Sequential(
nn.Dropout(p=dropout), nn.Linear(d_model, num_classes)
)
if num_decoder_layers > 0:
self.decoder_num_class = (
self.num_classes
) # bpe model already has sos/eos symbol
self.decoder_embed = nn.Embedding(
num_embeddings=self.decoder_num_class, embedding_dim=d_model
)
self.decoder_pos = PositionalEncoding(d_model, dropout)
decoder_layer = TransformerDecoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
normalize_before=normalize_before,
)
if normalize_before:
decoder_norm = nn.LayerNorm(d_model)
else:
decoder_norm = None
self.decoder = nn.TransformerDecoder(
decoder_layer=decoder_layer,
num_layers=num_decoder_layers,
norm=decoder_norm,
)
self.decoder_output_layer = torch.nn.Linear(
d_model, self.decoder_num_class
)
self.decoder_criterion = LabelSmoothingLoss()
else:
self.decoder_criterion = None
def forward(
self,
x: torch.Tensor,
supervision: Optional[Supervisions] = None,
dynamic_chunk_training: bool = False,
short_chunk_proportion: float = 0.5,
chunk_size: int = -1,
simulate_streaming=False,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
x:
The input tensor. Its shape is (N, T, C).
supervision:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
(CAUTION: It contains length information, i.e., start and number of
frames, before subsampling)
Returns:
Return a tuple containing 3 tensors:
- CTC output for ctc decoding. Its shape is (N, T, C)
- Encoder output with shape (T, N, C). It can be used as key and
value for the decoder.
- Encoder output padding mask. It can be used as
memory_key_padding_mask for the decoder. Its shape is (N, T).
It is None if `supervision` is None.
"""
if self.use_feat_batchnorm:
x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T)
x = self.feat_batchnorm(x)
x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C)
encoder_memory, memory_key_padding_mask = self.run_encoder(
x,
supervision,
dynamic_chunk_training=dynamic_chunk_training,
short_chunk_proportion=short_chunk_proportion,
chunk_size=chunk_size,
simulate_streaming=simulate_streaming,
)
x = self.ctc_output(encoder_memory)
return x, encoder_memory, memory_key_padding_mask
def run_encoder(
self,
x: torch.Tensor,
supervisions: Optional[Supervisions] = None,
chunk_size: int = -1,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Run the transformer encoder.
Args:
x:
The model input. Its shape is (N, T, C).
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
CAUTION: It contains length information, i.e., start and number of
frames, before subsampling
It is read directly from the batch, without any sorting. It is used
to compute the encoder padding mask, which is used as memory key
padding mask for the decoder.
chunk_size: right chunk_size to simulate streaming decoding
-1 for whole right context
Returns:
Return a tuple with two tensors:
- The encoder output, with shape (T, N, C)
- encoder padding mask, with shape (N, T).
The mask is None if `supervisions` is None.
It is used as memory key padding mask in the decoder.
"""
# streaming decoding(chunk_size >= 0) is only verified with Conformer
assert chunk_size == -1
x = self.encoder_embed(x)
x = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
mask = encoder_padding_mask(x.size(0), supervisions)
mask = mask.to(x.device) if mask is not None else None
x = self.encoder(
x, src_key_padding_mask=mask, chunk_size=chunk_size
) # (T, N, C)
return x, mask
def ctc_output(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
The output tensor from the transformer encoder.
Its shape is (T, N, C)
Returns:
Return a tensor that can be used for CTC decoding.
Its shape is (N, T, C)
"""
x = self.encoder_output_layer(x)
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
x = nn.functional.log_softmax(x, dim=-1) # (N, T, C)
return x
@torch.jit.export
def decoder_forward(
self,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
token_ids: List[List[int]],
sos_id: int,
eos_id: int,
) -> torch.Tensor:
"""
Args:
memory:
It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
A list-of-list IDs. Each sublist contains IDs for an utterance.
The IDs can be either phone IDs or word piece IDs.
sos_id:
sos token id
eos_id:
eos token id
Returns:
A scalar, the **sum** of label smoothing loss over utterances
in the batch without any normalization.
"""
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device
ys_in_pad = ys_in_pad.to(device)
ys_out_pad = ys_out_pad.to(device)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
# We set the first column to False since the first column in ys_in_pad
# contains sos_id, which is the same as eos_id in our current setting.
tgt_key_padding_mask[:, 0] = False
tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C)
tgt = self.decoder_pos(tgt)
tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
pred_pad = self.decoder(
tgt=tgt,
memory=memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
) # (T, N, C)
pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C)
pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C)
decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad)
return decoder_loss
@torch.jit.export
def decoder_nll(
self,
memory: torch.Tensor,
memory_key_padding_mask: torch.Tensor,
token_ids: List[torch.Tensor],
sos_id: int,
eos_id: int,
) -> torch.Tensor:
"""
Args:
memory:
It's the output of the encoder with shape (T, N, C)
memory_key_padding_mask:
The padding mask from the encoder.
token_ids:
A list-of-list IDs (e.g., word piece IDs).
Each sublist represents an utterance.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
Returns:
A 2-D tensor of shape (len(token_ids), max_token_length)
representing the cross entropy loss (i.e., negative log-likelihood).
"""
# The common part between this function and decoder_forward could be
# extracted as a separate function.
if isinstance(token_ids[0], torch.Tensor):
# This branch is executed by torchscript in C++.
# See https://github.com/k2-fsa/k2/pull/870
# https://github.com/k2-fsa/k2/blob/3c1c18400060415b141ccea0115fd4bf0ad6234e/k2/torch/bin/attention_rescore.cu#L286
token_ids = [tolist(t) for t in token_ids]
ys_in = add_sos(token_ids, sos_id=sos_id)
ys_in = [torch.tensor(y) for y in ys_in]
ys_in_pad = pad_sequence(
ys_in, batch_first=True, padding_value=float(eos_id)
)
ys_out = add_eos(token_ids, eos_id=eos_id)
ys_out = [torch.tensor(y) for y in ys_out]
ys_out_pad = pad_sequence(
ys_out, batch_first=True, padding_value=float(-1)
)
device = memory.device
ys_in_pad = ys_in_pad.to(device, dtype=torch.int64)
ys_out_pad = ys_out_pad.to(device, dtype=torch.int64)
tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(
device
)
tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id)
# TODO: Use length information to create the decoder padding mask
# We set the first column to False since the first column in ys_in_pad
# contains sos_id, which is the same as eos_id in our current setting.
tgt_key_padding_mask[:, 0] = False
tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F)
tgt = self.decoder_pos(tgt)
tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F)
pred_pad = self.decoder(
tgt=tgt,
memory=memory,
tgt_mask=tgt_mask,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
) # (T, B, F)
pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F)
pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F)
# nll: negative log-likelihood
nll = torch.nn.functional.cross_entropy(
pred_pad.view(-1, self.decoder_num_class),
ys_out_pad.view(-1),
ignore_index=-1,
reduction="none",
)
nll = nll.view(pred_pad.shape[0], -1)
return nll
class TransformerEncoderLayer(nn.Module):
"""
Modified from torch.nn.TransformerEncoderLayer.
Add support of normalize_before,
i.e., use layer_norm before the first block.
Args:
d_model:
the number of expected features in the input (required).
nhead:
the number of heads in the multiheadattention models (required).
dim_feedforward:
the dimension of the feedforward network model (default=2048).
dropout:
the dropout value (default=0.1).
activation:
the activation function of intermediate layer, relu or
gelu (default=relu).
normalize_before:
whether to use layer_norm before the first block.
Examples::
>>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> out = encoder_layer(src)
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: str = "relu",
normalize_before: bool = True,
) -> None:
super(TransformerEncoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def __setstate__(self, state):
if "activation" not in state:
state["activation"] = nn.functional.relu
super(TransformerEncoderLayer, self).__setstate__(state)
def forward(
self,
src: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional)
Shape:
src: (S, N, E).
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length,
N is the batch size, E is the feature number
"""
residual = src
if self.normalize_before:
src = self.norm1(src)
src2 = self.self_attn(
src,
src,
src,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)[0]
src = residual + self.dropout1(src2)
if not self.normalize_before:
src = self.norm1(src)
residual = src
if self.normalize_before:
src = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
src = residual + self.dropout2(src2)
if not self.normalize_before:
src = self.norm2(src)
return src
class TransformerDecoderLayer(nn.Module):
"""
Modified from torch.nn.TransformerDecoderLayer.
Add support of normalize_before,
i.e., use layer_norm before the first block.
Args:
d_model:
the number of expected features in the input (required).
nhead:
the number of heads in the multiheadattention models (required).
dim_feedforward:
the dimension of the feedforward network model (default=2048).
dropout:
the dropout value (default=0.1).
activation:
the activation function of intermediate layer, relu or
gelu (default=relu).
Examples::
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> memory = torch.rand(10, 32, 512)
>>> tgt = torch.rand(20, 32, 512)
>>> out = decoder_layer(tgt, memory)
"""
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: str = "relu",
normalize_before: bool = True,
) -> None:
super(TransformerDecoderLayer, self).__init__()
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
self.src_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def __setstate__(self, state):
if "activation" not in state:
state["activation"] = nn.functional.relu
super(TransformerDecoderLayer, self).__setstate__(state)
def forward(
self,
tgt: torch.Tensor,
memory: torch.Tensor,
tgt_mask: Optional[torch.Tensor] = None,
memory_mask: Optional[torch.Tensor] = None,
tgt_key_padding_mask: Optional[torch.Tensor] = None,
memory_key_padding_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Pass the inputs (and mask) through the decoder layer.
Args:
tgt:
the sequence to the decoder layer (required).
memory:
the sequence from the last layer of the encoder (required).
tgt_mask:
the mask for the tgt sequence (optional).
memory_mask:
the mask for the memory sequence (optional).
tgt_key_padding_mask:
the mask for the tgt keys per batch (optional).
memory_key_padding_mask:
the mask for the memory keys per batch (optional).
Shape:
tgt: (T, N, E).
memory: (S, N, E).
tgt_mask: (T, T).
memory_mask: (T, S).
tgt_key_padding_mask: (N, T).
memory_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length,
N is the batch size, E is the feature number
"""
residual = tgt
if self.normalize_before:
tgt = self.norm1(tgt)
tgt2 = self.self_attn(
tgt,
tgt,
tgt,
attn_mask=tgt_mask,
key_padding_mask=tgt_key_padding_mask,
)[0]
tgt = residual + self.dropout1(tgt2)
if not self.normalize_before:
tgt = self.norm1(tgt)
residual = tgt
if self.normalize_before:
tgt = self.norm2(tgt)
tgt2 = self.src_attn(
tgt,
memory,
memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = residual + self.dropout2(tgt2)
if not self.normalize_before:
tgt = self.norm2(tgt)
residual = tgt
if self.normalize_before:
tgt = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = residual + self.dropout3(tgt2)
if not self.normalize_before:
tgt = self.norm3(tgt)
return tgt
def _get_activation_fn(activation: str):
if activation == "relu":
return nn.functional.relu
elif activation == "gelu":
return nn.functional.gelu
raise RuntimeError(
"activation should be relu/gelu, not {}".format(activation)
)
class PositionalEncoding(nn.Module):
"""This class implements the positional encoding
proposed in the following paper:
- Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf
PE(pos, 2i) = sin(pos / (10000^(2i/d_modle))
PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle))
Note::
1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model)))
= exp(-1* 2i / d_model * log(100000))
= exp(2i * -(log(10000) / d_model))
"""
def __init__(self, d_model: int, dropout: float = 0.1) -> None:
"""
Args:
d_model:
Embedding dimension.
dropout:
Dropout probability to be applied to the output of this module.
"""
super().__init__()
self.d_model = d_model
self.xscale = math.sqrt(self.d_model)
self.dropout = nn.Dropout(p=dropout)
# not doing: self.pe = None because of errors thrown by torchscript
self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32)
def extend_pe(self, x: torch.Tensor) -> None:
"""Extend the time t in the positional encoding if required.
The shape of `self.pe` is (1, T1, d_model). The shape of the input x
is (N, T, d_model). If T > T1, then we change the shape of self.pe
to (N, T, d_model). Otherwise, nothing is done.
Args:
x:
It is a tensor of shape (N, T, C).
Returns:
Return None.
"""
if self.pe is not None:
if self.pe.size(1) >= x.size(1):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32)
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
# Now pe is of shape (1, T, d_model), where T is x.size(1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Add positional encoding.
Args:
x:
Its shape is (N, T, C)
Returns:
Return a tensor of shape (N, T, C)
"""
self.extend_pe(x)
x = x * self.xscale + self.pe[:, : x.size(1), :]
return self.dropout(x)
class Noam(object):
"""
Implements Noam optimizer.
Proposed in
"Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa
Args:
params:
iterable of parameters to optimize or dicts defining parameter groups
model_size:
attention dimension of the transformer model
factor:
learning rate factor
warm_step:
warmup steps
"""
def __init__(
self,
params,
model_size: int = 256,
factor: float = 10.0,
warm_step: int = 25000,
weight_decay=0,
) -> None:
"""Construct an Noam object."""
self.optimizer = torch.optim.Adam(
params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay
)
self._step = 0
self.warmup = warm_step
self.factor = factor
self.model_size = model_size
self._rate = 0
@property
def param_groups(self):
"""Return param_groups."""
return self.optimizer.param_groups
def step(self):
"""Update parameters and rate."""
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p["lr"] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step=None):
"""Implement `lrate` above."""
if step is None:
step = self._step
return (
self.factor
* self.model_size ** (-0.5)
* min(step ** (-0.5), step * self.warmup ** (-1.5))
)
def zero_grad(self):
"""Reset gradient."""
self.optimizer.zero_grad()
def state_dict(self):
"""Return state_dict."""
return {
"_step": self._step,
"warmup": self.warmup,
"factor": self.factor,
"model_size": self.model_size,
"_rate": self._rate,
"optimizer": self.optimizer.state_dict(),
}
def load_state_dict(self, state_dict):
"""Load state_dict."""
for key, value in state_dict.items():
if key == "optimizer":
self.optimizer.load_state_dict(state_dict["optimizer"])
else:
setattr(self, key, value)
def encoder_padding_mask(
max_len: int, supervisions: Optional[Supervisions] = None
) -> Optional[torch.Tensor]:
"""Make mask tensor containing indexes of padded part.
TODO::
This function **assumes** that the model uses
a subsampling factor of 4. We should remove that
assumption later.
Args:
max_len:
Maximum length of input features.
CAUTION: It is the length after subsampling.
supervisions:
Supervision in lhotse format.
See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa
(CAUTION: It contains length information, i.e., start and number of
frames, before subsampling)
Returns:
Tensor: Mask tensor of dimension (batch_size, input_length),
True denote the masked indices.
"""
if supervisions is None:
return None
supervision_segments = torch.stack(
(
supervisions["sequence_idx"],
supervisions["start_frame"],
supervisions["num_frames"],
),
1,
).to(torch.int32)
lengths = [
0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)
]
for idx in range(supervision_segments.size(0)):
# Note: TorchScript doesn't allow to unpack tensors as tuples
sequence_idx = supervision_segments[idx, 0].item()
start_frame = supervision_segments[idx, 1].item()
num_frames = supervision_segments[idx, 2].item()
lengths[sequence_idx] = start_frame + num_frames
lengths = [((i - 1) // 2 - 1) // 2 for i in lengths]
bs = int(len(lengths))
seq_range = torch.arange(0, max_len, dtype=torch.int64)
seq_range_expand = seq_range.unsqueeze(0).expand(bs, max_len)
# Note: TorchScript doesn't implement Tensor.new()
seq_length_expand = torch.tensor(
lengths, device=seq_range_expand.device, dtype=seq_range_expand.dtype
).unsqueeze(-1)
mask = seq_range_expand >= seq_length_expand
return mask
def decoder_padding_mask(
ys_pad: torch.Tensor, ignore_id: int = -1
) -> torch.Tensor:
"""Generate a length mask for input.
The masked position are filled with True,
Unmasked positions are filled with False.
Args:
ys_pad:
padded tensor of dimension (batch_size, input_length).
ignore_id:
the ignored number (the padding number) in ys_pad
Returns:
Tensor:
a bool tensor of the same shape as the input tensor.
"""
ys_mask = ys_pad == ignore_id
return ys_mask
def generate_square_subsequent_mask(sz: int) -> torch.Tensor:
"""Generate a square mask for the sequence. The masked positions are
filled with float('-inf'). Unmasked positions are filled with float(0.0).
The mask can be used for masked self-attention.
For instance, if sz is 3, it returns::
tensor([[0., -inf, -inf],
[0., 0., -inf],
[0., 0., 0]])
Args:
sz: mask size
Returns:
A square mask of dimension (sz, sz)
"""
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = (
mask.float()
.masked_fill(mask == 0, float("-inf"))
.masked_fill(mask == 1, float(0.0))
)
return mask
def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]:
"""Prepend sos_id to each utterance.
Args:
token_ids:
A list-of-list of token IDs. Each sublist contains
token IDs (e.g., word piece IDs) of an utterance.
sos_id:
The ID of the SOS token.
Return:
Return a new list-of-list, where each sublist starts
with SOS ID.
"""
return [[sos_id] + utt for utt in token_ids]
def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]:
"""Append eos_id to each utterance.
Args:
token_ids:
A list-of-list of token IDs. Each sublist contains
token IDs (e.g., word piece IDs) of an utterance.
eos_id:
The ID of the EOS token.
Return:
Return a new list-of-list, where each sublist ends
with EOS ID.
"""
return [utt + [eos_id] for utt in token_ids]
def tolist(t: torch.Tensor) -> List[int]:
"""Used by jit"""
return torch.jit.annotate(List[int], t.tolist())