Merge changes from pruned_transducer_stateless4->5

This commit is contained in:
Daniel Povey 2022-05-31 13:00:11 +08:00
parent c7cf229f56
commit 1651fe0d42
6 changed files with 569 additions and 123 deletions

View File

@ -20,26 +20,27 @@
# to a single one using model averaging. # to a single one using model averaging.
""" """
Usage: Usage:
./pruned_transducer_stateless2/export.py \ ./pruned_transducer_stateless5/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--bpe-model data/lang_bpe_500/bpe.model \ --bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \ --epoch 20 \
--avg 10 --avg 10
It will generate a file exp_dir/pretrained.pt It will generate a file exp_dir/pretrained.pt
To use the generated file with `pruned_transducer_stateless2/decode.py`, To use the generated file with `pruned_transducer_stateless5/decode.py`,
you can do: you can do:
cd /path/to/exp_dir cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/librispeech/ASR cd /path/to/egs/librispeech/ASR
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless5/decode.py \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless5/exp \
--epoch 9999 \ --epoch 9999 \
--avg 1 \ --avg 1 \
--max-duration 100 \ --max-duration 600 \
--decoding-method greedy_search \
--bpe-model data/lang_bpe_500/bpe.model --bpe-model data/lang_bpe_500/bpe.model
""" """
@ -49,10 +50,11 @@ from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import torch import torch
from train import get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints, find_checkpoints,
load_checkpoint, load_checkpoint,
) )
@ -69,7 +71,7 @@ def get_parser():
type=int, type=int,
default=28, default=28,
help="""It specifies the checkpoint to use for averaging. help="""It specifies the checkpoint to use for averaging.
Note: Epoch counts from 0. Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""", You can specify --avg to use more checkpoints for model averaging.""",
) )
@ -92,10 +94,21 @@ def get_parser():
"'--epoch' and '--iter'", "'--epoch' and '--iter'",
) )
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=False,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="pruned_transducer_stateless5/exp",
help="""It specifies the directory where all training related help="""It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
""", """,
@ -124,6 +137,8 @@ def get_parser():
"2 means tri-gram", "2 means tri-gram",
) )
add_model_arguments(parser)
return parser return parser
@ -131,6 +146,8 @@ def main():
args = get_parser().parse_args() args = get_parser().parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
assert args.jit is False, "Support torchscript will be added later"
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
@ -152,12 +169,11 @@ def main():
logging.info("About to create model") logging.info("About to create model")
model = get_transducer_model(params) model = get_transducer_model(params)
model.to(device) if not params.use_averaged_model:
if params.iter > 0: if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ filenames = find_checkpoints(
: params.avg params.exp_dir, iteration=-params.iter
] )[: params.avg]
if len(filenames) == 0: if len(filenames) == 0:
raise ValueError( raise ValueError(
f"No checkpoints found for" f"No checkpoints found for"
@ -177,11 +193,58 @@ def main():
start = params.epoch - params.avg + 1 start = params.epoch - params.avg + 1
filenames = [] filenames = []
for i in range(start, params.epoch + 1): for i in range(start, params.epoch + 1):
if start >= 0: if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt") filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}") logging.info(f"averaging {filenames}")
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(
params.exp_dir, iteration=-params.iter
)[: params.avg + 1]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.eval() model.eval()
@ -189,11 +252,6 @@ def main():
model.eval() model.eval()
if params.jit: if params.jit:
# We won't use the forward() method of the model in C++, so just ignore
# it here.
# Otherwise, one of its arguments is a ragged tensor and is not
# torch scriptabe.
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
logging.info("Using torch.jit.script") logging.info("Using torch.jit.script")
model = torch.jit.script(model) model = torch.jit.script(model)
filename = params.exp_dir / "cpu_jit.pt" filename = params.exp_dir / "cpu_jit.pt"

View File

@ -1 +0,0 @@
../pruned_transducer_stateless2/__init__.py

View File

@ -18,7 +18,7 @@
import copy import copy
import math import math
import warnings import warnings
from typing import Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
@ -61,6 +61,7 @@ class Conformer(EncoderInterface):
dropout: float = 0.1, dropout: float = 0.1,
layer_dropout: float = 0.075, layer_dropout: float = 0.075,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
aux_layer_period: int = 3,
) -> None: ) -> None:
super(Conformer, self).__init__() super(Conformer, self).__init__()
@ -86,7 +87,11 @@ class Conformer(EncoderInterface):
layer_dropout, layer_dropout,
cnn_module_kernel, cnn_module_kernel,
) )
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers) self.encoder = ConformerEncoder(
encoder_layer,
num_encoder_layers,
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
)
def forward( def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0 self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
@ -112,13 +117,10 @@ class Conformer(EncoderInterface):
x, pos_emb = self.encoder_pos(x) x, pos_emb = self.encoder_pos(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
# Caution: We assume the subsampling factor is 4! # Caution: We assume the subsampling factor is 4!
lengths = ((x_lens - 1) // 2 - 1) // 2
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
#
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
lengths = (((x_lens - 1) >> 1) - 1) >> 1
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths) mask = make_pad_mask(lengths)
@ -282,13 +284,30 @@ class ConformerEncoder(nn.Module):
>>> out = conformer_encoder(src, pos_emb) >>> out = conformer_encoder(src, pos_emb)
""" """
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None: def __init__(
self,
encoder_layer: nn.Module,
num_layers: int,
aux_layers: List[int],
) -> None:
super().__init__() super().__init__()
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)] [copy.deepcopy(encoder_layer) for i in range(num_layers)]
) )
self.num_layers = num_layers self.num_layers = num_layers
assert num_layers - 1 not in aux_layers
self.aux_layers = set(aux_layers + [num_layers - 1])
num_channels = encoder_layer.norm_final.num_channels
self.combiner = RandomCombine(
num_inputs=len(self.aux_layers),
num_channels=num_channels,
final_weight=0.5,
pure_prob=0.333,
stddev=2.0,
)
def forward( def forward(
self, self,
src: Tensor, src: Tensor,
@ -315,6 +334,8 @@ class ConformerEncoder(nn.Module):
""" """
output = src output = src
outputs = []
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
output = mod( output = mod(
output, output,
@ -323,6 +344,10 @@ class ConformerEncoder(nn.Module):
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
warmup=warmup, warmup=warmup,
) )
if i in self.aux_layers:
outputs.append(output)
output = self.combiner(outputs)
return output return output
@ -1022,15 +1047,281 @@ class Conv2dSubsampling(nn.Module):
x = self.out_balancer(x) x = self.out_balancer(x)
return x return x
class RandomCombine(nn.Module):
"""
This module combines a list of Tensors, all with the same shape, to
produce a single output of that same shape which, in training time,
is a random combination of all the inputs; but which in test time
will be just the last input.
All but the last input will have a linear transform before we
randomly combine them; these linear transforms will be initialized
to the identity transform.
The idea is that the list of Tensors will be a list of outputs of multiple
conformer layers. This has a similar effect as iterated loss. (See:
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
NETWORKS).
"""
def __init__(
self,
num_inputs: int,
num_channels: int,
final_weight: float = 0.5,
pure_prob: float = 0.5,
stddev: float = 2.0,
) -> None:
"""
Args:
num_inputs:
The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3.
num_channels:
The number of channels on the input, e.g. 512.
final_weight:
The amount of weight or probability we assign to the
final layer when randomly choosing layers or when choosing
continuous layer weights.
pure_prob:
The probability, on each frame, with which we choose
only a single layer to output (rather than an interpolation)
stddev:
A standard deviation that we add to log-probs for computing
randomized weights.
The method of choosing which layers, or combinations of layers, to use,
is conceptually as follows::
With probability `pure_prob`::
With probability `final_weight`: choose final layer,
Else: choose random non-final layer.
Else::
Choose initial log-weights that correspond to assigning
weight `final_weight` to the final layer and equal
weights to other layers; then add Gaussian noise
with variance `stddev` to these log-weights, and normalize
to weights (note: the average weight assigned to the
final layer here will not be `final_weight` if stddev>0).
"""
super().__init__()
assert 0 <= pure_prob <= 1, pure_prob
assert 0 < final_weight < 1, final_weight
assert num_inputs >= 1
self.linear = nn.ModuleList(
[
nn.Linear(num_channels, num_channels, bias=True)
for _ in range(num_inputs - 1)
]
)
self.num_inputs = num_inputs
self.final_weight = final_weight
self.pure_prob = pure_prob
self.stddev = stddev
self.final_log_weight = (
torch.tensor(
(final_weight / (1 - final_weight)) * (self.num_inputs - 1)
)
.log()
.item()
)
self._reset_parameters()
def _reset_parameters(self):
for i in range(len(self.linear)):
nn.init.eye_(self.linear[i].weight)
nn.init.constant_(self.linear[i].bias, 0.0)
def forward(self, inputs: List[Tensor]) -> Tensor:
"""Forward function.
Args:
inputs:
A list of Tensor, e.g. from various layers of a transformer.
All must be the same shape, of (*, num_channels)
Returns:
A Tensor of shape (*, num_channels). In test mode
this is just the final input.
"""
num_inputs = self.num_inputs
assert len(inputs) == num_inputs
if not self.training:
return inputs[-1]
# Shape of weights: (*, num_inputs)
num_channels = inputs[0].shape[-1]
num_frames = inputs[0].numel() // num_channels
mod_inputs = []
for i in range(num_inputs - 1):
mod_inputs.append(self.linear[i](inputs[i]))
mod_inputs.append(inputs[num_inputs - 1])
ndim = inputs[0].ndim
# stacked_inputs: (num_frames, num_channels, num_inputs)
stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape(
(num_frames, num_channels, num_inputs)
)
# weights: (num_frames, num_inputs)
weights = self._get_random_weights(
inputs[0].dtype, inputs[0].device, num_frames
)
weights = weights.reshape(num_frames, num_inputs, 1)
# ans: (num_frames, num_channels, 1)
ans = torch.matmul(stacked_inputs, weights)
# ans: (*, num_channels)
ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels)
if __name__ == "__main__": if __name__ == "__main__":
# for testing only...
print("Weights = ", weights.reshape(num_frames, num_inputs))
return ans
def _get_random_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
) -> Tensor:
"""Return a tensor of random weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired
Returns:
A tensor of shape (num_frames, self.num_inputs), such that
`ans.sum(dim=1)` is all ones.
"""
pure_prob = self.pure_prob
if pure_prob == 0.0:
return self._get_random_mixed_weights(dtype, device, num_frames)
elif pure_prob == 1.0:
return self._get_random_pure_weights(dtype, device, num_frames)
else:
p = self._get_random_pure_weights(dtype, device, num_frames)
m = self._get_random_mixed_weights(dtype, device, num_frames)
return torch.where(
torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
)
def _get_random_pure_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
):
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
exactly one weight equal to 1.0 on each frame.
"""
final_prob = self.final_weight
# final contains self.num_inputs - 1 in all elements
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights.
nonfinal = torch.randint(
self.num_inputs - 1, (num_frames,), device=device
)
indexes = torch.where(
torch.rand(num_frames, device=device) < final_prob, final, nonfinal
)
ans = torch.nn.functional.one_hot(
indexes, num_classes=self.num_inputs
).to(dtype=dtype)
return ans
def _get_random_mixed_weights(
self, dtype: torch.dtype, device: torch.device, num_frames: int
):
"""Return a tensor of random one-hot weights, of shape
`(num_frames, self.num_inputs)`,
Args:
dtype:
The data-type desired for the answer, e.g. float, double.
device:
The device needed for the answer.
num_frames:
The number of sets of weights desired.
Returns:
A tensor of shape (num_frames, self.num_inputs), which elements
in [0..1] that sum to one over the second axis, i.e.
`ans.sum(dim=1)` is all ones.
"""
logprobs = (
torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
* self.stddev
)
logprobs[:, -1] += self.final_log_weight
return logprobs.softmax(dim=1)
def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
print(
f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}"
)
num_inputs = 3
num_channels = 50
m = RandomCombine(
num_inputs=num_inputs,
num_channels=num_channels,
final_weight=final_weight,
pure_prob=pure_prob,
stddev=stddev,
)
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
y = m(x)
assert y.shape == x[0].shape
assert torch.allclose(y, x[0]) # .. since actually all ones.
def _test_random_combine_main():
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.0)
_test_random_combine(0.999, 0, 0.0)
_test_random_combine(0.5, 0, 0.3)
_test_random_combine(0.5, 1, 0.3)
_test_random_combine(0.5, 0.5, 0.3)
def _test_conformer_main():
feature_dim = 50 feature_dim = 50
c = Conformer(num_features=feature_dim, d_model=128, nhead=4) c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
batch_size = 5 batch_size = 5
seq_len = 20 seq_len = 20
feature_dim = 50
# Just make sure the forward pass runs.
c = Conformer(
num_features=feature_dim, d_model=128, nhead=4
)
batch_size = 5
seq_len = 20
# Just make sure the forward pass runs. # Just make sure the forward pass runs.
f = c( f = c(
torch.randn(batch_size, seq_len, feature_dim), torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64), torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup=0.5, warmup=0.5,
) )
f # to remove flake8 warnings
if __name__ == "__main__":
_test_conformer_main()
_test_random_combine_main()

View File

@ -19,36 +19,36 @@
""" """
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless7/decode.py \
--epoch 30 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (not recommended) (2) beam search (not recommended)
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless7/decode.py \
--epoch 30 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless7/decode.py \
--epoch 30 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless4/decode.py \ ./pruned_transducer_stateless7/decode.py \
--epoch 30 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless4/exp \ --exp-dir ./pruned_transducer_stateless7/exp \
--max-duration 600 \ --max-duration 600 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 4 \
@ -75,7 +75,7 @@ from beam_search import (
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
) )
from train import get_params, get_transducer_model from train import add_model_arguments, get_params, get_transducer_model
from icefall.checkpoint import ( from icefall.checkpoint import (
average_checkpoints, average_checkpoints,
@ -139,7 +139,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless4/exp", default="pruned_transducer_stateless7/exp",
help="The experiment dir", help="The experiment dir",
) )
@ -212,6 +212,8 @@ def get_parser():
Used only when --decoding_method is greedy_search""", Used only when --decoding_method is greedy_search""",
) )
add_model_arguments(parser)
return parser return parser
@ -374,7 +376,7 @@ def decode_dataset(
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 50 log_interval = 50
else: else:
log_interval = 10 log_interval = 20
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):

View File

@ -23,27 +23,42 @@ To run this file, do:
python ./pruned_transducer_stateless4/test_model.py python ./pruned_transducer_stateless4/test_model.py
""" """
import torch
from train import get_params, get_transducer_model from train import get_params, get_transducer_model
def test_model(): def test_model_1():
params = get_params() params = get_params()
params.vocab_size = 500 params.vocab_size = 500
params.blank_id = 0 params.blank_id = 0
params.context_size = 2 params.context_size = 2
params.unk_id = 2 params.num_encoder_layers = 24
params.dim_feedforward = 1536 # 384 * 4
params.encoder_dim = 384
model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf
def test_model_M():
params = get_params()
params.vocab_size = 500
params.blank_id = 0
params.context_size = 2
params.num_encoder_layers = 18
params.dim_feedforward = 1024
params.encoder_dim = 256
params.nhead = 4
params.decoder_dim = 512
params.joiner_dim = 512
model = get_transducer_model(params) model = get_transducer_model(params)
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}") print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)
def main(): def main():
test_model() # test_model_1()
test_model_M()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, # Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang, # Wei Kang,
# Mingshuang Luo,) # Mingshuang Luo,)
# Zengwei Yao) # Zengwei Yao)
@ -22,22 +22,22 @@ Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2,3"
./pruned_transducer_stateless4/train.py \ ./pruned_transducer_stateless7/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 1 \ --start-epoch 1 \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 300 --max-duration 300
# For mix precision training: # For mix precision training:
./pruned_transducer_stateless4/train.py \ ./pruned_transducer_stateless7/train.py \
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 1 \ --start-epoch 1 \
--use-fp16 1 \ --use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 550 --max-duration 550
@ -88,6 +88,53 @@ LRSchedulerType = Union[
] ]
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-encoder-layers",
type=int,
default=24,
help="Number of conformer encoder layers..",
)
parser.add_argument(
"--dim-feedforward",
type=int,
default=1536,
help="Feedforward dimension of the conformer encoder layer.",
)
parser.add_argument(
"--nhead",
type=int,
default=8,
help="Number of attention heads in the conformer encoder layer.",
)
parser.add_argument(
"--encoder-dim",
type=int,
default=384,
help="Attention dimension in the conformer encoder layer.",
)
parser.add_argument(
"--decoder-dim",
type=int,
default=512,
help="Embedding dimension in the decoder model.",
)
parser.add_argument(
"--joiner-dim",
type=int,
default=512,
help="""Dimension used in the joiner model.
Outputs from the encoder and decoder model are projected
to this dimension before adding.
""",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -143,7 +190,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="pruned_transducer_stateless7/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -161,16 +208,16 @@ def get_parser():
"--initial-lr", "--initial-lr",
type=float, type=float,
default=0.003, default=0.003,
help="""The initial learning rate. This value should not need to be help="The initial learning rate. This value should not need "
changed.""", "to be changed.",
) )
parser.add_argument( parser.add_argument(
"--lr-batches", "--lr-batches",
type=float, type=float,
default=5000, default=5000,
help="""Number of steps that affects how rapidly the learning rate decreases. help="""Number of steps that affects how rapidly the learning rate
We suggest not to change this.""", decreases. We suggest not to change this.""",
) )
parser.add_argument( parser.add_argument(
@ -240,7 +287,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--save-every-n", "--save-every-n",
type=int, type=int,
default=8000, default=4000,
help="""Save checkpoint after processing this number of batches" help="""Save checkpoint after processing this number of batches"
periodically. We save checkpoint to exp-dir/ whenever periodically. We save checkpoint to exp-dir/ whenever
params.batch_idx_train % save_every_n == 0. The checkpoint filename params.batch_idx_train % save_every_n == 0. The checkpoint filename
@ -253,7 +300,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--keep-last-k", "--keep-last-k",
type=int, type=int,
default=20, default=30,
help="""Only keep this number of checkpoints on disk. help="""Only keep this number of checkpoints on disk.
For instance, if it is 3, there are only 3 checkpoints For instance, if it is 3, there are only 3 checkpoints
in the exp-dir with filenames `checkpoint-xxx.pt`. in the exp-dir with filenames `checkpoint-xxx.pt`.
@ -281,6 +328,8 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
add_model_arguments(parser)
return parser return parser
@ -341,14 +390,6 @@ def get_params() -> AttributeDict:
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"subsampling_factor": 4, "subsampling_factor": 4,
"encoder_dim": 512,
"nhead": 8,
"dim_feedforward": 2048,
"num_encoder_layers": 12,
# parameters for decoder
"decoder_dim": 512,
# parameters for joiner
"joiner_dim": 512,
# parameters for Noam # parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate "model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(), "env_info": get_env_info(),
@ -704,6 +745,7 @@ def train_one_epoch(
params.batch_idx_train += 1 params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"]) batch_size = len(batch["supervisions"]["text"])
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
params=params, params=params,
@ -723,8 +765,11 @@ def train_one_epoch(
scaler.step(optimizer) scaler.step(optimizer)
scaler.update() scaler.update()
optimizer.zero_grad() optimizer.zero_grad()
except: # noqa
display_and_save_batch(batch, params=params, sp=sp)
raise
if params.print_diagnostics and batch_idx == 30: if params.print_diagnostics and batch_idx == 5:
return return
if ( if (
@ -888,7 +933,10 @@ def run(rank, world_size, args):
scheduler.load_state_dict(checkpoints["scheduler"]) scheduler.load_state_dict(checkpoints["scheduler"])
if params.print_diagnostics: if params.print_diagnostics:
diagnostic = diagnostics.attach_diagnostics(model) opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)
librispeech = LibriSpeechAsrDataModule(args) librispeech = LibriSpeechAsrDataModule(args)
@ -986,6 +1034,38 @@ def run(rank, world_size, args):
cleanup_dist() cleanup_dist()
def display_and_save_batch(
batch: dict,
params: AttributeDict,
sp: spm.SentencePieceProcessor,
) -> None:
"""Display the batch statistics and save the batch into disk.
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.
sp:
The BPE model.
"""
from lhotse.utils import uuid4
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
logging.info(f"Saving batch to {filename}")
torch.save(batch, filename)
supervisions = batch["supervisions"]
features = batch["inputs"]
logging.info(f"features shape: {features.shape}")
y = sp.encode(supervisions["text"], out_type=int)
num_tokens = sum(len(i) for i in y)
logging.info(f"num tokens: {num_tokens}")
def scan_pessimistic_batches_for_oom( def scan_pessimistic_batches_for_oom(
model: Union[nn.Module, DDP], model: Union[nn.Module, DDP],
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
@ -1017,7 +1097,7 @@ def scan_pessimistic_batches_for_oom(
loss.backward() loss.backward()
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
except RuntimeError as e: except Exception as e:
if "CUDA out of memory" in str(e): if "CUDA out of memory" in str(e):
logging.error( logging.error(
"Your GPU ran out of memory with the current " "Your GPU ran out of memory with the current "
@ -1026,6 +1106,7 @@ def scan_pessimistic_batches_for_oom(
f"Failing criterion: {criterion} " f"Failing criterion: {criterion} "
f"(={crit_values[criterion]}) ..." f"(={crit_values[criterion]}) ..."
) )
display_and_save_batch(batch, params=params, sp=sp)
raise raise