Add warmup mode

This commit is contained in:
Daniel Povey 2022-03-14 23:04:51 +08:00
parent 8d17a05dd2
commit a23010fc10
4 changed files with 29 additions and 36 deletions

View File

@ -88,7 +88,7 @@ class Conformer(Transformer):
def forward( def forward(
self, x: torch.Tensor, x_lens: torch.Tensor self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -112,7 +112,8 @@ class Conformer(Transformer):
assert x.size(0) == lengths.max().item() assert x.size(0) == lengths.max().item()
mask = make_pad_mask(lengths) mask = make_pad_mask(lengths)
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C) x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
warmup_mode=warmup_mode) # (T, N, C)
logits = self.encoder_output_layer(x) logits = self.encoder_output_layer(x)
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C) logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
@ -258,7 +259,6 @@ class ConformerEncoder(nn.Module):
self.num_layers = num_layers self.num_layers = num_layers
num_channels = encoder_layer.d_model num_channels = encoder_layer.d_model
self.combiner = RandomCombine(num_inputs=len(self.aux_layers), self.combiner = RandomCombine(num_inputs=len(self.aux_layers),
num_channels=num_channels,
final_weight=0.5, final_weight=0.5,
pure_prob=0.333, pure_prob=0.333,
stddev=2.0) stddev=2.0)
@ -269,6 +269,7 @@ class ConformerEncoder(nn.Module):
pos_emb: Tensor, pos_emb: Tensor,
mask: Optional[Tensor] = None, mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
warmup_mode: bool = False
) -> Tensor: ) -> Tensor:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
@ -300,7 +301,7 @@ class ConformerEncoder(nn.Module):
if i in self.aux_layers: if i in self.aux_layers:
outputs.append(output) outputs.append(output)
output = self.combiner(outputs) output = self.combiner(outputs, warmup_mode)
return output return output
@ -946,17 +947,12 @@ class RandomCombine(torch.nn.Module):
is a random combination of all the inputs; but which in test time is a random combination of all the inputs; but which in test time
will be just the last input. will be just the last input.
All but the last input will have a linear transform before we
randomly combine them; these linear transforms will be initialzed
to the identity transform.
The idea is that the list of Tensors will be a list of outputs of multiple The idea is that the list of Tensors will be a list of outputs of multiple
conformer layers. This has a similar effect as iterated loss. (See: conformer layers. This has a similar effect as iterated loss. (See:
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
NETWORKS). NETWORKS).
""" """
def __init__(self, num_inputs: int, def __init__(self, num_inputs: int,
num_channels: int,
final_weight: float = 0.5, final_weight: float = 0.5,
pure_prob: float = 0.5, pure_prob: float = 0.5,
stddev: float = 2.0) -> None: stddev: float = 2.0) -> None:
@ -965,7 +961,6 @@ class RandomCombine(torch.nn.Module):
num_inputs: The number of tensor inputs, which equals the number of layers' num_inputs: The number of tensor inputs, which equals the number of layers'
outputs that are fed into this module. E.g. in an 18-layer neural outputs that are fed into this module. E.g. in an 18-layer neural
net if we output layers 16, 12, 18, num_inputs would be 3. net if we output layers 16, 12, 18, num_inputs would be 3.
num_channels: The number of channels on the input, e.g. 512.
final_weight: The amount of weight or probability we assign to the final_weight: The amount of weight or probability we assign to the
final layer when randomly choosing layers or when choosing final layer when randomly choosing layers or when choosing
continuous layer weights. continuous layer weights.
@ -991,8 +986,6 @@ class RandomCombine(torch.nn.Module):
assert pure_prob >= 0 and pure_prob <= 1 assert pure_prob >= 0 and pure_prob <= 1
assert final_weight > 0 and final_weight < 1 assert final_weight > 0 and final_weight < 1
assert num_inputs >= 1 assert num_inputs >= 1
self.linear = nn.ModuleList([ScaledLinear(num_channels, num_channels, bias=True)
for _ in range(num_inputs - 1)])
self.num_inputs = num_inputs self.num_inputs = num_inputs
self.final_weight = final_weight self.final_weight = final_weight
@ -1000,14 +993,10 @@ class RandomCombine(torch.nn.Module):
self.stddev= stddev self.stddev= stddev
self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item() self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item()
self._reset_parameters()
def _reset_parameters(self):
for i in range(len(self.linear)):
nn.init.eye_(self.linear[i].weight)
nn.init.constant_(self.linear[i].bias, 0.0)
def forward(self, inputs: Sequence[Tensor]) -> Tensor: def forward(self, inputs: Sequence[Tensor],
warmup_mode: bool) -> Tensor:
""" """
Forward function. Forward function.
Args: Args:
@ -1019,24 +1008,18 @@ class RandomCombine(torch.nn.Module):
""" """
num_inputs = self.num_inputs num_inputs = self.num_inputs
assert len(inputs) == num_inputs assert len(inputs) == num_inputs
if not self.training: if not (self.training and warmup_mode):
return inputs[-1] return inputs[-1]
# Shape of weights: (*, num_inputs) # Shape of weights: (*, num_inputs)
num_channels = inputs[0].shape[-1] num_channels = inputs[0].shape[-1]
num_frames = inputs[0].numel() // num_channels num_frames = inputs[0].numel() // num_channels
mod_inputs = []
for i in range(num_inputs - 1):
mod_inputs.append(self.linear[i](inputs[i]))
mod_inputs.append(inputs[num_inputs - 1])
ndim = inputs[0].ndim ndim = inputs[0].ndim
# stacked_inputs: (num_frames, num_channels, num_inputs) # stacked_inputs: (num_frames, num_channels, num_inputs)
stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape((num_frames, stacked_inputs = torch.stack(inputs, dim=ndim).reshape((num_frames,
num_channels, num_channels,
num_inputs)) num_inputs))
# weights: (num_frames, num_inputs) # weights: (num_frames, num_inputs)
weights = self._get_random_weights(inputs[0].dtype, inputs[0].device, weights = self._get_random_weights(inputs[0].dtype, inputs[0].device,
@ -1118,12 +1101,14 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}") print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}")
num_inputs = 3 num_inputs = 3
num_channels = 50 num_channels = 50
m = RandomCombine(num_inputs=num_inputs, num_channels=num_channels, m = RandomCombine(num_inputs=num_inputs,
final_weight=final_weight, pure_prob=pure_prob, stddev=stddev) final_weight=final_weight,
pure_prob=pure_prob,
stddev=stddev)
x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ] x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ]
y = m(x) y = m(x, True)
assert y.shape == x[0].shape assert y.shape == x[0].shape
assert torch.allclose(y, x[0]) # .. since actually all ones. assert torch.allclose(y, x[0]) # .. since actually all ones.

