mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
Change how warmup works.
This commit is contained in:
parent
cef6348703
commit
9a8aa1f54a
@ -88,7 +88,7 @@ class Conformer(Transformer):
|
|||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup_mode: bool = False
|
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -97,6 +97,10 @@ class Conformer(Transformer):
|
|||||||
x_lens:
|
x_lens:
|
||||||
A tensor of shape (batch_size,) containing the number of frames in
|
A tensor of shape (batch_size,) containing the number of frames in
|
||||||
`x` before padding.
|
`x` before padding.
|
||||||
|
warmup:
|
||||||
|
A floating point value that gradually increases from 0 throughout
|
||||||
|
training; when it is >= 1.0 we are "fully warmed up". It is used
|
||||||
|
to turn modules on sequentially.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple containing 2 tensors:
|
Return a tuple containing 2 tensors:
|
||||||
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
- logits, its shape is (batch_size, output_seq_len, output_dim)
|
||||||
@ -113,7 +117,7 @@ class Conformer(Transformer):
|
|||||||
mask = make_pad_mask(lengths)
|
mask = make_pad_mask(lengths)
|
||||||
|
|
||||||
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
|
x = self.encoder(x, pos_emb, src_key_padding_mask=mask,
|
||||||
warmup_mode=warmup_mode) # (T, N, C)
|
warmup=warmup) # (T, N, C)
|
||||||
|
|
||||||
logits = self.encoder_output_layer(x)
|
logits = self.encoder_output_layer(x)
|
||||||
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
logits = logits.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
@ -193,6 +197,8 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
src_mask: Optional[Tensor] = None,
|
src_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
warmup: float = 1.0,
|
||||||
|
position: float = 0.0
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Pass the input through the encoder layer.
|
Pass the input through the encoder layer.
|
||||||
@ -202,6 +208,11 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
pos_emb: Positional embedding tensor (required).
|
pos_emb: Positional embedding tensor (required).
|
||||||
src_mask: the mask for the src sequence (optional).
|
src_mask: the mask for the src sequence (optional).
|
||||||
src_key_padding_mask: the mask for the src keys per batch (optional).
|
src_key_padding_mask: the mask for the src keys per batch (optional).
|
||||||
|
warmup: controls selective activation of layers; if < 1.0, it's possible that
|
||||||
|
not all modules will be included.
|
||||||
|
position: the position of this module in the encoder stack (relates to
|
||||||
|
warmup); a value 0 <= position < 1.0.
|
||||||
|
|
||||||
|
|
||||||
Shape:
|
Shape:
|
||||||
src: (S, N, E).
|
src: (S, N, E).
|
||||||
@ -210,9 +221,9 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
src_key_padding_mask: (N, S).
|
src_key_padding_mask: (N, S).
|
||||||
S is the source sequence length, N is the batch size, E is the feature number
|
S is the source sequence length, N is the batch size, E is the feature number
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# macaron style feed forward module
|
# macaron style feed forward module
|
||||||
src = src + self.dropout(self.feed_forward_macaron(src))
|
src = torch.add(src, self.dropout(self.feed_forward_macaron(src)),
|
||||||
|
alpha=(0.0 if warmup < 0.2 * (position + 1) else 1.0))
|
||||||
|
|
||||||
|
|
||||||
# multi-headed self-attention module
|
# multi-headed self-attention module
|
||||||
@ -224,13 +235,16 @@ class ConformerEncoderLayer(nn.Module):
|
|||||||
attn_mask=src_mask,
|
attn_mask=src_mask,
|
||||||
key_padding_mask=src_key_padding_mask,
|
key_padding_mask=src_key_padding_mask,
|
||||||
)[0]
|
)[0]
|
||||||
src = src + self.dropout(src_att)
|
src = torch.add(src, self.dropout(src_att),
|
||||||
|
alpha=(0.0 if warmup < 0.2 * (position + 2) else 1.0))
|
||||||
|
|
||||||
# convolution module
|
# convolution module
|
||||||
src = src + self.dropout(self.conv_module(src))
|
src = torch.add(src, self.dropout(self.conv_module(src)),
|
||||||
|
alpha=(0.0 if warmup < 0.2 * (position + 3) else 1.0))
|
||||||
|
|
||||||
# feed forward module
|
# feed forward module
|
||||||
src = src + self.dropout(self.feed_forward(src))
|
src = torch.add(src, self.dropout(self.feed_forward(src)),
|
||||||
|
alpha=(0.0 if warmup < 0.2 * (position + 4) else 1.0))
|
||||||
|
|
||||||
src = self.norm_final(self.balancer(src))
|
src = self.norm_final(self.balancer(src))
|
||||||
|
|
||||||
@ -262,10 +276,6 @@ class ConformerEncoder(nn.Module):
|
|||||||
assert num_layers - 1 not in aux_layers
|
assert num_layers - 1 not in aux_layers
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
num_channels = encoder_layer.d_model
|
num_channels = encoder_layer.d_model
|
||||||
self.combiner = RandomCombine(num_inputs=len(self.aux_layers),
|
|
||||||
final_weight=0.5,
|
|
||||||
pure_prob=0.333,
|
|
||||||
stddev=2.0)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -273,7 +283,7 @@ class ConformerEncoder(nn.Module):
|
|||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
mask: Optional[Tensor] = None,
|
mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
warmup_mode: bool = False
|
warmup: float = 1.0
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
@ -293,7 +303,7 @@ class ConformerEncoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
outputs = []
|
num_layers = len(self.layers)
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
for i, mod in enumerate(self.layers):
|
||||||
output = mod(
|
output = mod(
|
||||||
@ -301,11 +311,10 @@ class ConformerEncoder(nn.Module):
|
|||||||
pos_emb,
|
pos_emb,
|
||||||
src_mask=mask,
|
src_mask=mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
warmup=warmup,
|
||||||
|
position=(i / num_layers),
|
||||||
)
|
)
|
||||||
if i in self.aux_layers:
|
|
||||||
outputs.append(output)
|
|
||||||
|
|
||||||
output = self.combiner(outputs, warmup_mode)
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@ -922,187 +931,9 @@ class Identity(torch.nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class RandomCombine(torch.nn.Module):
|
|
||||||
"""
|
|
||||||
This module combines a list of Tensors, all with the same shape, to
|
|
||||||
produce a single output of that same shape which, in training time,
|
|
||||||
is a random combination of all the inputs; but which in test time
|
|
||||||
will be just the last input.
|
|
||||||
|
|
||||||
The idea is that the list of Tensors will be a list of outputs of multiple
|
|
||||||
conformer layers. This has a similar effect as iterated loss. (See:
|
|
||||||
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
|
|
||||||
NETWORKS).
|
|
||||||
"""
|
|
||||||
def __init__(self, num_inputs: int,
|
|
||||||
final_weight: float = 0.5,
|
|
||||||
pure_prob: float = 0.5,
|
|
||||||
stddev: float = 2.0) -> None:
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
num_inputs: The number of tensor inputs, which equals the number of layers'
|
|
||||||
outputs that are fed into this module. E.g. in an 18-layer neural
|
|
||||||
net if we output layers 16, 12, 18, num_inputs would be 3.
|
|
||||||
final_weight: The amount of weight or probability we assign to the
|
|
||||||
final layer when randomly choosing layers or when choosing
|
|
||||||
continuous layer weights.
|
|
||||||
pure_prob: The probability, on each frame, with which we choose
|
|
||||||
only a single layer to output (rather than an interpolation)
|
|
||||||
stddev: A standard deviation that we add to log-probs for computing
|
|
||||||
randomized weights.
|
|
||||||
|
|
||||||
The method of choosing which layers,
|
|
||||||
or combinations of layers, to use, is conceptually as follows.
|
|
||||||
With probability `pure_prob`:
|
|
||||||
With probability `final_weight`: choose final layer,
|
|
||||||
Else: choose random non-final layer.
|
|
||||||
Else:
|
|
||||||
Choose initial log-weights that correspond to assigning
|
|
||||||
weight `final_weight` to the final layer and equal
|
|
||||||
weights to other layers; then add Gaussian noise
|
|
||||||
with variance `stddev` to these log-weights, and normalize
|
|
||||||
to weights (note: the average weight assigned to the
|
|
||||||
final layer here will not be `final_weight` if stddev>0).
|
|
||||||
"""
|
|
||||||
super(RandomCombine, self).__init__()
|
|
||||||
assert pure_prob >= 0 and pure_prob <= 1
|
|
||||||
assert final_weight > 0 and final_weight < 1
|
|
||||||
assert num_inputs >= 1
|
|
||||||
|
|
||||||
self.num_inputs = num_inputs
|
|
||||||
self.final_weight = final_weight
|
|
||||||
self.pure_prob = pure_prob
|
|
||||||
self.stddev= stddev
|
|
||||||
|
|
||||||
self.final_log_weight = torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)).log().item()
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, inputs: Sequence[Tensor],
|
|
||||||
warmup_mode: bool) -> Tensor:
|
|
||||||
"""
|
|
||||||
Forward function.
|
|
||||||
Args:
|
|
||||||
inputs: a list of Tensor, e.g. from various layers of a transformer.
|
|
||||||
All must be the same shape, of (*, num_channels)
|
|
||||||
Returns:
|
|
||||||
a Tensor of shape (*, num_channels). In test mode
|
|
||||||
this is just the final input.
|
|
||||||
"""
|
|
||||||
num_inputs = self.num_inputs
|
|
||||||
assert len(inputs) == num_inputs
|
|
||||||
if not (self.training and warmup_mode):
|
|
||||||
return inputs[-1]
|
|
||||||
|
|
||||||
# Shape of weights: (*, num_inputs)
|
|
||||||
num_channels = inputs[0].shape[-1]
|
|
||||||
num_frames = inputs[0].numel() // num_channels
|
|
||||||
|
|
||||||
ndim = inputs[0].ndim
|
|
||||||
# stacked_inputs: (num_frames, num_channels, num_inputs)
|
|
||||||
stacked_inputs = torch.stack(inputs, dim=ndim).reshape((num_frames,
|
|
||||||
num_channels,
|
|
||||||
num_inputs))
|
|
||||||
|
|
||||||
# weights: (num_frames, num_inputs)
|
|
||||||
weights = self._get_random_weights(inputs[0].dtype, inputs[0].device,
|
|
||||||
num_frames)
|
|
||||||
|
|
||||||
weights = weights.reshape(num_frames, num_inputs, 1)
|
|
||||||
# ans: (num_frames, num_channels, 1)
|
|
||||||
ans = torch.matmul(stacked_inputs, weights)
|
|
||||||
# ans: (*, num_channels)
|
|
||||||
ans = ans.reshape(*tuple(inputs[0].shape[:-1]), num_channels)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# for testing only...
|
|
||||||
print("Weights = ", weights.reshape(num_frames, num_inputs))
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def _get_random_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int) -> Tensor:
|
|
||||||
"""
|
|
||||||
Return a tensor of random weights, of shape (num_frames, self.num_inputs),
|
|
||||||
Args:
|
|
||||||
dtype: the data-type desired for the answer, e.g. float, double
|
|
||||||
device: the device needed for the answer
|
|
||||||
num_frames: the number of sets of weights desired
|
|
||||||
Returns: a tensor of shape (num_frames, self.num_inputs), such that
|
|
||||||
ans.sum(dim=1) is all ones.
|
|
||||||
|
|
||||||
"""
|
|
||||||
pure_prob = self.pure_prob
|
|
||||||
if pure_prob == 0.0:
|
|
||||||
return self._get_random_mixed_weights(dtype, device, num_frames)
|
|
||||||
elif pure_prob == 1.0:
|
|
||||||
return self._get_random_pure_weights(dtype, device, num_frames)
|
|
||||||
else:
|
|
||||||
p = self._get_random_pure_weights(dtype, device, num_frames)
|
|
||||||
m = self._get_random_mixed_weights(dtype, device, num_frames)
|
|
||||||
return torch.where(torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m)
|
|
||||||
|
|
||||||
def _get_random_pure_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int):
|
|
||||||
"""
|
|
||||||
Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs),
|
|
||||||
Args:
|
|
||||||
dtype: the data-type desired for the answer, e.g. float, double
|
|
||||||
device: the device needed for the answer
|
|
||||||
num_frames: the number of sets of weights desired
|
|
||||||
Returns: a one-hot tensor of shape (num_frames, self.num_inputs), with
|
|
||||||
exactly one weight equal to 1.0 on each frame.
|
|
||||||
"""
|
|
||||||
|
|
||||||
final_prob = self.final_weight
|
|
||||||
|
|
||||||
# final contains self.num_inputs - 1 in all elements
|
|
||||||
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
|
|
||||||
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights.
|
|
||||||
nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device)
|
|
||||||
|
|
||||||
indexes = torch.where(torch.rand(num_frames, device=device) < final_prob,
|
|
||||||
final, nonfinal)
|
|
||||||
ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to(dtype=dtype)
|
|
||||||
return ans
|
|
||||||
|
|
||||||
|
|
||||||
def _get_random_mixed_weights(self, dtype: torch.dtype, device: torch.device, num_frames: int):
|
|
||||||
"""
|
|
||||||
Return a tensor of random one-hot weights, of shape (num_frames, self.num_inputs),
|
|
||||||
Args:
|
|
||||||
dtype: the data-type desired for the answer, e.g. float, double
|
|
||||||
device: the device needed for the answer
|
|
||||||
num_frames: the number of sets of weights desired
|
|
||||||
Returns: a tensor of shape (num_frames, self.num_inputs), which elements in [0..1] that
|
|
||||||
sum to one over the second axis, i.e. ans.sum(dim=1) is all ones.
|
|
||||||
"""
|
|
||||||
logprobs = torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) * self.stddev
|
|
||||||
logprobs[:,-1] += self.final_log_weight
|
|
||||||
return logprobs.softmax(dim=1)
|
|
||||||
|
|
||||||
|
|
||||||
def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
|
|
||||||
print(f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}")
|
|
||||||
num_inputs = 3
|
|
||||||
num_channels = 50
|
|
||||||
m = RandomCombine(num_inputs=num_inputs,
|
|
||||||
final_weight=final_weight,
|
|
||||||
pure_prob=pure_prob,
|
|
||||||
stddev=stddev)
|
|
||||||
|
|
||||||
x = [ torch.ones(3, 4, num_channels) for _ in range(num_inputs) ]
|
|
||||||
|
|
||||||
y = m(x, True)
|
|
||||||
assert y.shape == x[0].shape
|
|
||||||
assert torch.allclose(y, x[0]) # .. since actually all ones.
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
_test_random_combine(0.999, 0, 0.0)
|
|
||||||
_test_random_combine(0.5, 0, 0.0)
|
|
||||||
_test_random_combine(0.999, 0, 0.0)
|
|
||||||
_test_random_combine(0.5, 0, 0.3)
|
|
||||||
_test_random_combine(0.5, 1, 0.3)
|
|
||||||
_test_random_combine(0.5, 0.5, 0.3)
|
|
||||||
|
|
||||||
feature_dim = 50
|
feature_dim = 50
|
||||||
c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
|
c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
@ -1110,4 +941,4 @@ if __name__ == '__main__':
|
|||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
f = c(torch.randn(batch_size, seq_len, feature_dim),
|
f = c(torch.randn(batch_size, seq_len, feature_dim),
|
||||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
warmup_mode=True)
|
warmup=0.5)
|
||||||
|
@ -66,7 +66,7 @@ class Transducer(nn.Module):
|
|||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
warmup_mode: bool = False
|
warmup: float = 1.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -87,6 +87,9 @@ class Transducer(nn.Module):
|
|||||||
lm_scale:
|
lm_scale:
|
||||||
The scale to smooth the loss with lm (output of predictor network)
|
The scale to smooth the loss with lm (output of predictor network)
|
||||||
part
|
part
|
||||||
|
warmup:
|
||||||
|
A value warmup >= 0 that determines which modules are active, values
|
||||||
|
warmup > 1 "are fully warmed up" and all modules will be active.
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
|
|
||||||
@ -102,7 +105,7 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||||
|
|
||||||
encoder_out, x_lens = self.encoder(x, x_lens, warmup_mode=warmup_mode)
|
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup)
|
||||||
assert torch.all(x_lens > 0)
|
assert torch.all(x_lens > 0)
|
||||||
|
|
||||||
# Now for the decoder, i.e., the prediction network
|
# Now for the decoder, i.e., the prediction network
|
||||||
|
@ -296,7 +296,7 @@ def get_params() -> AttributeDict:
|
|||||||
"embedding_dim": 512,
|
"embedding_dim": 512,
|
||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"warm_step": 60000, # For the 100h subset, use 8k
|
"warm_step": 60000, # For the 100h subset, use 8k
|
||||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
"model_warm_step": 4000, # arg given to model, not for lrate
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -454,7 +454,7 @@ def compute_loss(
|
|||||||
sp: spm.SentencePieceProcessor,
|
sp: spm.SentencePieceProcessor,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
warmup_mode: bool = False
|
warmup: float = 1.0
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute CTC loss given the model and its inputs.
|
||||||
@ -471,6 +471,8 @@ def compute_loss(
|
|||||||
True for training. False for validation. When it is True, this
|
True for training. False for validation. When it is True, this
|
||||||
function enables autograd during computation; when it is False, it
|
function enables autograd during computation; when it is False, it
|
||||||
disables autograd.
|
disables autograd.
|
||||||
|
warmup: a floating point value which increases throughout training;
|
||||||
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
device = model.device
|
device = model.device
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
@ -493,10 +495,10 @@ def compute_loss(
|
|||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup_mode=warmup_mode,
|
warmup=warmup,
|
||||||
)
|
)
|
||||||
loss = (params.simple_loss_scale * simple_loss +
|
loss = (params.simple_loss_scale * simple_loss +
|
||||||
(pruned_loss * 0.0 if warmup_mode else pruned_loss))
|
(pruned_loss * 0.0 if warmup < 1.0 else pruned_loss))
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
@ -601,7 +603,7 @@ def train_one_epoch(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup_mode=(params.batch_idx_train < params.model_warm_step)
|
warmup=(params.batch_idx_train / params.model_warm_step)
|
||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
@ -855,7 +857,6 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
warmup_mode=True # may use slightly more memory
|
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user