mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Some changes to try to reduce mem consumption; decrease batch size
This commit is contained in:
parent
625e39fd1a
commit
137ac513bf
@ -541,16 +541,14 @@ class SubformerEncoderLayer(nn.Module):
|
||||
attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
||||
|
||||
if True:
|
||||
selected_attn_weights = attn_weights[0:2]
|
||||
selected_attn_weights = attn_weights[0:1]
|
||||
if random.random() < float(self.const_attention_rate):
|
||||
# Make attention weights constant. The intention is to
|
||||
# encourage these modules to do something similar to an
|
||||
# averaging-over-time operation.
|
||||
# only need the mask, can just use the 1st one and expand later
|
||||
selected_attn_weights = selected_attn_weights[0:1]
|
||||
selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype)
|
||||
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
||||
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
|
||||
|
||||
|
||||
na = self.balancer_na(self.nonlin_attention(src,
|
||||
@ -1383,6 +1381,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
|
||||
self.score_penalty = AbsValuePenalizer(
|
||||
limit=25.0, penalty=1.0e-04, prob=0.1)
|
||||
self.name = None # for diagnostics, will be set in train.py
|
||||
|
||||
key_head_dim = query_head_dim
|
||||
in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads
|
||||
@ -1622,7 +1621,7 @@ class MultiheadAttentionWeights(nn.Module):
|
||||
self.dropout = dropout
|
||||
self.score_penalty = AbsValuePenalizer(
|
||||
limit=25.0, penalty=1.0e-04, prob=0.1)
|
||||
|
||||
self.name = None # for diagnostics, will be set in train.py
|
||||
|
||||
# the initial_scale is supposed to take over the "scaling" factor of
|
||||
# head_dim ** -0.5 that has been used in previous forms of attention,
|
||||
|
@ -404,7 +404,7 @@ def get_params() -> AttributeDict:
|
||||
"warm_step": 2000,
|
||||
"env_info": get_env_info(),
|
||||
"bytes_per_segment": 2048,
|
||||
"batch_size": 16,
|
||||
"batch_size": 14,
|
||||
"train_file_list": "train.txt",
|
||||
"valid_file_list": "valid.txt",
|
||||
"num_workers": 4,
|
||||
|
@ -253,8 +253,10 @@ class CutoffEstimator:
|
||||
|
||||
class SoftmaxFunction(torch.autograd.Function):
|
||||
"""
|
||||
Tries to handle half-precision derivatives in a randomized way that should
|
||||
be more accurate for training than the default behavior.
|
||||
An overloaded version of backprop for softmax that tries to save memory by
|
||||
creating fp16 output (the default version of softmax creates fp32 output).
|
||||
We convert back to fp32 internally during the backward pass, to minimize
|
||||
roundoff error.
|
||||
"""
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, dim: int):
|
||||
@ -272,13 +274,30 @@ class SoftmaxFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def backward(ctx, ans_grad: Tensor):
|
||||
ans, = ctx.saved_tensors
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
ans_grad = ans_grad.to(torch.float32)
|
||||
ans = ans.to(torch.float32)
|
||||
x_grad = ans_grad * ans
|
||||
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
|
||||
x_grad -= ans
|
||||
return x_grad, None
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
if ans.dtype != torch.float32:
|
||||
ans = ans.to(torch.float32)
|
||||
x_grad = ans.mul_(ans_grad.to(torch.float32))
|
||||
else:
|
||||
# out-of-place since it's not a copy
|
||||
x_grad = ans_grad.to(torch.float32) * ans
|
||||
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
|
||||
x_grad -= ans
|
||||
return x_grad, None
|
||||
except Exception as e:
|
||||
logging.info(f"Caught exception in SoftmaxFunction backward: {e}, size={list(ans.shape)}, dim={ctx.dim}, will try in half precision.")
|
||||
x_grad = None
|
||||
|
||||
|
||||
ans, = ctx.saved_tensors
|
||||
ans_grad.mul_(ans)
|
||||
x_grad = ans_grad
|
||||
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
|
||||
x_grad -= ans
|
||||
return x_grad, None
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user