View File

@ -22,7 +22,7 @@ import torch.nn as nn
class EncoderInterface(nn.Module): class EncoderInterface(nn.Module):
def forward( def forward(
self, x: torch.Tensor, x_lens: torch.Tensor self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -32,6 +32,8 @@ class EncoderInterface(nn.Module):
x_lens: x_lens:
A tensor of shape (batch_size,) containing the number of frames A tensor of shape (batch_size,) containing the number of frames
in `x` before padding. in `x` before padding.
warmup_mode: for training only, if true then train in
"warmup mode" (use this for the first few thousand minibatches).
Returns: Returns:
Return a tuple containing two tensors: Return a tuple containing two tensors:
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim) - encoder_out, a tensor of (batch_size, out_seq_len, output_dim)

View File

@ -62,6 +62,7 @@ class Transducer(nn.Module):
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
y: k2.RaggedTensor, y: k2.RaggedTensor,
warmup_mode: bool = False
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Args: Args:
@ -82,7 +83,7 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0 assert x.size(0) == x_lens.size(0) == y.dim0
encoder_out, x_lens = self.encoder(x, x_lens) encoder_out, x_lens = self.encoder(x, x_lens, warmup_mode)
assert torch.all(x_lens > 0) assert torch.all(x_lens > 0)
# Now for the decoder, i.e., the prediction network # Now for the decoder, i.e., the prediction network

View File

@ -110,7 +110,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2", default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -203,6 +203,7 @@ def get_params() -> AttributeDict:
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 3000, # For the 100h subset, use 800 "valid_interval": 3000, # For the 100h subset, use 800
"warmup_minibatches": 3000, # use warmup mode for 3k minibatches.
# parameters for conformer # parameters for conformer
"feature_dim": 80, "feature_dim": 80,
"encoder_out_dim": 512, "encoder_out_dim": 512,
@ -360,6 +361,7 @@ def compute_loss(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
is_warmup_mode: bool = False
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute CTC loss given the model and its inputs. Compute CTC loss given the model and its inputs.
@ -391,7 +393,8 @@ def compute_loss(
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
loss = model(x=feature, x_lens=feature_lens, y=y) loss = model(x=feature, x_lens=feature_lens, y=y,
warmup_mode=is_warmup_mode)
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
@ -423,6 +426,7 @@ def compute_validation_loss(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=False, is_training=False,
is_warmup_mode=False
) )
assert loss.requires_grad is False assert loss.requires_grad is False
tot_loss = tot_loss + loss_info tot_loss = tot_loss + loss_info
@ -484,6 +488,7 @@ def train_one_epoch(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
is_warmup_mode=(params.batch_idx_train<params.warmup_minibatches)
) )
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -498,7 +503,6 @@ def train_one_epoch(
if params.print_diagnostics and batch_idx == 5: if params.print_diagnostics and batch_idx == 5:
return return
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
logging.info( logging.info(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
@ -715,6 +719,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
is_warmup_mode=False
) )
loss.backward() loss.backward()
clip_grad_norm_(model.parameters(), 5.0, 2.0) clip_grad_norm_(model.parameters(), 5.0, 2.0)