Some changes to try to reduce mem consumption; decrease batch size

This commit is contained in:
Daniel Povey 2023-05-28 21:50:34 +08:00
parent 625e39fd1a
commit 137ac513bf
3 changed files with 32 additions and 14 deletions

View File

@ -541,16 +541,14 @@ class SubformerEncoderLayer(nn.Module):
attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
if True:
selected_attn_weights = attn_weights[0:2]
selected_attn_weights = attn_weights[0:1]
if random.random() < float(self.const_attention_rate):
# Make attention weights constant. The intention is to
# encourage these modules to do something similar to an
# averaging-over-time operation.
# only need the mask, can just use the 1st one and expand later
selected_attn_weights = selected_attn_weights[0:1]
selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype)
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
na = self.balancer_na(self.nonlin_attention(src,
@ -1383,6 +1381,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
self.score_penalty = AbsValuePenalizer(
limit=25.0, penalty=1.0e-04, prob=0.1)
self.name = None # for diagnostics, will be set in train.py
key_head_dim = query_head_dim
in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads
@ -1622,7 +1621,7 @@ class MultiheadAttentionWeights(nn.Module):
self.dropout = dropout
self.score_penalty = AbsValuePenalizer(
limit=25.0, penalty=1.0e-04, prob=0.1)
self.name = None # for diagnostics, will be set in train.py
# the initial_scale is supposed to take over the "scaling" factor of
# head_dim ** -0.5 that has been used in previous forms of attention,

View File

@ -404,7 +404,7 @@ def get_params() -> AttributeDict:
"warm_step": 2000,
"env_info": get_env_info(),
"bytes_per_segment": 2048,
"batch_size": 16,
"batch_size": 14,
"train_file_list": "train.txt",
"valid_file_list": "valid.txt",
"num_workers": 4,

View File

@ -253,8 +253,10 @@ class CutoffEstimator:
class SoftmaxFunction(torch.autograd.Function):
"""
Tries to handle half-precision derivatives in a randomized way that should
be more accurate for training than the default behavior.
An overloaded version of backprop for softmax that tries to save memory by
creating fp16 output (the default version of softmax creates fp32 output).
We convert back to fp32 internally during the backward pass, to minimize
roundoff error.
"""
@staticmethod
def forward(ctx, x: Tensor, dim: int):
@ -272,13 +274,30 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, ans_grad: Tensor):
ans, = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False):
ans_grad = ans_grad.to(torch.float32)
ans = ans.to(torch.float32)
x_grad = ans_grad * ans
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
x_grad -= ans
return x_grad, None
try:
with torch.cuda.amp.autocast(enabled=False):
if ans.dtype != torch.float32:
ans = ans.to(torch.float32)
x_grad = ans.mul_(ans_grad.to(torch.float32))
else:
# out-of-place since it's not a copy
x_grad = ans_grad.to(torch.float32) * ans
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
x_grad -= ans
return x_grad, None
except Exception as e:
logging.info(f"Caught exception in SoftmaxFunction backward: {e}, size={list(ans.shape)}, dim={ctx.dim}, will try in half precision.")
x_grad = None
ans, = ctx.saved_tensors
ans_grad.mul_(ans)
x_grad = ans_grad
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
x_grad -= ans
return x_grad, None