mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
Add warmup mode
This commit is contained in:
parent
8d17a05dd2
commit
a23010fc10
@ -88,7 +88,7 @@ class Conformer(Transformer):
|
|||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -112,7 +112,8 @@ class Conformer(Transformer):
|
|||||||
assert x.size(0) == lengths.max().item()
|
assert x.size(0) == lengths.max().item()
|
||||||
mask = make_pad_mask(lengths)
|
mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask) # (T, N, C)
|
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
|
||||||
|
warmup_mode=warmup_mode) # (T, N, C)
|
||||||
|
|
||||||
logits = self.encoder_output_layer(x)
|
logits = self.encoder_output_layer(x)
|
||||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
@ -258,7 +259,6 @@ class ConformerEncoder(nn.Module):
|
|||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
num_channels = encoder_layer.d_model
|
num_channels = encoder_layer.d_model
|
||||||
self.combiner = RandomCombine(num_inputs=len(self.aux_layers),
|
self.combiner = RandomCombine(num_inputs=len(self.aux_layers),
|
||||||
num_channels=num_channels,
|
|
||||||
final_weight=0.5,
|
final_weight=0.5,
|
||||||
pure_prob=0.333,
|
pure_prob=0.333,
|
||||||
stddev=2.0)
|
stddev=2.0)
|
||||||
@ -269,6 +269,7 @@ class ConformerEncoder(nn.Module):
|
|||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
warmup_mode: bool = False
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
@ -300,7 +301,7 @@ class ConformerEncoder(nn.Module):
|
|||||||
if i in self.aux_layers:
|
if i in self.aux_layers:
|
||||||
outputs.append(output)
|
outputs.append(output)
|
||||||
|
|
||||||
output = self.combiner(outputs)
|
output = self.combiner(outputs, warmup_mode)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@ -946,17 +947,12 @@ class RandomCombine(torch.nn.Module):
|
|||||||
is a random combination of all the inputs; but which in test time
|
is a random combination of all the inputs; but which in test time
|
||||||
will be just the last input.
|
will be just the last input.
|
||||||
|
|
||||||
All but the last input will have a linear transform before we
|
|
||||||
randomly combine them; these linear transforms will be initialzed
|
|
||||||
to the identity transform.
|
|
||||||
|
|
||||||
The idea is that the list of Tensors will be a list of outputs of multiple
|
The idea is that the list of Tensors will be a list of outputs of multiple
|
||||||
conformer layers. This has a similar effect as iterated loss. (See:
|
conformer layers. This has a similar effect as iterated loss. (See:
|
||||||
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
|
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
|
||||||
NETWORKS).
|
NETWORKS).
|
||||||
"""
|
"""
|
||||||
def __init__(self, num_inputs: int,
|
def __init__(self, num_inputs: int,
|
||||||
num_channels: int,
|
|
||||||
final_weight: float = 0.5,
|
final_weight: float = 0.5,
|
||||||
pure_prob: float = 0.5,
|
pure_prob: float = 0.5,
|
||||||
stddev: float = 2.0) -> None:
|
stddev: float = 2.0) -> None:
|
||||||
@ -965,7 +961,6 @@ class RandomCombine(torch.nn.Module):
|
|||||||
num_inputs: The number of tensor inputs, which equals the number of layers'
|
num_inputs: The number of tensor inputs, which equals the number of layers'
|
||||||
outputs that are fed into this module. E.g. in an 18-layer neural
|
outputs that are fed into this module. E.g. in an 18-layer neural
|
||||||
net if we output layers 16, 12, 18, num_inputs would be 3.
|
net if we output layers 16, 12, 18, num_inputs would be 3.
|
||||||
num_channels: The number of channels on the input, e.g. 512.
|
|
||||||
final_weight: The amount of weight or probability we assign to the
|
final_weight: The amount of weight or probability we assign to the
|
||||||
final layer when randomly choosing layers or when choosing
|
final layer when randomly choosing layers or when choosing
|
||||||
continuous layer weights.
|
continuous layer weights.
|
||||||
@ -991,8 +986,6 @@ class RandomCombine(torch.nn.Module):
|
|||||||
assert pure_prob >= 0 and pure_prob <= 1
|
assert pure_prob >= 0 and pure_prob <= 1
|
||||||
assert final_weight > 0 and final_weight < 1
|
assert final_weight > 0 and final_weight < 1
|
||||||
assert num_inputs >= 1
|
assert num_inputs >= 1
|
||||||
self.linear = nn.ModuleList([ScaledLinear(num_channels, num_channels, bias=True)
|
|
||||||
for _ in range(num_inputs - 1)])
|
|
||||||
|
|
||||||
self.num_inputs = num_inputs
|
self.num_inputs = num_inputs
|
||||||
self.final_weight = final_weight
|
self.final_weight = final_weight
|
||||||
@ -1000,14 +993,10 @@ class RandomCombine(torch.nn.Module):
|
|||||||
self.stddev= stddev
|
self.stddev= stddev
|
||||||
|
|
||||||
self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item()
|
self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item()
|
||||||
self._reset_parameters()
|
|
||||||
|
|
||||||
def _reset_parameters(self):
|
|
||||||
for i in range(len(self.linear)):
|
|
||||||
nn.init.eye_(self.linear[i].weight)
|
|
||||||
nn.init.constant_(self.linear[i].bias, 0.0)
|
|
||||||
|
|
||||||
def forward(self, inputs: Sequence[Tensor]) -> Tensor:
|
def forward(self, inputs: Sequence[Tensor],
|
||||||
|
warmup_mode: bool) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Forward function.
|
Forward function.
|
||||||
Args:
|
Args:
|
||||||
@ -1019,24 +1008,18 @@ class RandomCombine(torch.nn.Module):
|
|||||||
"""
|
"""
|
||||||
num_inputs = self.num_inputs
|
num_inputs = self.num_inputs
|
||||||
assert len(inputs) == num_inputs
|
assert len(inputs) == num_inputs
|
||||||
if not self.training:
|
if not (self.training and warmup_mode):
|
||||||
return inputs[-1]
|
return inputs[-1]
|
||||||
|
|
||||||
# Shape of weights: (*, num_inputs)
|
# Shape of weights: (*, num_inputs)
|
||||||
num_channels = inputs[0].shape[-1]
|
num_channels = inputs[0].shape[-1]
|
||||||
num_frames = inputs[0].numel() // num_channels
|
num_frames = inputs[0].numel() // num_channels
|
||||||
|
|
||||||
mod_inputs = []
|
|
||||||
for i in range(num_inputs - 1):
|
|
||||||
mod_inputs.append(self.linear[i](inputs[i]))
|
|
||||||
mod_inputs.append(inputs[num_inputs - 1])
|
|
||||||
|
|
||||||
|
|
||||||
ndim = inputs[0].ndim
|
ndim = inputs[0].ndim
|
||||||
# stacked_inputs: (num_frames, num_channels, num_inputs)
|
# stacked_inputs: (num_frames, num_channels, num_inputs)
|
||||||
stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape((num_frames,
|
stacked_inputs = torch.stack(inputs, dim=ndim).reshape((num_frames,
|
||||||
num_channels,
|
num_channels,
|
||||||
num_inputs))
|
num_inputs))
|
||||||
|
|
||||||
# weights: (num_frames, num_inputs)
|
# weights: (num_frames, num_inputs)
|
||||||
weights = self._get_random_weights(inputs[0].dtype, inputs[0].device,
|
weights = self._get_random_weights(inputs[0].dtype, inputs[0].device,
|
||||||
@ -1118,12 +1101,14 @@ def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
|
|||||||
print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}")
|
print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}")
|
||||||
num_inputs = 3
|
num_inputs = 3
|
||||||
num_channels = 50
|
num_channels = 50
|
||||||
m = RandomCombine(num_inputs=num_inputs, num_channels=num_channels,
|
m = RandomCombine(num_inputs=num_inputs,
|
||||||
final_weight=final_weight, pure_prob=pure_prob, stddev=stddev)
|
final_weight=final_weight,
|
||||||
|
pure_prob=pure_prob,
|
||||||
|
stddev=stddev)
|
||||||
|
|
||||||
x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ]
|
x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ]
|
||||||
|
|
||||||
y = m(x)
|
y = m(x, True)
|
||||||
assert y.shape == x[0].shape
|
assert y.shape == x[0].shape
|
||||||
assert torch.allclose(y, x[0]) # .. since actually all ones.
|
assert torch.allclose(y, x[0]) # .. since actually all ones.
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ import torch.nn as nn
|
|||||||
|
|
||||||
class EncoderInterface(nn.Module):
|
class EncoderInterface(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor
|
self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -32,6 +32,8 @@ class EncoderInterface(nn.Module):
|
|||||||
x_lens:
|
x_lens:
|
||||||
A tensor of shape (batch_size,) containing the number of frames
|
A tensor of shape (batch_size,) containing the number of frames
|
||||||
in `x` before padding.
|
in `x` before padding.
|
||||||
|
warmup_mode: for training only, if true then train in
|
||||||
|
"warmup mode" (use this for the first few thousand minibatches).
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple containing two tensors:
|
Return a tuple containing two tensors:
|
||||||
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
|
- encoder_out, a tensor of (batch_size, out_seq_len, output_dim)
|
||||||
|
@ -62,6 +62,7 @@ class Transducer(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
y: k2.RaggedTensor,
|
y: k2.RaggedTensor,
|
||||||
|
warmup_mode: bool = False
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -82,7 +83,7 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||||
|
|
||||||
encoder_out, x_lens = self.encoder(x, x_lens)
|
encoder_out, x_lens = self.encoder(x, x_lens, warmup_mode)
|
||||||
assert torch.all(x_lens > 0)
|
assert torch.all(x_lens > 0)
|
||||||
|
|
||||||
# Now for the decoder, i.e., the prediction network
|
# Now for the decoder, i.e., the prediction network
|
||||||
|
@ -110,7 +110,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2",
|
default="transducer_stateless/randcombine1_expscale3_rework2c_maxabs1000_maxp0.95_noexp_convderiv2warmup",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
@ -203,6 +203,7 @@ def get_params() -> AttributeDict:
|
|||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 3000, # For the 100h subset, use 800
|
"valid_interval": 3000, # For the 100h subset, use 800
|
||||||
|
"warmup_minibatches": 3000, # use warmup mode for 3k minibatches.
|
||||||
# parameters for conformer
|
# parameters for conformer
|
||||||
"feature_dim": 80,
|
"feature_dim": 80,
|
||||||
"encoder_out_dim": 512,
|
"encoder_out_dim": 512,
|
||||||
@ -360,6 +361,7 @@ def compute_loss(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
|
is_warmup_mode: bool = False
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute CTC loss given the model and its inputs.
|
||||||
@ -391,7 +393,8 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
loss = model(x=feature, x_lens=feature_lens, y=y)
|
loss = model(x=feature, x_lens=feature_lens, y=y,
|
||||||
|
warmup_mode=is_warmup_mode)
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
@ -423,6 +426,7 @@ def compute_validation_loss(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=False,
|
is_training=False,
|
||||||
|
is_warmup_mode=False
|
||||||
)
|
)
|
||||||
assert loss.requires_grad is False
|
assert loss.requires_grad is False
|
||||||
tot_loss = tot_loss + loss_info
|
tot_loss = tot_loss + loss_info
|
||||||
@ -484,6 +488,7 @@ def train_one_epoch(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
|
is_warmup_mode=(params.batch_idx_train<params.warmup_minibatches)
|
||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
@ -498,7 +503,6 @@ def train_one_epoch(
|
|||||||
if params.print_diagnostics and batch_idx == 5:
|
if params.print_diagnostics and batch_idx == 5:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
@ -715,6 +719,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
|
is_warmup_mode=False
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user