Use k2 pruned RNN-T.

This commit is contained in:
Fangjun Kuang 2022-04-28 14:13:26 +08:00
parent ea969e9b84
commit 026f446a4d
9 changed files with 529 additions and 455 deletions

View File

@ -1 +1 @@
../transducer_stateless/beam_search.py ../pruned_transducer_stateless2/beam_search.py

View File

@ -19,16 +19,16 @@
Usage: Usage:
(1) greedy search (1) greedy search
./transducer_lstm/decode.py \ ./transducer_lstm/decode.py \
--epoch 14 \ --epoch 28 \
--avg 7 \ --avg 15 \
--exp-dir ./transducer_lstm/exp \ --exp-dir ./transducer_lstm/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search
./transducer_lstm/decode.py \ ./transducer_lstm/decode.py \
--epoch 14 \ --epoch 28 \
--avg 7 \ --avg 15 \
--exp-dir ./transducer_lstm/exp \ --exp-dir ./transducer_lstm/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
@ -36,12 +36,23 @@ Usage:
(3) modified beam search (3) modified beam search
./transducer_lstm/decode.py \ ./transducer_lstm/decode.py \
--epoch 14 \ --epoch 28 \
--avg 7 \ --avg 15 \
--exp-dir ./transducer_lstm/exp \ --exp-dir ./transducer_lstm/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search
./transducer_lstm/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./transducer_lstm/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
""" """
@ -49,21 +60,27 @@ import argparse
import logging import logging
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple
import k2
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.checkpoint import (
average_checkpoints,
find_checkpoints,
load_checkpoint,
)
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
setup_logger, setup_logger,
@ -80,17 +97,29 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--epoch", "--epoch",
type=int, type=int,
default=29, default=28,
help="It specifies the checkpoint to use for decoding." help="""It specifies the checkpoint to use for decoding.
"Note: Epoch counts from 0.", Note: Epoch counts from 0.
You can specify --avg to use more checkpoints for model averaging.""",
) )
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=13, default=15,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch'. ", "'--epoch' and '--iter'",
) )
parser.add_argument( parser.add_argument(
@ -115,6 +144,7 @@ def get_parser():
- greedy_search - greedy_search
- beam_search - beam_search
- modified_beam_search - modified_beam_search
- fast_beam_search
""", """,
) )
@ -122,8 +152,35 @@ def get_parser():
"--beam-size", "--beam-size",
type=int, type=int,
default=4, default=4,
help="""An integer indicating how many candidates we will keep for each
frame. Used only when --decoding-method is beam_search or
modified_beam_search.""",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is help="""Used only when --decoding-method is
beam_search or modified_beam_search""", fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
) )
parser.add_argument( parser.add_argument(
@ -149,6 +206,7 @@ def decode_one_batch(
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[List[str]]]: ) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
@ -171,6 +229,9 @@ def decode_one_batch(
It is the return value from iterating It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`. for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns: Returns:
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. the returned dict.
@ -188,24 +249,41 @@ def decode_one_batch(
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
) )
hyp_list: List[List[int]] = [] hyps = []
if ( if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif (
params.decoding_method == "greedy_search" params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1 and params.max_sym_per_frame == 1
): ):
hyp_list = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search": elif params.decoding_method == "modified_beam_search":
hyp_list = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
beam=params.beam_size, beam=params.beam_size,
) )
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else: else:
batch_size = encoder_out.size(0) batch_size = encoder_out.size(0)
for i in range(batch_size): for i in range(batch_size):
# fmt: off # fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
@ -226,14 +304,20 @@ def decode_one_batch(
raise ValueError( raise ValueError(
f"Unsupported decoding method: {params.decoding_method}" f"Unsupported decoding method: {params.decoding_method}"
) )
hyp_list.append(hyp) hyps.append(sp.decode(hyp).split())
hyps = [sp.decode(hyp).split() for hyp in hyp_list]
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
return {"greedy_search": hyps} return {"greedy_search": hyps}
elif params.decoding_method == "fast_beam_search":
return {
(
f"beam_{params.beam}_"
f"max_contexts_{params.max_contexts}_"
f"max_states_{params.max_states}"
): hyps
}
else: else:
return {f"beam_{params.beam_size}": hyps} return {f"beam_size_{params.beam_size}": hyps}
def decode_dataset( def decode_dataset(
@ -241,6 +325,7 @@ def decode_dataset(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
decoding_graph: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[List[str], List[str]]]]: ) -> Dict[str, List[Tuple[List[str], List[str]]]]:
"""Decode dataset. """Decode dataset.
@ -253,6 +338,9 @@ def decode_dataset(
The neural model. The neural model.
sp: sp:
The BPE model. The BPE model.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used
only when --decoding_method is fast_beam_search.
Returns: Returns:
Return a dict, whose key may be "greedy_search" if greedy search Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used. is used, or it may be "beam_7" if beam size of 7 is used.
@ -280,6 +368,7 @@ def decode_dataset(
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph,
batch=batch, batch=batch,
) )
@ -360,13 +449,24 @@ def main():
assert params.decoding_method in ( assert params.decoding_method in (
"greedy_search", "greedy_search",
"beam_search", "beam_search",
"fast_beam_search",
"modified_beam_search", "modified_beam_search",
) )
params.res_dir = params.exp_dir / params.decoding_method params.res_dir = params.exp_dir / params.decoding_method
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" if params.iter > 0:
if "beam_search" in params.decoding_method: params.suffix = f"iter-{params.iter}-avg-{params.avg}"
params.suffix += f"-beam-{params.beam_size}" else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if "fast_beam_search" in params.decoding_method:
params.suffix += f"-beam-{params.beam}"
params.suffix += f"-max-contexts-{params.max_contexts}"
params.suffix += f"-max-states-{params.max_states}"
elif "beam_search" in params.decoding_method:
params.suffix += (
f"-{params.decoding_method}-beam-size-{params.beam_size}"
)
else: else:
params.suffix += f"-context-{params.context_size}" params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
@ -383,8 +483,9 @@ def main():
sp = spm.SentencePieceProcessor() sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model) sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py # <blk> and <unk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>") params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
logging.info(params) logging.info(params)
@ -392,7 +493,24 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
if params.avg == 1: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else: else:
start = params.epoch - params.avg + 1 start = params.epoch - params.avg + 1
@ -408,6 +526,11 @@ def main():
model.eval() model.eval()
model.device = device model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
@ -428,6 +551,7 @@ def main():
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
decoding_graph=decoding_graph,
) )
save_results( save_results(

View File

@ -1,98 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Decoder(nn.Module):
"""This class modifies the stateless decoder from the following paper:
RNN-transducer with stateless prediction network
https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=9054419
It removes the recurrent connection from the decoder, i.e., the prediction
network. Different from the above paper, it adds an extra Conv1d
right after the embedding layer.
TODO: Implement https://arxiv.org/pdf/2109.07513.pdf
"""
def __init__(
self,
vocab_size: int,
embedding_dim: int,
blank_id: int,
context_size: int,
):
"""
Args:
vocab_size:
Number of tokens of the modeling unit including blank.
embedding_dim:
Dimension of the input embedding.
blank_id:
The ID of the blank symbol.
context_size:
Number of previous words to use to predict the next word.
1 means bigram; 2 means trigram. n means (n+1)-gram.
"""
super().__init__()
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=embedding_dim,
padding_idx=blank_id,
)
self.blank_id = blank_id
assert context_size >= 1, context_size
self.context_size = context_size
if context_size > 1:
self.conv = nn.Conv1d(
in_channels=embedding_dim,
out_channels=embedding_dim,
kernel_size=context_size,
padding=0,
groups=embedding_dim,
bias=False,
)
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
"""
Args:
y:
A 2-D tensor of shape (N, U).
need_pad:
True to left pad the input. Should be True during training.
False to not pad the input. Should be False during inference.
Returns:
Return a tensor of shape (N, U, embedding_dim).
"""
embedding_out = self.embedding(y)
if self.context_size > 1:
embedding_out = embedding_out.permute(0, 2, 1)
if need_pad is True:
embedding_out = F.pad(
embedding_out, pad=(self.context_size - 1, 0)
)
else:
# During inference time, there is no need to do extra padding
# as we only need one output
assert embedding_out.size(-1) == self.context_size
embedding_out = self.conv(embedding_out)
embedding_out = embedding_out.permute(0, 2, 1)
return embedding_out

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/decoder.py

View File

@ -1,57 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
import torch.nn.functional as F
class Joiner(nn.Module):
def __init__(self, input_dim: int, output_dim: int):
super().__init__()
self.output_linear = nn.Linear(input_dim, output_dim)
def forward(
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor, *unused
) -> torch.Tensor:
"""
Args:
encoder_out:
Output from the encoder. Its shape is (N, T, C).
decoder_out:
Output from the decoder. Its shape is (N, U, C).
Returns:
Return a tensor of shape (N, T, U, C).
"""
assert encoder_out.ndim == decoder_out.ndim == 3
assert encoder_out.size(0) == decoder_out.size(0)
assert encoder_out.size(2) == decoder_out.size(2)
encoder_out = encoder_out.unsqueeze(2)
# Now encoder_out is (N, T, 1, C)
decoder_out = decoder_out.unsqueeze(1)
# Now decoder_out is (N, 1, U, C)
logit = encoder_out + decoder_out
logit = F.relu(logit)
output = self.output_linear(logit)
if not self.training:
output = output.squeeze(2).squeeze(1)
return output

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/joiner.py

View File

@ -1,126 +0,0 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note we use `rnnt_loss` from torchaudio, which exists only in
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
"""
import k2
import torch
import torch.nn as nn
import torchaudio
import torchaudio.functional
from encoder_interface import EncoderInterface
from icefall.utils import add_sos
class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
decoder: nn.Module,
joiner: nn.Module,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, C) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, C) and
`logit_lens` of shape (N,).
decoder:
It is the prediction network in the paper. Its input shape
is (N, U) and its output shape is (N, U, C). It should contain
one attribute: `blank_id`.
joiner:
It has two inputs with shapes: (N, T, C) and (N, U, C). Its
output shape is (N, T, U, C). Note that its output contains
unnormalized probs, i.e., not processed by log-softmax.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface)
assert hasattr(decoder, "blank_id")
self.encoder = encoder
self.decoder = decoder
self.joiner = joiner
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
) -> torch.Tensor:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance.
Returns:
Return the transducer loss.
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
blank_id = self.decoder.blank_id
sos_y = add_sos(y, sos_id=blank_id)
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
sos_y_padded = sos_y_padded.to(torch.int64)
decoder_out = self.decoder(sos_y_padded)
logits = self.joiner(encoder_out, decoder_out)
# rnnt_loss requires 0 padded targets
# Note: y does not start with SOS
y_padded = y.pad(mode="constant", padding_value=0)
assert hasattr(torchaudio.functional, "rnnt_loss"), (
f"Current torchaudio version: {torchaudio.__version__}\n"
"Please install a version >= 0.10.0"
)
loss = torchaudio.functional.rnnt_loss(
logits=logits,
targets=y_padded,
logit_lengths=x_lens,
target_lengths=y_lens,
blank=blank_id,
reduction="sum",
)
return loss

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/model.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/optim.py

View File

@ -0,0 +1 @@
../pruned_transducer_stateless2/scaling.py

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# # Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -15,33 +14,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
To run this file, do: To run this file, do:
cd icefall/egs/librispeech/ASR cd icefall/egs/librispeech/ASR
python ./transducer_lstm/test_encoder.py python ./pruned_transducer_stateless4/test_model.py
""" """
from encoder import LstmEncoder from train import get_params, get_transducer_model
def test_encoder(): def test_model():
encoder = LstmEncoder( params = get_params()
num_features=80, params.vocab_size = 500
hidden_size=1024, params.blank_id = 0
proj_size=512, params.context_size = 2
output_dim=512, model = get_transducer_model(params)
subsampling_factor=4, num_param = sum([p.numel() for p in model.parameters()])
num_encoder_layers=12, print(f"Number of model parameters: {num_param}")
)
num_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
print(num_params)
# 93979284
# 66427392
def main(): def main():
test_encoder() test_model()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -16,20 +16,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Usage: Usage:
export CUDA_VISIBLE_DEVICES="0,1,2" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./transducer_lstm/train.py \ ./transducer_lstm/train.py \
--world-size 3 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir transducer_lstm/exp \ --exp-dir transducer_lstm/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 400 \ --max-duration 300
--lr-factor 3
# For mix precision training:
./transducer_lstm/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--use-fp16 1 \
--exp-dir transducer_lstm/exp \
--full-libri 1 \
--max-duration 550
""" """
@ -38,32 +48,40 @@ import logging
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Optional, Tuple from typing import Any, Dict, Optional, Tuple, Union
import k2 import k2
import optim
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from decoder import Decoder
from encoder import LstmEncoder from encoder import LstmEncoder
from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from noam import Noam from optim import Eden, Eve
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from icefall.checkpoint import load_checkpoint from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import save_checkpoint_with_global_batch_idx
from icefall.dist import cleanup_dist, setup_dist from icefall.dist import cleanup_dist, setup_dist
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool
LRSchedulerType = Union[
torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler
]
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -104,7 +122,16 @@ def get_parser():
default=0, default=0,
help="""Resume training from from this epoch. help="""Resume training from from this epoch.
If it is positive, it will load checkpoint from If it is positive, it will load checkpoint from
transducer_lstm/exp/epoch-{start_epoch-1}.pt transducer_stateless2/exp/epoch-{start_epoch-1}.pt
""",
)
parser.add_argument(
"--start-batch",
type=int,
default=0,
help="""If positive, --start-epoch is ignored and
it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
""", """,
) )
@ -126,10 +153,68 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--lr-factor", "--initial-lr",
type=float, type=float,
default=3.0, default=0.003,
help="The lr_factor for Noam optimizer", help="The initial learning rate. This value should not need to be changed.",
)
parser.add_argument(
"--lr-batches",
type=float,
default=5000,
help="""Number of steps that affects how rapidly the learning rate decreases.
We suggest not to change this.""",
)
parser.add_argument(
"--lr-epochs",
type=float,
default=6,
help="""Number of epochs that affects how rapidly the learning rate decreases.
""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--prune-range",
type=int,
default=5,
help="The prune range for rnnt loss, it means how many symbols(context)"
"we are using to compute the loss",
)
parser.add_argument(
"--lm-scale",
type=float,
default=0.25,
help="The scale to smooth the loss with lm "
"(output of prediction network) part.",
)
parser.add_argument(
"--am-scale",
type=float,
default=0.0,
help="The scale to smooth the loss with am (output of encoder network)"
"part.",
)
parser.add_argument(
"--simple-loss-scale",
type=float,
default=0.5,
help="To get pruning ranges, we will calculate a simple version"
"loss(joiner is just addition), this simple loss also uses for"
"training (as a regularization item). We will scale the simple loss"
"with this parameter before adding to the final loss.",
) )
parser.add_argument( parser.add_argument(
@ -140,11 +225,41 @@ def get_parser():
) )
parser.add_argument( parser.add_argument(
"--context-size", "--print-diagnostics",
type=str2bool,
default=False,
help="Accumulate stats on activations, print them and exit.",
)
parser.add_argument(
"--save-every-n",
type=int, type=int,
default=2, default=8000,
help="The context size in the decoder. 1 means bigram; " help="""Save checkpoint after processing this number of batches"
"2 means tri-gram", periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename
has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt'
Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the
end of each epoch where `xxx` is the epoch number counting from 0.
""",
)
parser.add_argument(
"--keep-last-k",
type=int,
default=20,
help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`.
It does not affect checkpoints with name `epoch-xxx.pt`.
""",
)
parser.add_argument(
"--use-fp16",
type=str2bool,
default=False,
help="Whether to use half precision training.",
) )
return parser return parser
@ -188,15 +303,10 @@ def get_params() -> AttributeDict:
- subsampling_factor: The subsampling factor for the model. - subsampling_factor: The subsampling factor for the model.
- use_feat_batchnorm: Whether to do batch normalization for the - encoder_dim: Hidden dim for multi-head attention model.
input features.
- attention_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder. - num_decoder_layers: Number of decoder layer of transformer decoder.
- weight_decay: The weight_decay for the optimizer.
- warm_step: The warm_step for Noam optimizer. - warm_step: The warm_step for Noam optimizer.
""" """
params = AttributeDict( params = AttributeDict(
@ -209,21 +319,20 @@ def get_params() -> AttributeDict:
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800 "valid_interval": 3000, # For the 100h subset, use 800
# parameters for conformer # parameters for encoder
"feature_dim": 80, "feature_dim": 80,
"encoder_out_dim": 512,
"subsampling_factor": 4, "subsampling_factor": 4,
"encoder_dim": 512,
"encoder_hidden_size": 1024, "encoder_hidden_size": 1024,
"num_encoder_layers": 4, "num_encoder_layers": 4,
"proj_size": 512, "proj_size": 512,
"vgg_frontend": False, "vgg_frontend": False,
# decoder params # parameters for decoder
"decoder_embedding_dim": 1024, "decoder_dim": 512,
"num_decoder_layers": 4, # parameters for joiner
"decoder_hidden_dim": 512, "joiner_dim": 512,
# parameters for Noam # parameters for Noam
"weight_decay": 1e-6, "model_warm_step": 3000, # arg given to model, not for lrate
"warm_step": 80000, # For the 100h subset, use 8k
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
@ -231,11 +340,11 @@ def get_params() -> AttributeDict:
return params return params
def get_encoder_model(params: AttributeDict): def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = LstmEncoder( encoder = LstmEncoder(
num_features=params.feature_dim, num_features=params.feature_dim,
hidden_size=params.encoder_hidden_size, hidden_size=params.encoder_hidden_size,
output_dim=params.encoder_out_dim, output_dim=params.encoder_dim,
subsampling_factor=params.subsampling_factor, subsampling_factor=params.subsampling_factor,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
vgg_frontend=params.vgg_frontend, vgg_frontend=params.vgg_frontend,
@ -246,22 +355,24 @@ def get_encoder_model(params: AttributeDict):
def get_decoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module:
decoder = Decoder( decoder = Decoder(
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
embedding_dim=params.encoder_out_dim, decoder_dim=params.decoder_dim,
blank_id=params.blank_id, blank_id=params.blank_id,
context_size=params.context_size, context_size=params.context_size,
) )
return decoder return decoder
def get_joiner_model(params: AttributeDict): def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.encoder_out_dim, encoder_dim=params.encoder_dim,
output_dim=params.vocab_size, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
) )
return joiner return joiner
def get_transducer_model(params: AttributeDict): def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
@ -270,6 +381,10 @@ def get_transducer_model(params: AttributeDict):
encoder=encoder, encoder=encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
encoder_dim=params.encoder_dim,
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
) )
return model return model
@ -278,15 +393,17 @@ def load_checkpoint_if_available(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, scheduler: Optional[LRSchedulerType] = None,
) -> None: ) -> Optional[Dict[str, Any]]:
"""Load checkpoint from file. """Load checkpoint from file.
If params.start_epoch is positive, it will load the checkpoint from If params.start_batch is positive, it will load the checkpoint from
`params.start_epoch - 1`. Otherwise, this function does nothing. `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
params.start_epoch is positive, it will load the checkpoint from
`params.start_epoch - 1`.
Apart from loading state dict for `model`, `optimizer` and `scheduler`, Apart from loading state dict for `model` and `optimizer` it also updates
it also updates `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
and `best_valid_loss` in `params`. and `best_valid_loss` in `params`.
Args: Args:
@ -297,14 +414,19 @@ def load_checkpoint_if_available(
optimizer: optimizer:
The optimizer that we are using. The optimizer that we are using.
scheduler: scheduler:
The learning rate scheduler we are using. The scheduler that we are using.
Returns: Returns:
Return None. Return a dict containing previously saved training info.
""" """
if params.start_epoch <= 0: if params.start_batch > 0:
return filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
elif params.start_epoch > 0:
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
else:
return None
assert filename.is_file(), f"{filename} does not exist!"
filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint( saved_params = load_checkpoint(
filename, filename,
model=model, model=model,
@ -322,6 +444,13 @@ def load_checkpoint_if_available(
for k in keys: for k in keys:
params[k] = saved_params[k] params[k] = saved_params[k]
if params.start_batch > 0:
if "cur_epoch" in saved_params:
params["start_epoch"] = saved_params["cur_epoch"]
if "cur_batch_idx" in saved_params:
params["cur_batch_idx"] = saved_params["cur_batch_idx"]
return saved_params return saved_params
@ -329,7 +458,9 @@ def save_checkpoint(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
optimizer: Optional[torch.optim.Optimizer] = None, optimizer: Optional[torch.optim.Optimizer] = None,
scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, scheduler: Optional[LRSchedulerType] = None,
sampler: Optional[CutSampler] = None,
scaler: Optional[GradScaler] = None,
rank: int = 0, rank: int = 0,
) -> None: ) -> None:
"""Save model, optimizer, scheduler and training stats to file. """Save model, optimizer, scheduler and training stats to file.
@ -339,6 +470,12 @@ def save_checkpoint(
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The training model. The training model.
optimizer:
The optimizer used in the training.
sampler:
The sampler for the training dataset.
scaler:
The scaler used for mix precision training.
""" """
if rank != 0: if rank != 0:
return return
@ -349,6 +486,8 @@ def save_checkpoint(
params=params, params=params,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler, scheduler=scheduler,
sampler=sampler,
scaler=scaler,
rank=rank, rank=rank,
) )
@ -367,6 +506,7 @@ def compute_loss(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
warmup: float = 1.0,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -383,6 +523,8 @@ def compute_loss(
True for training. False for validation. When it is True, this True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it function enables autograd during computation; when it is False, it
disables autograd. disables autograd.
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
""" """
device = model.device device = model.device
feature = batch["inputs"] feature = batch["inputs"]
@ -398,21 +540,42 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=y) simple_loss, pruned_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
)
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
# overwhelming the simple_loss and causing it to diverge,
# in case it had not fully learned the alignment yet.
pruned_loss_scale = (
0.0
if warmup < 1.0
else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0)
)
loss = (
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
info = MetricsTracker() info = MetricsTracker()
with warnings.catch_warnings(): with warnings.catch_warnings():
warnings.simplefilter("ignore") warnings.simplefilter("ignore")
with warnings.catch_warnings(): info["frames"] = (
warnings.simplefilter("ignore") (feature_lens // params.subsampling_factor).sum().item()
info["frames"] = ( )
(feature_lens // params.subsampling_factor).sum().item()
)
# Note: We use reduction=sum while computing the loss. # Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
return loss, info return loss, info
@ -455,11 +618,14 @@ def train_one_epoch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
scheduler: LRSchedulerType,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
rank: int = 0,
) -> None: ) -> None:
"""Train the model for one epoch. """Train the model for one epoch.
@ -474,51 +640,96 @@ def train_one_epoch(
The model for training. The model for training.
optimizer: optimizer:
The optimizer we are using. The optimizer we are using.
scheduler:
The learning rate scheduler, we call step() every step.
train_dl: train_dl:
Dataloader for the training dataset. Dataloader for the training dataset.
valid_dl: valid_dl:
Dataloader for the validation dataset. Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
tb_writer: tb_writer:
Writer to write log messages to tensorboard. Writer to write log messages to tensorboard.
world_size: world_size:
Number of nodes in DDP training. If it is 1, DDP is disabled. Number of nodes in DDP training. If it is 1, DDP is disabled.
rank:
The rank of the node in DDP training. If no DDP is used, it should
be set to 0.
""" """
model.train() model.train()
tot_loss = MetricsTracker() tot_loss = MetricsTracker()
cur_batch_idx = params.get("cur_batch_idx", 0)
for batch_idx, batch in enumerate(train_dl): for batch_idx, batch in enumerate(train_dl):
if batch_idx < cur_batch_idx:
continue
cur_batch_idx = batch_idx
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
loss, loss_info = compute_loss( with torch.cuda.amp.autocast(enabled=params.use_fp16):
params=params, loss, loss_info = compute_loss(
model=model, params=params,
sp=sp, model=model,
batch=batch, sp=sp,
is_training=True, batch=batch,
) is_training=True,
warmup=(params.batch_idx_train / params.model_warm_step),
)
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
# NOTE: We use reduction==sum and loss is computed over utterances # NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far. # in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
scheduler.step_batch(params.batch_idx_train)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step()
if batch_idx % params.log_interval == 0: if params.print_diagnostics and batch_idx == 5:
logging.info( return
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], " if (
f"tot_loss[{tot_loss}], batch size: {batch_size}" params.batch_idx_train > 0
and params.batch_idx_train % params.save_every_n == 0
):
params.cur_batch_idx = batch_idx
save_checkpoint_with_global_batch_idx(
out_dir=params.exp_dir,
global_batch_idx=params.batch_idx_train,
model=model,
params=params,
optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank,
)
del params.cur_batch_idx
remove_checkpoints(
out_dir=params.exp_dir,
topk=params.keep_last_k,
rank=rank,
) )
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
cur_lr = scheduler.get_last_lr()[0]
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}"
)
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
loss_info.write_summary( loss_info.write_summary(
tb_writer, "train/current_", params.batch_idx_train tb_writer, "train/current_", params.batch_idx_train
) )
@ -564,8 +775,7 @@ def run(rank, world_size, args):
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
if params.full_libri is False: if params.full_libri is False:
params.valid_interval = 800 params.valid_interval = 1600
params.warm_step = 8000
fix_random_seed(params.seed) fix_random_seed(params.seed)
if world_size > 1: if world_size > 1:
@ -596,29 +806,39 @@ def run(rank, world_size, args):
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
checkpoints = load_checkpoint_if_available(params=params, model=model) num_param = sum([p.numel() for p in model.parameters()])
num_param = sum([p.numel() for p in model.parameters() if p.requires_grad])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
checkpoints = load_checkpoint_if_available(params=params, model=model)
model.to(device) model.to(device)
if world_size > 1: if world_size > 1:
logging.info("Using DDP") logging.info("Using DDP")
model = DDP(model, device_ids=[rank]) model = DDP(model, device_ids=[rank])
model.device = device model.device = device
optimizer = Noam( optimizer = Eve(model.parameters(), lr=params.initial_lr)
model.parameters(),
model_size=params.encoder_hidden_size, scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
factor=params.lr_factor,
warm_step=params.warm_step,
weight_decay=params.weight_decay,
)
if checkpoints and "optimizer" in checkpoints: if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict") logging.info("Loading optimizer state dict")
optimizer.load_state_dict(checkpoints["optimizer"]) optimizer.load_state_dict(checkpoints["optimizer"])
if (
checkpoints
and "scheduler" in checkpoints
and checkpoints["scheduler"] is not None
):
logging.info("Loading scheduler state dict")
scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
train_cuts = librispeech.train_clean_100_cuts() train_cuts = librispeech.train_clean_100_cuts()
@ -628,75 +848,81 @@ def run(rank, world_size, args):
def remove_short_and_long_utt(c: Cut): def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds # Keep only utterances with duration between 1 second and 20 seconds
#
# Caution: There is a reason to select 20.0 here. Please see
# ../local/display_manifest_statistics.py
#
# You should use ../local/display_manifest_statistics.py to get
# an utterance duration distribution for your dataset to select
# the threshold
return 1.0 <= c.duration <= 20.0 return 1.0 <= c.duration <= 20.0
num_in_total = len(train_cuts)
train_cuts = train_cuts.filter(remove_short_and_long_utt) train_cuts = train_cuts.filter(remove_short_and_long_utt)
try: if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
num_left = len(train_cuts) # We only load the sampler's state dict when it loads a checkpoint
num_removed = num_in_total - num_left # saved in the middle of an epoch
removed_percent = num_removed / num_in_total * 100 sampler_state_dict = checkpoints["sampler"]
else:
sampler_state_dict = None
logging.info( train_dl = librispeech.train_dataloaders(
f"Before removing short and long utterances: {num_in_total}" train_cuts, sampler_state_dict=sampler_state_dict
) )
logging.info(f"After removing short and long utterances: {num_left}")
logging.info(
f"Removed {num_removed} utterances ({removed_percent:.5f}%)"
)
except TypeError as e:
# You can ignore this error as previous versions of Lhotse work fine
# for the above code. In recent versions of Lhotse, it uses
# lazy filter, producing cutsets that don't have the __len__ method
logging.info(str(e))
train_dl = librispeech.train_dataloaders(train_cuts)
valid_cuts = librispeech.dev_clean_cuts() valid_cuts = librispeech.dev_clean_cuts()
valid_cuts += librispeech.dev_other_cuts() valid_cuts += librispeech.dev_other_cuts()
valid_dl = librispeech.valid_dataloaders(valid_cuts) valid_dl = librispeech.valid_dataloaders(valid_cuts)
scan_pessimistic_batches_for_oom( if not params.print_diagnostics:
model=model, scan_pessimistic_batches_for_oom(
train_dl=train_dl, model=model,
optimizer=optimizer, train_dl=train_dl,
sp=sp, optimizer=optimizer,
params=params, sp=sp,
) params=params,
)
scaler = GradScaler(enabled=params.use_fp16)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
for epoch in range(params.start_epoch, params.num_epochs): for epoch in range(params.start_epoch, params.num_epochs):
scheduler.step_epoch(epoch)
fix_random_seed(params.seed + epoch) fix_random_seed(params.seed + epoch)
train_dl.sampler.set_epoch(epoch) train_dl.sampler.set_epoch(epoch)
cur_lr = optimizer._rate
if tb_writer is not None: if tb_writer is not None:
tb_writer.add_scalar(
"train/learning_rate", cur_lr, params.batch_idx_train
)
tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
if rank == 0:
logging.info("epoch {}, learning rate {}".format(epoch, cur_lr))
params.cur_epoch = epoch params.cur_epoch = epoch
train_one_epoch( train_one_epoch(
params=params, params=params,
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler,
sp=sp, sp=sp,
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
scaler=scaler,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
rank=rank,
) )
if params.print_diagnostics:
diagnostic.print_diagnostics()
break
save_checkpoint( save_checkpoint(
params=params, params=params,
model=model, model=model,
optimizer=optimizer, optimizer=optimizer,
scheduler=scheduler,
sampler=train_dl.sampler,
scaler=scaler,
rank=rank, rank=rank,
) )
@ -723,17 +949,21 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items(): for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts] batch = train_dl.dataset[cuts]
try: try:
optimizer.zero_grad() # warmup = 0.0 is so that the derivs for the pruned loss stay zero
loss, _ = compute_loss( # (i.e. are not remembered by the decaying-average in adam), because
params=params, # we want to avoid these params being subject to shrinkage in adam.
model=model, with torch.cuda.amp.autocast(enabled=params.use_fp16):
sp=sp, loss, _ = compute_loss(
batch=batch, params=params,
is_training=True, model=model,
) sp=sp,
batch=batch,
is_training=True,
warmup=0.0,
)
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)
optimizer.step() optimizer.step()
optimizer.zero_grad()
except RuntimeError as e: except RuntimeError as e:
if "CUDA out of memory" in str(e): if "CUDA out of memory" in str(e):
logging.error( logging.error(