Some changes to try to reduce mem consumption; decrease batch size

This commit is contained in:
Daniel Povey 2023-05-28 21:50:34 +08:00
parent 625e39fd1a
commit 137ac513bf
3 changed files with 32 additions and 14 deletions

View File

@ -541,16 +541,14 @@ class SubformerEncoderLayer(nn.Module):
attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
if True: if True:
selected_attn_weights = attn_weights[0:2] selected_attn_weights = attn_weights[0:1]
if random.random() < float(self.const_attention_rate): if random.random() < float(self.const_attention_rate):
# Make attention weights constant. The intention is to # Make attention weights constant. The intention is to
# encourage these modules to do something similar to an # encourage these modules to do something similar to an
# averaging-over-time operation. # averaging-over-time operation.
# only need the mask, can just use the 1st one and expand later # only need the mask, can just use the 1st one and expand later
selected_attn_weights = selected_attn_weights[0:1]
selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype) selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype)
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True)) selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
na = self.balancer_na(self.nonlin_attention(src, na = self.balancer_na(self.nonlin_attention(src,
@ -1383,6 +1381,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate) self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
self.score_penalty = AbsValuePenalizer( self.score_penalty = AbsValuePenalizer(
limit=25.0, penalty=1.0e-04, prob=0.1) limit=25.0, penalty=1.0e-04, prob=0.1)
self.name = None # for diagnostics, will be set in train.py
key_head_dim = query_head_dim key_head_dim = query_head_dim
in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads
@ -1622,7 +1621,7 @@ class MultiheadAttentionWeights(nn.Module):
self.dropout = dropout self.dropout = dropout
self.score_penalty = AbsValuePenalizer( self.score_penalty = AbsValuePenalizer(
limit=25.0, penalty=1.0e-04, prob=0.1) limit=25.0, penalty=1.0e-04, prob=0.1)
self.name = None # for diagnostics, will be set in train.py
# the initial_scale is supposed to take over the "scaling" factor of # the initial_scale is supposed to take over the "scaling" factor of
# head_dim ** -0.5 that has been used in previous forms of attention, # head_dim ** -0.5 that has been used in previous forms of attention,

View File

@ -404,7 +404,7 @@ def get_params() -> AttributeDict:
"warm_step": 2000, "warm_step": 2000,
"env_info": get_env_info(), "env_info": get_env_info(),
"bytes_per_segment": 2048, "bytes_per_segment": 2048,
"batch_size": 16, "batch_size": 14,
"train_file_list": "train.txt", "train_file_list": "train.txt",
"valid_file_list": "valid.txt", "valid_file_list": "valid.txt",
"num_workers": 4, "num_workers": 4,

View File

@ -253,8 +253,10 @@ class CutoffEstimator:
class SoftmaxFunction(torch.autograd.Function): class SoftmaxFunction(torch.autograd.Function):
""" """
Tries to handle half-precision derivatives in a randomized way that should An overloaded version of backprop for softmax that tries to save memory by
be more accurate for training than the default behavior. creating fp16 output (the default version of softmax creates fp32 output).
We convert back to fp32 internally during the backward pass, to minimize
roundoff error.
""" """
@staticmethod @staticmethod
def forward(ctx, x: Tensor, dim: int): def forward(ctx, x: Tensor, dim: int):
@ -272,13 +274,30 @@ class SoftmaxFunction(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, ans_grad: Tensor): def backward(ctx, ans_grad: Tensor):
ans, = ctx.saved_tensors ans, = ctx.saved_tensors
with torch.cuda.amp.autocast(enabled=False): try:
ans_grad = ans_grad.to(torch.float32) with torch.cuda.amp.autocast(enabled=False):
ans = ans.to(torch.float32) if ans.dtype != torch.float32:
x_grad = ans_grad * ans ans = ans.to(torch.float32)
ans *= x_grad.sum(dim=ctx.dim, keepdim=True) x_grad = ans.mul_(ans_grad.to(torch.float32))
x_grad -= ans else:
return x_grad, None # out-of-place since it's not a copy
x_grad = ans_grad.to(torch.float32) * ans
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
x_grad -= ans
return x_grad, None
except Exception as e:
logging.info(f"Caught exception in SoftmaxFunction backward: {e}, size={list(ans.shape)}, dim={ctx.dim}, will try in half precision.")
x_grad = None
ans, = ctx.saved_tensors
ans_grad.mul_(ans)
x_grad = ans_grad
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
x_grad -= ans
return x_grad, None