Add warmup mode

This commit is contained in:
Daniel Povey 2022-03-14 23:04:51 +08:00
parent 8d17a05dd2
commit a23010fc10
4 changed files with 29 additions and 36 deletions

View File

@ -88,7 +88,7 @@ class Conformer(Transformer):
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
@ -112,7 +112,8 @@ class Conformer(Transformer):
assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
warmup_mode=warmup_mode) # (T, N, C)
logits = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -258,7 +259,6 @@ class ConformerEncoder(nn.Module):
self.num_layers = num_layers
num_channels = encoder_layer.d_model
self.combiner = RandomCombine(num_inputs=len(self.aux_layers),
num_channels=num_channels,
final_weight=0.5,
pure_prob=0.333,
stddev=2.0)
@ -269,6 +269,7 @@ class ConformerEncoder(nn.Module):
pos_emb: Tensor,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup_mode: bool = False
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
@ -300,7 +301,7 @@ class ConformerEncoder(nn.Module):
if i in self.aux_layers:
outputs.append(output)
output = self.combiner(outputs)
output = self.combiner(outputs, warmup_mode)
return output
@ -946,17 +947,12 @@ class RandomCombine(torch.nn.Module):
is a random combination of all the inputs; but which in test time
will be just the last input.
All but the last input will have a linear transform before we
randomly combine them; these linear transforms will be initialzed
to the identity transform.
The idea is that the list of Tensors will be a list of outputs of multiple
conformer layers. This has a similar effect as iterated loss. (See:
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
NETWORKS).
"""
def __init__(self, num_inputs: int,
num_channels: int,
final_weight: float = 0.5,
pure_prob: float = 0.5,
stddev: float = 2.0) -> None:
@ -965,7 +961,6 @@ class RandomCombine(torch.nn.Module):
num_inputs: The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3.
num_channels: The number of channels on the input, e.g. 512.
final_weight: The amount of weight or probability we assign to the
final layer when randomly choosing layers or when choosing
continuous layer weights.
@ -991,8 +986,6 @@ class RandomCombine(torch.nn.Module):
assert pure_prob >= 0 and pure_prob <= 1
assert final_weight > 0 and final_weight < 1
assert num_inputs >= 1
self.linear = nn.ModuleList([ScaledLinear(num_channels, num_channels, bias=True)
for _ in range(num_inputs - 1)])
self.num_inputs = num_inputs
self.final_weight = final_weight
@ -1000,14 +993,10 @@ class RandomCombine(torch.nn.Module):
self.stddev= stddev
self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item()
self._reset_parameters()
def _reset_parameters(self):
for i in range(len(self.linear)):
nn.init.eye_(self.linear[i].weight)
nn.init.constant_(self.linear[i].bias, 0.0)
def forward(self, inputs: Sequence[Tensor]) -> Tensor:
def forward(self, inputs: Sequence[Tensor],
warmup_mode: bool) -> Tensor:
"""
Forward function.
Args:
@ -1019,24 +1008,18 @@ class RandomCombine(torch.nn.Module):
"""
num_inputs = self.num_inputs
assert len(inputs) == num_inputs
if not self.training:
if not (self.training and warmup_mode):
return inputs[-1]
# Shape of weights: (*, num_inputs)
num_channels = inputs[0].shape[-1]
num_frames = inputs[0].numel() // num_channels
mod_inputs = []
for i in range(num_inputs - 1):
mod_inputs.append(self.linear[i](inputs[i]))
mod_inputs.append(inputs[num_inputs - 1])
ndim = inputs[0].ndim
# stacked_inputs: (num_frames, num_channels, num_inputs)
stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape((num_frames,
num_channels,
num_inputs))
stacked_inputs = torch.stack(inputs, dim=ndim).reshape((num_frames,
num_channels,
num_inputs))
# weights: (num_frames, num_inputs)
weights = self._get_random_weights(inputs[0].dtype, inputs[0].device,
@ -1118,12 +1101,14 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}")
num_inputs = 3
num_channels = 50
m = RandomCombine(num_inputs=num_inputs, num_channels=num_channels,
final_weight=final_weight, pure_prob=pure_prob, stddev=stddev)
m = RandomCombine(num_inputs=num_inputs,
final_weight=final_weight,
pure_prob=pure_prob,
stddev=stddev)
x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ]
y = m(x)
y = m(x, True)
assert y.shape == x[0].shape
assert torch.allclose(y, x[0]) # .. since actually all ones.

View File

@ -22,7 +22,7 @@ import torch.nn as nn
class EncoderInterface(nn.Module):
def forward(
self, x: torch.Tensor, x_lens: torch.Tensor
self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
@ -32,6 +32,8 @@ class EncoderInterface(nn.Module):
x_lens:
A tensor of shape (batch_size,) containing the number of frames
in `x` before padding.
warmup_mode: for training only, if true then train in
"warmup mode" (use this for the first few thousand minibatches).
Returns:
Return a tuple containing two tensors:
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)

View File

@ -62,6 +62,7 @@ class Transducer(nn.Module):
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
warmup_mode: bool = False
) -> torch.Tensor:
"""
Args:
@ -82,7 +83,7 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens)
encoder_out, x_lens = self.encoder(x, x_lens, warmup_mode)
assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2",
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -203,6 +203,7 @@ def get_params() -> AttributeDict:
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800
"warmup_minibatches": 3000, # use warmup mode for 3k minibatches.
# parameters for conformer
"feature_dim": 80,
"encoder_out_dim": 512,
@ -360,6 +361,7 @@ def compute_loss(
sp: spm.SentencePieceProcessor,
batch: dict,
is_training: bool,
is_warmup_mode: bool = False
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
@ -391,7 +393,8 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=y)
loss = model(x=feature, x_lens=feature_lens, y=y,
warmup_mode=is_warmup_mode)
assert loss.requires_grad == is_training
@ -423,6 +426,7 @@ def compute_validation_loss(
sp=sp,
batch=batch,
is_training=False,
is_warmup_mode=False
)
assert loss.requires_grad is False
tot_loss = tot_loss + loss_info
@ -484,6 +488,7 @@ def train_one_epoch(
sp=sp,
batch=batch,
is_training=True,
is_warmup_mode=(params.batch_idx_train<params.warmup_minibatches)
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -498,7 +503,6 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 5:
return
if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, "
@ -715,6 +719,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp,
batch=batch,
is_training=True,
is_warmup_mode=False
)
loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0)