Use k2 pruned RNN-T.

This commit is contained in:
Fangjun Kuang 2022-04-28 14:13:26 +08:00
parent ea969e9b84
commit 026f446a4d
9 changed files with 529 additions and 455 deletions

View File

@ -1 +1 @@
../transducer_stateless/beam_search.py
../pruned_transducer_stateless2/beam_search.py

View File

@ -19,16 +19,16 @@
Usage:
(1) greedy search
./transducer_lstm/decode.py \
--epoch 14 \
--avg 7 \
--epoch 28 \
--avg 15 \
--exp-dir ./transducer_lstm/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./transducer_lstm/decode.py \
--epoch 14 \
--avg 7 \
--epoch 28 \
--avg 15 \
--exp-dir ./transducer_lstm/exp \
--max-duration 100 \
--decoding-method beam_search \
@ -36,12 +36,23 @@ Usage:
(3) modified beam search
./transducer_lstm/decode.py \
--epoch 14 \
--avg 7 \
--epoch 28 \
--avg 15 \
--exp-dir ./transducer_lstm/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./transducer_lstm/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./transducer_lstm/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
"""
@ -49,21 +60,27 @@ import argparse
import logging
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
fast_beam_search,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint
from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import (
AttributeDict,
setup_logger,
@ -80,17 +97,29 @@ def get_parser():
parser.add_argument(
"--epoch",
type=int,
default=29,
help="It specifies the checkpoint to use for decoding."
"Note: Epoch counts from 0.",
default=28,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=13,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ",
"'--epoch' and '--iter'",
)
parser.add_argument(
@ -115,6 +144,7 @@ def get_parser():
- greedy_search
- beam_search
- modified_beam_search
- fast_beam_search
""",
)
@ -122,8 +152,35 @@ def get_parser():
"--beam-size",
type=int,
default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
beam_search or modified_beam_search""",
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
@ -149,6 +206,7 @@ def decode_one_batch(
model: nn.Module,
sp: spm.SentencePieceProcessor,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
@ -171,6 +229,9 @@ def decode_one_batch(
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
@ -188,24 +249,41 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens
)
hyp_list: List[List[int]] = []
hyps = []
if (
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_list = greedy_search_batch(
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search":
hyp_list = modified_beam_search(
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -226,14 +304,20 @@ def decode_one_batch(
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyp_list.append(hyp)
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
hyps.append(sp.decode(hyp).split())
if params.decoding_method == "greedy_search":
return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else:
return {f"beam_{params.beam_size}": hyps}
return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset(
@ -241,6 +325,7 @@ def decode_dataset(
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset.
@ -253,6 +338,9 @@ def decode_dataset(
The neural model.
sp:
The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
@ -280,6 +368,7 @@ def decode_dataset(
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
batch=batch,
)
@ -360,13 +449,24 @@ def main():
assert params.decoding_method in (
"greedy_search",
"beam_search",
"fast_beam_search",
"modified_beam_search",
)
params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam_size}"
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -383,8 +483,9 @@ def main():
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
# <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(params)
@ -392,7 +493,24 @@ def main():
logging.info("About to create model")
model = get_transducer_model(params)
if params.avg == 1:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
@ -408,6 +526,11 @@ def main():
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
@ -428,6 +551,7 @@ def main():
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
)
save_results(

View File

@ -1,98 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
embedding_dim: int,
blank_id: int,
context_size: int,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
embedding_dim:
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=blank_id,
)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=embedding_dim,
out_channels=embedding_dim,
kernel_size=context_size,
padding=0,
groups=embedding_dim,
bias=False,
)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
embedding_out = self.embedding(y)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
return embedding_out

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/decoder.py

View File

@ -1,57 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.output_linear = nn.Linear(input_dim, output_dim)
def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, *unused
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, C).
decoder_out:
Output from the decoder. Its shape is (N, U, C).
Returns:
Return a tensor of shape (N, T, U, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 3
assert encoder_out.size(0) == decoder_out.size(0)
assert encoder_out.size(2) == decoder_out.size(2)
encoder_out = encoder_out.unsqueeze(2)
# Now encoder_out is (N, T, 1, C)
decoder_out = decoder_out.unsqueeze(1)
# Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out
logit = F.relu(logit)
output = self.output_linear(logit)
if not self.training:
output = output.squeeze(2).squeeze(1)
return output

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/joiner.py

View File

@ -1,126 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note we use `rnnt_loss` from torchaudio, which exists only in
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
"""
import k2
import torch
import torch.nn as nn
import torchaudio
import torchaudio.functional
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, C) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface)
assert hasattr(decoder, "blank_id")
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
Returns:
Return the transducer loss.
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
decoder_out = self.decoder(sos_y_padded)
logits = self.joiner(encoder_out, decoder_out)
# rnnt_loss requires 0 padded targets
# Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0)
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
)
loss = torchaudio.functional.rnnt_loss(
logits=logits,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="sum",
)
return loss

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/model.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/scaling.py

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python3
#
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -15,33 +14,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
To run this file, do:
cd icefall/egs/librispeech/ASR
python ./transducer_lstm/test_encoder.py
python ./pruned_transducer_stateless4/test_model.py
"""
from encoder import LstmEncoder
from train import get_params, get_transducer_model
def test_encoder():
encoder = LstmEncoder(
num_features=80,
hidden_size=1024,
proj_size=512,
output_dim=512,
subsampling_factor=4,
num_encoder_layers=12,
)
num_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
print(num_params)
# 93979284
# 66427392
def test_model():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
def main():
test_encoder()
test_model()
if __name__ == "__main__":

View File

@ -16,20 +16,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2"
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer_lstm/train.py \
--world-size 3 \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir transducer_lstm/exp \
--full-libri 1 \
--max-duration 400 \
--lr-factor 3
--max-duration 300
# For mix precision training:
./transducer_lstm/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--use-fp16 1 \
--exp-dir transducer_lstm/exp \
--full-libri 1 \
--max-duration 550
"""
@ -38,32 +48,40 @@ import logging
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Optional, Tuple
from typing import Any, Dict, Optional, Tuple, Union
import k2
import optim
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder
from encoder import LstmEncoder
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
from noam import Noam
from optim import Eden, Eve
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def get_parser():
parser = argparse.ArgumentParser(
@ -104,7 +122,16 @@ def get_parser():
default=0,
help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from
transducer_lstm/exp/epoch-{start_epoch-1}.pt
transducer_stateless2/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--start-batch",
type=int,
default=0,
help="""If positive, --start-epoch is ignored and
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
""",
)
@ -126,10 +153,68 @@ def get_parser():
)
parser.add_argument(
"--lr-factor",
"--initial-lr",
type=float,
default=3.0,
help="The lr_factor for Noam optimizer",
default=0.003,
help="The initial learning rate. This value should not need to be changed.",
)
parser.add_argument(
"--lr-batches",
type=float,
default=5000,
help="""Number of steps that affects how rapidly the learning rate decreases.
We suggest not to change this.""",
)
parser.add_argument(
"--lr-epochs",
type=float,
default=6,
help="""Number of epochs that affects how rapidly the learning rate decreases.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--prune-range",
type=int,
default=5,
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.25,
help="The scale to smooth the loss with lm "
"(output of prediction network) part.",
)
parser.add_argument(
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
)
parser.add_argument(
"--simple-loss-scale",
type=float,
default=0.5,
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.",
)
parser.add_argument(
@ -140,11 +225,41 @@ def get_parser():
)
parser.add_argument(
"--context-size",
"--print-diagnostics",
type=str2bool,
default=False,
help="Accumulate stats on activations, print them and exit.",
)
parser.add_argument(
"--save-every-n",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
default=8000,
help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
""",
)
parser.add_argument(
"--keep-last-k",
type=int,
default=20,
help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`.
It does not affect checkpoints with name `epoch-xxx.pt`.
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
)
return parser
@ -188,15 +303,10 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the
input features.
- attention_dim: Hidden dim for multi-head attention model.
- encoder_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- weight_decay: The weight_decay for the optimizer.
- warm_step: The warm_step for Noam optimizer.
"""
params = AttributeDict(
@ -209,21 +319,20 @@ def get_params() -> AttributeDict:
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
# parameters for conformer
# parameters for encoder
"feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4,
"encoder_dim": 512,
"encoder_hidden_size": 1024,
"num_encoder_layers": 4,
"proj_size": 512,
"vgg_frontend": False,
# decoder params
"decoder_embedding_dim": 1024,
"num_decoder_layers": 4,
"decoder_hidden_dim": 512,
# parameters for decoder
"decoder_dim": 512,
# parameters for joiner
"joiner_dim": 512,
# parameters for Noam
"weight_decay": 1e-6,
"warm_step": 80000, # For the 100h subset, use 8k
"model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(),
}
)
@ -231,11 +340,11 @@ def get_params() -> AttributeDict:
return params
def get_encoder_model(params: AttributeDict):
def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = LstmEncoder(
num_features=params.feature_dim,
hidden_size=params.encoder_hidden_size,
output_dim=params.encoder_out_dim,
output_dim=params.encoder_dim,
subsampling_factor=params.subsampling_factor,
num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend,
@ -246,22 +355,24 @@ def get_encoder_model(params: AttributeDict):
def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder(
vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim,
decoder_dim=params.decoder_dim,
blank_id=params.blank_id,
context_size=params.context_size,
)
return decoder
def get_joiner_model(params: AttributeDict):
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.encoder_out_dim,
output_dim=params.vocab_size,
encoder_dim=params.encoder_dim,
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
)
return joiner
def get_transducer_model(params: AttributeDict):
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
@ -270,6 +381,10 @@ def get_transducer_model(params: AttributeDict):
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=params.encoder_dim,
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
)
return model
@ -278,15 +393,17 @@ def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
) -> None:
scheduler: Optional[LRSchedulerType] = None,
) -> Optional[Dict[str, Any]]:
"""Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing.
If params.start_batch is positive, it will load the checkpoint from
`params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model`, `optimizer` and `scheduler`,
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
Apart from loading state dict for `model` and `optimizer` it also updates
`best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`.
Args:
@ -297,14 +414,19 @@ def load_checkpoint_if_available(
optimizer:
The optimizer that we are using.
scheduler:
The learning rate scheduler we are using.
The scheduler that we are using.
Returns:
Return None.
Return a dict containing previously saved training info.
"""
if params.start_epoch <= 0:
return
if params.start_batch > 0:
filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 0:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename,
model=model,
@ -322,6 +444,13 @@ def load_checkpoint_if_available(
for k in keys:
params[k] = saved_params[k]
if params.start_batch > 0:
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params
@ -329,7 +458,9 @@ def save_checkpoint(
params: AttributeDict,
model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
rank: int = 0,
) -> None:
"""Save model, optimizer, scheduler and training stats to file.
@ -339,6 +470,12 @@ def save_checkpoint(
It is returned by :func:`get_params`.
model:
The training model.
optimizer:
The optimizer used in the training.
sampler:
The sampler for the training dataset.
scaler:
The scaler used for mix precision training.
"""
if rank != 0:
return
@ -349,6 +486,8 @@ def save_checkpoint(
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=sampler,
scaler=scaler,
rank=rank,
)
@ -367,6 +506,7 @@ def compute_loss(
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
warmup: float = 1.0,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
@ -383,6 +523,8 @@ def compute_loss(
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
device = model.device
feature = batch["inputs"]
@ -398,21 +540,42 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=y)
simple_loss, pruned_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
)
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
assert loss.requires_grad == is_training
info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
info["frames"] = (
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
return loss, info
@ -455,11 +618,14 @@ def train_one_epoch(
params: AttributeDict,
model: nn.Module,
optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
rank: int = 0,
) -> None:
"""Train the model for one epoch.
@ -474,51 +640,96 @@ def train_one_epoch(
The model for training.
optimizer:
The optimizer we are using.
scheduler:
The learning rate scheduler, we call step() every step.
train_dl:
Dataloader for the training dataset.
valid_dl:
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer:
Writer to write log messages to tensorboard.
world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
"""
model.train()
tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step),
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}"
if params.print_diagnostics and batch_idx == 5:
return
if (
params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
)
if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}"
)
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train
)
@ -564,8 +775,7 @@ def run(rank, world_size, args):
params = get_params()
params.update(vars(args))
if params.full_libri is False:
params.valid_interval = 800
params.warm_step = 8000
params.valid_interval = 1600
fix_random_seed(params.seed)
if world_size > 1:
@ -596,29 +806,39 @@ def run(rank, world_size, args):
logging.info("About to create model")
model = get_transducer_model(params)
checkpoints = load_checkpoint_if_available(params=params, model=model)
num_param = sum([p.numel() for p in model.parameters() if p.requires_grad])
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device)
if world_size > 1:
logging.info("Using DDP")
model = DDP(model, device_ids=[rank])
model.device = device
optimizer = Noam(
model.parameters(),
model_size=params.encoder_hidden_size,
factor=params.lr_factor,
warm_step=params.warm_step,
weight_decay=params.weight_decay,
)
optimizer = Eve(model.parameters(), lr=params.initial_lr)
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"])
if (
checkpoints
and "scheduler" in checkpoints
and checkpoints["scheduler"] is not None
):
logging.info("Loading scheduler state dict")
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts()
@ -628,75 +848,81 @@ def run(rank, world_size, args):
def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts)
train_cuts = train_cuts.filter(remove_short_and_long_utt)
try:
num_left = len(train_cuts)
num_removed = num_in_total - num_left
removed_percent = num_removed / num_in_total * 100
if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
# We only load the sampler's state dict when it loads a checkpoint
# saved in the middle of an epoch
sampler_state_dict = checkpoints["sampler"]
else:
sampler_state_dict = None
logging.info(
f"Before removing short and long utterances: {num_in_total}"
)
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(
f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
)
except TypeError as e:
# You can ignore this error as previous versions of Lhotse work fine
# for the above code. In recent versions of Lhotse, it uses
# lazy filter, producing cutsets that don't have the __len__ method
logging.info(str(e))
train_dl = librispeech.train_dataloaders(train_cuts)
train_dl = librispeech.train_dataloaders(
train_cuts, sampler_state_dict=sampler_state_dict
)
valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
if not params.print_diagnostics:
scan_pessimistic_batches_for_oom(
model=model,
train_dl=train_dl,
optimizer=optimizer,
sp=sp,
params=params,
)
scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs):
scheduler.step_epoch(epoch)
fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch
train_one_epoch(
params=params,
model=model,
optimizer=optimizer,
scheduler=scheduler,
sp=sp,
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
)
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
save_checkpoint(
params=params,
model=model,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
@ -723,17 +949,21 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
optimizer.zero_grad()
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
)
# warmup = 0.0 is so that the derivs for the pruned loss stay zero
# (i.e. are not remembered by the decaying-average in adam), because
# we want to avoid these params being subject to shrinkage in adam.
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, _ = compute_loss(
params=params,
model=model,
sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
optimizer.zero_grad()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
logging.error(