mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Some changes to try to reduce mem consumption; decrease batch size
This commit is contained in:
parent
625e39fd1a
commit
137ac513bf
@ -541,16 +541,14 @@ class SubformerEncoderLayer(nn.Module):
|
|||||||
attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
||||||
|
|
||||||
if True:
|
if True:
|
||||||
selected_attn_weights = attn_weights[0:2]
|
selected_attn_weights = attn_weights[0:1]
|
||||||
if random.random() < float(self.const_attention_rate):
|
if random.random() < float(self.const_attention_rate):
|
||||||
# Make attention weights constant. The intention is to
|
# Make attention weights constant. The intention is to
|
||||||
# encourage these modules to do something similar to an
|
# encourage these modules to do something similar to an
|
||||||
# averaging-over-time operation.
|
# averaging-over-time operation.
|
||||||
# only need the mask, can just use the 1st one and expand later
|
# only need the mask, can just use the 1st one and expand later
|
||||||
selected_attn_weights = selected_attn_weights[0:1]
|
|
||||||
selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype)
|
selected_attn_weights = (selected_attn_weights > 0.0).to(selected_attn_weights.dtype)
|
||||||
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
selected_attn_weights = selected_attn_weights * (1.0 / selected_attn_weights.sum(dim=-1, keepdim=True))
|
||||||
selected_attn_weights = selected_attn_weights.expand(2, -1, -1, -1)
|
|
||||||
|
|
||||||
|
|
||||||
na = self.balancer_na(self.nonlin_attention(src,
|
na = self.balancer_na(self.nonlin_attention(src,
|
||||||
@ -1383,6 +1381,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
|
self.pos_emb_skip_rate = copy.deepcopy(pos_emb_skip_rate)
|
||||||
self.score_penalty = AbsValuePenalizer(
|
self.score_penalty = AbsValuePenalizer(
|
||||||
limit=25.0, penalty=1.0e-04, prob=0.1)
|
limit=25.0, penalty=1.0e-04, prob=0.1)
|
||||||
|
self.name = None # for diagnostics, will be set in train.py
|
||||||
|
|
||||||
key_head_dim = query_head_dim
|
key_head_dim = query_head_dim
|
||||||
in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads
|
in_proj_dim = (query_head_dim + key_head_dim + pos_dim) * num_heads
|
||||||
@ -1622,7 +1621,7 @@ class MultiheadAttentionWeights(nn.Module):
|
|||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.score_penalty = AbsValuePenalizer(
|
self.score_penalty = AbsValuePenalizer(
|
||||||
limit=25.0, penalty=1.0e-04, prob=0.1)
|
limit=25.0, penalty=1.0e-04, prob=0.1)
|
||||||
|
self.name = None # for diagnostics, will be set in train.py
|
||||||
|
|
||||||
# the initial_scale is supposed to take over the "scaling" factor of
|
# the initial_scale is supposed to take over the "scaling" factor of
|
||||||
# head_dim ** -0.5 that has been used in previous forms of attention,
|
# head_dim ** -0.5 that has been used in previous forms of attention,
|
||||||
|
|||||||
@ -404,7 +404,7 @@ def get_params() -> AttributeDict:
|
|||||||
"warm_step": 2000,
|
"warm_step": 2000,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
"bytes_per_segment": 2048,
|
"bytes_per_segment": 2048,
|
||||||
"batch_size": 16,
|
"batch_size": 14,
|
||||||
"train_file_list": "train.txt",
|
"train_file_list": "train.txt",
|
||||||
"valid_file_list": "valid.txt",
|
"valid_file_list": "valid.txt",
|
||||||
"num_workers": 4,
|
"num_workers": 4,
|
||||||
|
|||||||
@ -253,8 +253,10 @@ class CutoffEstimator:
|
|||||||
|
|
||||||
class SoftmaxFunction(torch.autograd.Function):
|
class SoftmaxFunction(torch.autograd.Function):
|
||||||
"""
|
"""
|
||||||
Tries to handle half-precision derivatives in a randomized way that should
|
An overloaded version of backprop for softmax that tries to save memory by
|
||||||
be more accurate for training than the default behavior.
|
creating fp16 output (the default version of softmax creates fp32 output).
|
||||||
|
We convert back to fp32 internally during the backward pass, to minimize
|
||||||
|
roundoff error.
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor, dim: int):
|
def forward(ctx, x: Tensor, dim: int):
|
||||||
@ -272,13 +274,30 @@ class SoftmaxFunction(torch.autograd.Function):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, ans_grad: Tensor):
|
def backward(ctx, ans_grad: Tensor):
|
||||||
ans, = ctx.saved_tensors
|
ans, = ctx.saved_tensors
|
||||||
with torch.cuda.amp.autocast(enabled=False):
|
try:
|
||||||
ans_grad = ans_grad.to(torch.float32)
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
ans = ans.to(torch.float32)
|
if ans.dtype != torch.float32:
|
||||||
x_grad = ans_grad * ans
|
ans = ans.to(torch.float32)
|
||||||
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
|
x_grad = ans.mul_(ans_grad.to(torch.float32))
|
||||||
x_grad -= ans
|
else:
|
||||||
return x_grad, None
|
# out-of-place since it's not a copy
|
||||||
|
x_grad = ans_grad.to(torch.float32) * ans
|
||||||
|
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
|
||||||
|
x_grad -= ans
|
||||||
|
return x_grad, None
|
||||||
|
except Exception as e:
|
||||||
|
logging.info(f"Caught exception in SoftmaxFunction backward: {e}, size={list(ans.shape)}, dim={ctx.dim}, will try in half precision.")
|
||||||
|
x_grad = None
|
||||||
|
|
||||||
|
|
||||||
|
ans, = ctx.saved_tensors
|
||||||
|
ans_grad.mul_(ans)
|
||||||
|
x_grad = ans_grad
|
||||||
|
ans *= x_grad.sum(dim=ctx.dim, keepdim=True)
|
||||||
|
x_grad -= ans
|
||||||
|
return x_grad, None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user