mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-10 01:24:19 +00:00
Merge changes from pruned_transducer_stateless4->5
This commit is contained in:
parent
c7cf229f56
commit
1651fe0d42
@ -20,26 +20,27 @@
|
|||||||
# to a single one using model averaging.
|
# to a single one using model averaging.
|
||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
./pruned_transducer_stateless2/export.py \
|
./pruned_transducer_stateless5/export.py \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model \
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
--epoch 20 \
|
--epoch 20 \
|
||||||
--avg 10
|
--avg 10
|
||||||
|
|
||||||
It will generate a file exp_dir/pretrained.pt
|
It will generate a file exp_dir/pretrained.pt
|
||||||
|
|
||||||
To use the generated file with `pruned_transducer_stateless2/decode.py`,
|
To use the generated file with `pruned_transducer_stateless5/decode.py`,
|
||||||
you can do:
|
you can do:
|
||||||
|
|
||||||
cd /path/to/exp_dir
|
cd /path/to/exp_dir
|
||||||
ln -s pretrained.pt epoch-9999.pt
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
|
|
||||||
cd /path/to/egs/librispeech/ASR
|
cd /path/to/egs/librispeech/ASR
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./pruned_transducer_stateless5/decode.py \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./pruned_transducer_stateless5/exp \
|
||||||
--epoch 9999 \
|
--epoch 9999 \
|
||||||
--avg 1 \
|
--avg 1 \
|
||||||
--max-duration 100 \
|
--max-duration 600 \
|
||||||
|
--decoding-method greedy_search \
|
||||||
--bpe-model data/lang_bpe_500/bpe.model
|
--bpe-model data/lang_bpe_500/bpe.model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -49,10 +50,11 @@ from pathlib import Path
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
from train import get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
@ -69,7 +71,7 @@ def get_parser():
|
|||||||
type=int,
|
type=int,
|
||||||
default=28,
|
default=28,
|
||||||
help="""It specifies the checkpoint to use for averaging.
|
help="""It specifies the checkpoint to use for averaging.
|
||||||
Note: Epoch counts from 0.
|
Note: Epoch counts from 1.
|
||||||
You can specify --avg to use more checkpoints for model averaging.""",
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -92,10 +94,21 @@ def get_parser():
|
|||||||
"'--epoch' and '--iter'",
|
"'--epoch' and '--iter'",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless2/exp",
|
default="pruned_transducer_stateless5/exp",
|
||||||
help="""It specifies the directory where all training related
|
help="""It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
""",
|
""",
|
||||||
@ -124,6 +137,8 @@ def get_parser():
|
|||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -131,6 +146,8 @@ def main():
|
|||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
|
||||||
|
assert args.jit is False, "Support torchscript will be added later"
|
||||||
|
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
|
|
||||||
@ -152,12 +169,11 @@ def main():
|
|||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
model.to(device)
|
if not params.use_averaged_model:
|
||||||
|
|
||||||
if params.iter > 0:
|
if params.iter > 0:
|
||||||
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
filenames = find_checkpoints(
|
||||||
: params.avg
|
params.exp_dir, iteration=-params.iter
|
||||||
]
|
)[: params.avg]
|
||||||
if len(filenames) == 0:
|
if len(filenames) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"No checkpoints found for"
|
f"No checkpoints found for"
|
||||||
@ -177,11 +193,58 @@ def main():
|
|||||||
start = params.epoch - params.avg + 1
|
start = params.epoch - params.avg + 1
|
||||||
filenames = []
|
filenames = []
|
||||||
for i in range(start, params.epoch + 1):
|
for i in range(start, params.epoch + 1):
|
||||||
if start >= 0:
|
if i >= 1:
|
||||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
logging.info(f"averaging {filenames}")
|
logging.info(f"averaging {filenames}")
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.load_state_dict(average_checkpoints(filenames, device=device))
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(
|
||||||
|
params.exp_dir, iteration=-params.iter
|
||||||
|
)[: params.avg + 1]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
@ -189,11 +252,6 @@ def main():
|
|||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
if params.jit:
|
if params.jit:
|
||||||
# We won't use the forward() method of the model in C++, so just ignore
|
|
||||||
# it here.
|
|
||||||
# Otherwise, one of its arguments is a ragged tensor and is not
|
|
||||||
# torch scriptabe.
|
|
||||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
model = torch.jit.script(model)
|
model = torch.jit.script(model)
|
||||||
filename = params.exp_dir / "cpu_jit.pt"
|
filename = params.exp_dir / "cpu_jit.pt"
|
||||||
|
@ -1 +0,0 @@
|
|||||||
../pruned_transducer_stateless2/__init__.py
|
|
@ -18,7 +18,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
@ -61,6 +61,7 @@ class Conformer(EncoderInterface):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
layer_dropout: float = 0.075,
|
layer_dropout: float = 0.075,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
|
aux_layer_period: int = 3,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(Conformer, self).__init__()
|
super(Conformer, self).__init__()
|
||||||
|
|
||||||
@ -86,7 +87,11 @@ class Conformer(EncoderInterface):
|
|||||||
layer_dropout,
|
layer_dropout,
|
||||||
cnn_module_kernel,
|
cnn_module_kernel,
|
||||||
)
|
)
|
||||||
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)
|
self.encoder = ConformerEncoder(
|
||||||
|
encoder_layer,
|
||||||
|
num_encoder_layers,
|
||||||
|
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||||
@ -112,13 +117,10 @@ class Conformer(EncoderInterface):
|
|||||||
x, pos_emb = self.encoder_pos(x)
|
x, pos_emb = self.encoder_pos(x)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
# Caution: We assume the subsampling factor is 4!
|
# Caution: We assume the subsampling factor is 4!
|
||||||
|
lengths = ((x_lens - 1) // 2 - 1) // 2
|
||||||
# lengths = ((x_lens - 1) // 2 - 1) // 2 # issue an warning
|
|
||||||
#
|
|
||||||
# Note: rounding_mode in torch.div() is available only in torch >= 1.8.0
|
|
||||||
lengths = (((x_lens - 1) >> 1) - 1) >> 1
|
|
||||||
|
|
||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
mask = make_pad_mask(lengths)
|
mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
@ -282,13 +284,30 @@ class ConformerEncoder(nn.Module):
|
|||||||
>>> out = conformer_encoder(src, pos_emb)
|
>>> out = conformer_encoder(src, pos_emb)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_layer: nn.Module,
|
||||||
|
num_layers: int,
|
||||||
|
aux_layers: List[int],
|
||||||
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||||
)
|
)
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
assert num_layers - 1 not in aux_layers
|
||||||
|
self.aux_layers = set(aux_layers + [num_layers - 1])
|
||||||
|
|
||||||
|
num_channels = encoder_layer.norm_final.num_channels
|
||||||
|
self.combiner = RandomCombine(
|
||||||
|
num_inputs=len(self.aux_layers),
|
||||||
|
num_channels=num_channels,
|
||||||
|
final_weight=0.5,
|
||||||
|
pure_prob=0.333,
|
||||||
|
stddev=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
@ -315,6 +334,8 @@ class ConformerEncoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
for i, mod in enumerate(self.layers):
|
||||||
output = mod(
|
output = mod(
|
||||||
output,
|
output,
|
||||||
@ -323,6 +344,10 @@ class ConformerEncoder(nn.Module):
|
|||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
|
if i in self.aux_layers:
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
output = self.combiner(outputs)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -1022,15 +1047,281 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
x = self.out_balancer(x)
|
x = self.out_balancer(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
class RandomCombine(nn.Module):
|
||||||
|
"""
|
||||||
|
This module combines a list of Tensors, all with the same shape, to
|
||||||
|
produce a single output of that same shape which, in training time,
|
||||||
|
is a random combination of all the inputs; but which in test time
|
||||||
|
will be just the last input.
|
||||||
|
|
||||||
|
All but the last input will have a linear transform before we
|
||||||
|
randomly combine them; these linear transforms will be initialized
|
||||||
|
to the identity transform.
|
||||||
|
|
||||||
|
The idea is that the list of Tensors will be a list of outputs of multiple
|
||||||
|
conformer layers. This has a similar effect as iterated loss. (See:
|
||||||
|
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
|
||||||
|
NETWORKS).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_inputs: int,
|
||||||
|
num_channels: int,
|
||||||
|
final_weight: float = 0.5,
|
||||||
|
pure_prob: float = 0.5,
|
||||||
|
stddev: float = 2.0,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
num_inputs:
|
||||||
|
The number of tensor inputs, which equals the number of layers'
|
||||||
|
outputs that are fed into this module. E.g. in an 18-layer neural
|
||||||
|
net if we output layers 16, 12, 18, num_inputs would be 3.
|
||||||
|
num_channels:
|
||||||
|
The number of channels on the input, e.g. 512.
|
||||||
|
final_weight:
|
||||||
|
The amount of weight or probability we assign to the
|
||||||
|
final layer when randomly choosing layers or when choosing
|
||||||
|
continuous layer weights.
|
||||||
|
pure_prob:
|
||||||
|
The probability, on each frame, with which we choose
|
||||||
|
only a single layer to output (rather than an interpolation)
|
||||||
|
stddev:
|
||||||
|
A standard deviation that we add to log-probs for computing
|
||||||
|
randomized weights.
|
||||||
|
|
||||||
|
The method of choosing which layers, or combinations of layers, to use,
|
||||||
|
is conceptually as follows::
|
||||||
|
|
||||||
|
With probability `pure_prob`::
|
||||||
|
With probability `final_weight`: choose final layer,
|
||||||
|
Else: choose random non-final layer.
|
||||||
|
Else::
|
||||||
|
Choose initial log-weights that correspond to assigning
|
||||||
|
weight `final_weight` to the final layer and equal
|
||||||
|
weights to other layers; then add Gaussian noise
|
||||||
|
with variance `stddev` to these log-weights, and normalize
|
||||||
|
to weights (note: the average weight assigned to the
|
||||||
|
final layer here will not be `final_weight` if stddev>0).
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert 0 <= pure_prob <= 1, pure_prob
|
||||||
|
assert 0 < final_weight < 1, final_weight
|
||||||
|
assert num_inputs >= 1
|
||||||
|
|
||||||
|
self.linear = nn.ModuleList(
|
||||||
|
[
|
||||||
|
nn.Linear(num_channels, num_channels, bias=True)
|
||||||
|
for _ in range(num_inputs - 1)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_inputs = num_inputs
|
||||||
|
self.final_weight = final_weight
|
||||||
|
self.pure_prob = pure_prob
|
||||||
|
self.stddev = stddev
|
||||||
|
|
||||||
|
self.final_log_weight = (
|
||||||
|
torch.tensor(
|
||||||
|
(final_weight / (1 - final_weight)) * (self.num_inputs - 1)
|
||||||
|
)
|
||||||
|
.log()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
self._reset_parameters()
|
||||||
|
|
||||||
|
def _reset_parameters(self):
|
||||||
|
for i in range(len(self.linear)):
|
||||||
|
nn.init.eye_(self.linear[i].weight)
|
||||||
|
nn.init.constant_(self.linear[i].bias, 0.0)
|
||||||
|
|
||||||
|
def forward(self, inputs: List[Tensor]) -> Tensor:
|
||||||
|
"""Forward function.
|
||||||
|
Args:
|
||||||
|
inputs:
|
||||||
|
A list of Tensor, e.g. from various layers of a transformer.
|
||||||
|
All must be the same shape, of (*, num_channels)
|
||||||
|
Returns:
|
||||||
|
A Tensor of shape (*, num_channels). In test mode
|
||||||
|
this is just the final input.
|
||||||
|
"""
|
||||||
|
num_inputs = self.num_inputs
|
||||||
|
assert len(inputs) == num_inputs
|
||||||
|
if not self.training:
|
||||||
|
return inputs[-1]
|
||||||
|
|
||||||
|
# Shape of weights: (*, num_inputs)
|
||||||
|
num_channels = inputs[0].shape[-1]
|
||||||
|
num_frames = inputs[0].numel() // num_channels
|
||||||
|
|
||||||
|
mod_inputs = []
|
||||||
|
for i in range(num_inputs - 1):
|
||||||
|
mod_inputs.append(self.linear[i](inputs[i]))
|
||||||
|
mod_inputs.append(inputs[num_inputs - 1])
|
||||||
|
|
||||||
|
ndim = inputs[0].ndim
|
||||||
|
# stacked_inputs: (num_frames, num_channels, num_inputs)
|
||||||
|
stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape(
|
||||||
|
(num_frames, num_channels, num_inputs)
|
||||||
|
)
|
||||||
|
|
||||||
|
# weights: (num_frames, num_inputs)
|
||||||
|
weights = self._get_random_weights(
|
||||||
|
inputs[0].dtype, inputs[0].device, num_frames
|
||||||
|
)
|
||||||
|
|
||||||
|
weights = weights.reshape(num_frames, num_inputs, 1)
|
||||||
|
# ans: (num_frames, num_channels, 1)
|
||||||
|
ans = torch.matmul(stacked_inputs, weights)
|
||||||
|
# ans: (*, num_channels)
|
||||||
|
ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# for testing only...
|
||||||
|
print("Weights = ", weights.reshape(num_frames, num_inputs))
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def _get_random_weights(
|
||||||
|
self, dtype: torch.dtype, device: torch.device, num_frames: int
|
||||||
|
) -> Tensor:
|
||||||
|
"""Return a tensor of random weights, of shape
|
||||||
|
`(num_frames, self.num_inputs)`,
|
||||||
|
Args:
|
||||||
|
dtype:
|
||||||
|
The data-type desired for the answer, e.g. float, double.
|
||||||
|
device:
|
||||||
|
The device needed for the answer.
|
||||||
|
num_frames:
|
||||||
|
The number of sets of weights desired
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (num_frames, self.num_inputs), such that
|
||||||
|
`ans.sum(dim=1)` is all ones.
|
||||||
|
"""
|
||||||
|
pure_prob = self.pure_prob
|
||||||
|
if pure_prob == 0.0:
|
||||||
|
return self._get_random_mixed_weights(dtype, device, num_frames)
|
||||||
|
elif pure_prob == 1.0:
|
||||||
|
return self._get_random_pure_weights(dtype, device, num_frames)
|
||||||
|
else:
|
||||||
|
p = self._get_random_pure_weights(dtype, device, num_frames)
|
||||||
|
m = self._get_random_mixed_weights(dtype, device, num_frames)
|
||||||
|
return torch.where(
|
||||||
|
torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_random_pure_weights(
|
||||||
|
self, dtype: torch.dtype, device: torch.device, num_frames: int
|
||||||
|
):
|
||||||
|
"""Return a tensor of random one-hot weights, of shape
|
||||||
|
`(num_frames, self.num_inputs)`,
|
||||||
|
Args:
|
||||||
|
dtype:
|
||||||
|
The data-type desired for the answer, e.g. float, double.
|
||||||
|
device:
|
||||||
|
The device needed for the answer.
|
||||||
|
num_frames:
|
||||||
|
The number of sets of weights desired.
|
||||||
|
Returns:
|
||||||
|
A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
|
||||||
|
exactly one weight equal to 1.0 on each frame.
|
||||||
|
"""
|
||||||
|
final_prob = self.final_weight
|
||||||
|
|
||||||
|
# final contains self.num_inputs - 1 in all elements
|
||||||
|
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
|
||||||
|
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights.
|
||||||
|
nonfinal = torch.randint(
|
||||||
|
self.num_inputs - 1, (num_frames,), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
indexes = torch.where(
|
||||||
|
torch.rand(num_frames, device=device) < final_prob, final, nonfinal
|
||||||
|
)
|
||||||
|
ans = torch.nn.functional.one_hot(
|
||||||
|
indexes, num_classes=self.num_inputs
|
||||||
|
).to(dtype=dtype)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def _get_random_mixed_weights(
|
||||||
|
self, dtype: torch.dtype, device: torch.device, num_frames: int
|
||||||
|
):
|
||||||
|
"""Return a tensor of random one-hot weights, of shape
|
||||||
|
`(num_frames, self.num_inputs)`,
|
||||||
|
Args:
|
||||||
|
dtype:
|
||||||
|
The data-type desired for the answer, e.g. float, double.
|
||||||
|
device:
|
||||||
|
The device needed for the answer.
|
||||||
|
num_frames:
|
||||||
|
The number of sets of weights desired.
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (num_frames, self.num_inputs), which elements
|
||||||
|
in [0..1] that sum to one over the second axis, i.e.
|
||||||
|
`ans.sum(dim=1)` is all ones.
|
||||||
|
"""
|
||||||
|
logprobs = (
|
||||||
|
torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
|
||||||
|
* self.stddev
|
||||||
|
)
|
||||||
|
logprobs[:, -1] += self.final_log_weight
|
||||||
|
return logprobs.softmax(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
|
||||||
|
print(
|
||||||
|
f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}"
|
||||||
|
)
|
||||||
|
num_inputs = 3
|
||||||
|
num_channels = 50
|
||||||
|
m = RandomCombine(
|
||||||
|
num_inputs=num_inputs,
|
||||||
|
num_channels=num_channels,
|
||||||
|
final_weight=final_weight,
|
||||||
|
pure_prob=pure_prob,
|
||||||
|
stddev=stddev,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
|
||||||
|
|
||||||
|
y = m(x)
|
||||||
|
assert y.shape == x[0].shape
|
||||||
|
assert torch.allclose(y, x[0]) # .. since actually all ones.
|
||||||
|
|
||||||
|
|
||||||
|
def _test_random_combine_main():
|
||||||
|
_test_random_combine(0.999, 0, 0.0)
|
||||||
|
_test_random_combine(0.5, 0, 0.0)
|
||||||
|
_test_random_combine(0.999, 0, 0.0)
|
||||||
|
_test_random_combine(0.5, 0, 0.3)
|
||||||
|
_test_random_combine(0.5, 1, 0.3)
|
||||||
|
_test_random_combine(0.5, 0.5, 0.3)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_conformer_main():
|
||||||
feature_dim = 50
|
feature_dim = 50
|
||||||
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
c = Conformer(num_features=feature_dim, d_model=128, nhead=4)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
seq_len = 20
|
seq_len = 20
|
||||||
|
feature_dim = 50
|
||||||
|
# Just make sure the forward pass runs.
|
||||||
|
|
||||||
|
c = Conformer(
|
||||||
|
num_features=feature_dim, d_model=128, nhead=4
|
||||||
|
)
|
||||||
|
batch_size = 5
|
||||||
|
seq_len = 20
|
||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
f = c(
|
f = c(
|
||||||
torch.randn(batch_size, seq_len, feature_dim),
|
torch.randn(batch_size, seq_len, feature_dim),
|
||||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
warmup=0.5,
|
warmup=0.5,
|
||||||
)
|
)
|
||||||
|
f # to remove flake8 warnings
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
_test_conformer_main()
|
||||||
|
_test_random_combine_main()
|
||||||
|
@ -19,36 +19,36 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless4/decode.py \
|
./pruned_transducer_stateless7/decode.py \
|
||||||
--epoch 30 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search (not recommended)
|
(2) beam search (not recommended)
|
||||||
./pruned_transducer_stateless4/decode.py \
|
./pruned_transducer_stateless7/decode.py \
|
||||||
--epoch 30 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless4/decode.py \
|
./pruned_transducer_stateless7/decode.py \
|
||||||
--epoch 30 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless4/decode.py \
|
./pruned_transducer_stateless7/decode.py \
|
||||||
--epoch 30 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless4/exp \
|
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||||
--max-duration 600 \
|
--max-duration 600 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 4 \
|
--beam 4 \
|
||||||
@ -75,7 +75,7 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from train import get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -139,7 +139,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless4/exp",
|
default="pruned_transducer_stateless7/exp",
|
||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -212,6 +212,8 @@ def get_parser():
|
|||||||
Used only when --decoding_method is greedy_search""",
|
Used only when --decoding_method is greedy_search""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -374,7 +376,7 @@ def decode_dataset(
|
|||||||
if params.decoding_method == "greedy_search":
|
if params.decoding_method == "greedy_search":
|
||||||
log_interval = 50
|
log_interval = 50
|
||||||
else:
|
else:
|
||||||
log_interval = 10
|
log_interval = 20
|
||||||
|
|
||||||
results = defaultdict(list)
|
results = defaultdict(list)
|
||||||
for batch_idx, batch in enumerate(dl):
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
@ -23,27 +23,42 @@ To run this file, do:
|
|||||||
python ./pruned_transducer_stateless4/test_model.py
|
python ./pruned_transducer_stateless4/test_model.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
|
||||||
from train import get_params, get_transducer_model
|
from train import get_params, get_transducer_model
|
||||||
|
|
||||||
|
|
||||||
def test_model():
|
def test_model_1():
|
||||||
params = get_params()
|
params = get_params()
|
||||||
params.vocab_size = 500
|
params.vocab_size = 500
|
||||||
params.blank_id = 0
|
params.blank_id = 0
|
||||||
params.context_size = 2
|
params.context_size = 2
|
||||||
params.unk_id = 2
|
params.num_encoder_layers = 24
|
||||||
|
params.dim_feedforward = 1536 # 384 * 4
|
||||||
|
params.encoder_dim = 384
|
||||||
|
model = get_transducer_model(params)
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
print(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
|
||||||
|
# See Table 1 from https://arxiv.org/pdf/2005.08100.pdf
|
||||||
|
def test_model_M():
|
||||||
|
params = get_params()
|
||||||
|
params.vocab_size = 500
|
||||||
|
params.blank_id = 0
|
||||||
|
params.context_size = 2
|
||||||
|
params.num_encoder_layers = 18
|
||||||
|
params.dim_feedforward = 1024
|
||||||
|
params.encoder_dim = 256
|
||||||
|
params.nhead = 4
|
||||||
|
params.decoder_dim = 512
|
||||||
|
params.joiner_dim = 512
|
||||||
model = get_transducer_model(params)
|
model = get_transducer_model(params)
|
||||||
|
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
print(f"Number of model parameters: {num_param}")
|
print(f"Number of model parameters: {num_param}")
|
||||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
|
||||||
torch.jit.script(model)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
test_model()
|
# test_model_1()
|
||||||
|
test_model_M()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang,
|
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
# Wei Kang,
|
# Wei Kang,
|
||||||
# Mingshuang Luo,)
|
# Mingshuang Luo,)
|
||||||
# Zengwei Yao)
|
# Zengwei Yao)
|
||||||
@ -22,22 +22,22 @@ Usage:
|
|||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
./pruned_transducer_stateless4/train.py \
|
./pruned_transducer_stateless7/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--exp-dir pruned_transducer_stateless2/exp \
|
--exp-dir pruned_transducer_stateless7/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 300
|
--max-duration 300
|
||||||
|
|
||||||
# For mix precision training:
|
# For mix precision training:
|
||||||
|
|
||||||
./pruned_transducer_stateless4/train.py \
|
./pruned_transducer_stateless7/train.py \
|
||||||
--world-size 4 \
|
--world-size 4 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 1 \
|
--start-epoch 1 \
|
||||||
--use-fp16 1 \
|
--use-fp16 1 \
|
||||||
--exp-dir pruned_transducer_stateless2/exp \
|
--exp-dir pruned_transducer_stateless7/exp \
|
||||||
--full-libri 1 \
|
--full-libri 1 \
|
||||||
--max-duration 550
|
--max-duration 550
|
||||||
|
|
||||||
@ -88,6 +88,53 @@ LRSchedulerType = Union[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-encoder-layers",
|
||||||
|
type=int,
|
||||||
|
default=24,
|
||||||
|
help="Number of conformer encoder layers..",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--dim-feedforward",
|
||||||
|
type=int,
|
||||||
|
default=1536,
|
||||||
|
help="Feedforward dimension of the conformer encoder layer.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nhead",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="Number of attention heads in the conformer encoder layer.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-dim",
|
||||||
|
type=int,
|
||||||
|
default=384,
|
||||||
|
help="Attention dimension in the conformer encoder layer.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoder-dim",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="Embedding dimension in the decoder model.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--joiner-dim",
|
||||||
|
type=int,
|
||||||
|
default=512,
|
||||||
|
help="""Dimension used in the joiner model.
|
||||||
|
Outputs from the encoder and decoder model are projected
|
||||||
|
to this dimension before adding.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -143,7 +190,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless2/exp",
|
default="pruned_transducer_stateless7/exp",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
@ -161,16 +208,16 @@ def get_parser():
|
|||||||
"--initial-lr",
|
"--initial-lr",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.003,
|
default=0.003,
|
||||||
help="""The initial learning rate. This value should not need to be
|
help="The initial learning rate. This value should not need "
|
||||||
changed.""",
|
"to be changed.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lr-batches",
|
"--lr-batches",
|
||||||
type=float,
|
type=float,
|
||||||
default=5000,
|
default=5000,
|
||||||
help="""Number of steps that affects how rapidly the learning rate decreases.
|
help="""Number of steps that affects how rapidly the learning rate
|
||||||
We suggest not to change this.""",
|
decreases. We suggest not to change this.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -240,7 +287,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--save-every-n",
|
"--save-every-n",
|
||||||
type=int,
|
type=int,
|
||||||
default=8000,
|
default=4000,
|
||||||
help="""Save checkpoint after processing this number of batches"
|
help="""Save checkpoint after processing this number of batches"
|
||||||
periodically. We save checkpoint to exp-dir/ whenever
|
periodically. We save checkpoint to exp-dir/ whenever
|
||||||
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
params.batch_idx_train % save_every_n == 0. The checkpoint filename
|
||||||
@ -253,7 +300,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--keep-last-k",
|
"--keep-last-k",
|
||||||
type=int,
|
type=int,
|
||||||
default=20,
|
default=30,
|
||||||
help="""Only keep this number of checkpoints on disk.
|
help="""Only keep this number of checkpoints on disk.
|
||||||
For instance, if it is 3, there are only 3 checkpoints
|
For instance, if it is 3, there are only 3 checkpoints
|
||||||
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
in the exp-dir with filenames `checkpoint-xxx.pt`.
|
||||||
@ -281,6 +328,8 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -341,14 +390,6 @@ def get_params() -> AttributeDict:
|
|||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"encoder_dim": 512,
|
|
||||||
"nhead": 8,
|
|
||||||
"dim_feedforward": 2048,
|
|
||||||
"num_encoder_layers": 12,
|
|
||||||
# parameters for decoder
|
|
||||||
"decoder_dim": 512,
|
|
||||||
# parameters for joiner
|
|
||||||
"joiner_dim": 512,
|
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
@ -704,6 +745,7 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train += 1
|
params.batch_idx_train += 1
|
||||||
batch_size = len(batch["supervisions"]["text"])
|
batch_size = len(batch["supervisions"]["text"])
|
||||||
|
|
||||||
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
@ -723,8 +765,11 @@ def train_one_epoch(
|
|||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
except: # noqa
|
||||||
|
display_and_save_batch(batch, params=params, sp=sp)
|
||||||
|
raise
|
||||||
|
|
||||||
if params.print_diagnostics and batch_idx == 30:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@ -888,7 +933,10 @@ def run(rank, world_size, args):
|
|||||||
scheduler.load_state_dict(checkpoints["scheduler"])
|
scheduler.load_state_dict(checkpoints["scheduler"])
|
||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
diagnostic = diagnostics.attach_diagnostics(model)
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
|
2 ** 22
|
||||||
|
) # allow 4 megabytes per sub-module
|
||||||
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
librispeech = LibriSpeechAsrDataModule(args)
|
librispeech = LibriSpeechAsrDataModule(args)
|
||||||
|
|
||||||
@ -986,6 +1034,38 @@ def run(rank, world_size, args):
|
|||||||
cleanup_dist()
|
cleanup_dist()
|
||||||
|
|
||||||
|
|
||||||
|
def display_and_save_batch(
|
||||||
|
batch: dict,
|
||||||
|
params: AttributeDict,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
) -> None:
|
||||||
|
"""Display the batch statistics and save the batch into disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch:
|
||||||
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||||
|
for the content in it.
|
||||||
|
params:
|
||||||
|
Parameters for training. See :func:`get_params`.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
"""
|
||||||
|
from lhotse.utils import uuid4
|
||||||
|
|
||||||
|
filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
|
||||||
|
logging.info(f"Saving batch to {filename}")
|
||||||
|
torch.save(batch, filename)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
features = batch["inputs"]
|
||||||
|
|
||||||
|
logging.info(f"features shape: {features.shape}")
|
||||||
|
|
||||||
|
y = sp.encode(supervisions["text"], out_type=int)
|
||||||
|
num_tokens = sum(len(i) for i in y)
|
||||||
|
logging.info(f"num tokens: {num_tokens}")
|
||||||
|
|
||||||
|
|
||||||
def scan_pessimistic_batches_for_oom(
|
def scan_pessimistic_batches_for_oom(
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
@ -1017,7 +1097,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
except RuntimeError as e:
|
except Exception as e:
|
||||||
if "CUDA out of memory" in str(e):
|
if "CUDA out of memory" in str(e):
|
||||||
logging.error(
|
logging.error(
|
||||||
"Your GPU ran out of memory with the current "
|
"Your GPU ran out of memory with the current "
|
||||||
@ -1026,6 +1106,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
f"Failing criterion: {criterion} "
|
f"Failing criterion: {criterion} "
|
||||||
f"(={crit_values[criterion]}) ..."
|
f"(={crit_values[criterion]}) ..."
|
||||||
)
|
)
|
||||||
|
display_and_save_batch(batch, params=params, sp=sp)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user