First working version

This commit is contained in:
Fangjun Kuang 2024-10-30 21:11:07 +08:00
parent 1e986c930d
commit 256c446f06
3 changed files with 46 additions and 65 deletions

View File

@ -25,7 +25,7 @@ import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from lhotse.dataset import SpecAugment
from scaling import ScaledLinear
from scaling import ScaledLinear, scale_grad
from icefall.utils import add_sos, make_pad_mask, time_warp
@ -198,13 +198,6 @@ class AsrModel(nn.Module):
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
print(
"ctc_output",
ctc_output.detach().mean(),
ctc_output.detach().sum(),
ctc_output.detach().min(),
ctc_output.detach().max(),
)
if model_prev:
with fork_rng(
@ -213,18 +206,11 @@ class AsrModel(nn.Module):
rng_state=rng_state,
device=device,
):
ctc_output_prev = model_prev.ctc_output(encoder_out)
print(
"ctc_output_prev",
ctc_output_prev.detach().mean(),
ctc_output_prev.detach().sum(),
ctc_output_prev.detach().min(),
ctc_output_prev.detach().max(),
)
print(
"isclose ctc",
(ctc_output - ctc_output).detach().abs().max(),
)
ctc_output_prev = model_prev.ctc_output(encoder_out_prev)
has_grown = ctc_output > 0.8 * ctc_output_prev
grad_scale_tensor = torch.where(has_grown, 0.5, 1.0)
ctc_output = scale_grad(ctc_output, grad_scale_tensor)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
@ -481,15 +467,6 @@ class AsrModel(nn.Module):
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
print(
"encoder_out",
encoder_out.detach().mean(),
encoder_out.detach().abs().max(),
encoder_out.detach().abs().min(),
encoder_out.detach().sum(),
encoder_out.shape,
)
if model_prev:
with fork_rng(
cpu_state=cpu_state,
@ -500,19 +477,6 @@ class AsrModel(nn.Module):
encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder(
x, x_lens
)
print(
"encoder_out_prev",
encoder_out_prev.detach().mean(),
encoder_out_prev.detach().abs().max(),
encoder_out_prev.detach().abs().mean(),
encoder_out_prev.detach().sum(),
encoder_out_prev.shape,
)
print(
"isclose",
(encoder_out - encoder_out_prev).detach().abs().max(),
(encoder_out_lens - encoder_out_lens_prev).detach().abs().max(),
)
else:
encoder_out_prev = None
encoder_out_lens_prev = None

View File

@ -1136,16 +1136,24 @@ def with_loss(x, y, name):
class ScaleGradFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor, alpha: float) -> Tensor:
ctx.alpha = alpha
def forward(ctx, x: Tensor, alpha: Union[float, Tensor]) -> Tensor:
if isinstance(alpha, Tensor):
ctx.save_for_backward(alpha)
else:
ctx.alpha = alpha
return x
@staticmethod
def backward(ctx, grad: Tensor):
return grad * ctx.alpha, None
if hasattr(ctx, "alpha"):
alpha = ctx.alpha
else:
(alpha,) = ctx.saved_tensors
return grad * alpha, None
def scale_grad(x: Tensor, alpha: float):
def scale_grad(x: Tensor, alpha: Union[float, Tensor]):
return ScaleGradFunction.apply(x, alpha)

View File

@ -552,9 +552,15 @@ def get_parser():
parser.add_argument(
"--limit-grad-start-batch",
type=int,
# default=1000,
default=2,
help="Limit grad starting from this batch.",
default=1000,
help="Enable grad limit starting from this batch. Set it to 0 to disable it",
)
parser.add_argument(
"--limit-grad-every-n-batch",
type=int,
default=1,
help="Apply grad limit every this number of batch when it is enabled",
)
add_model_arguments(parser)
@ -1036,6 +1042,17 @@ def compute_validation_loss(
return tot_loss
@torch.inference_mode()
def update_model_prev(model_prev, model, beta):
# model_prev = beta * model_prev + (1-beta) * model
model_prev_dict = model_prev.state_dict()
model_dict = model.state_dict()
for key in model_prev_dict:
model_prev_dict[key].data.copy_(
model_prev_dict[key].data * beta + model_dict[key].data * (1 - beta)
)
def train_one_epoch(
params: AttributeDict,
model: Union[nn.Module, DDP],
@ -1115,13 +1132,11 @@ def train_one_epoch(
with torch.cuda.amp.autocast(
enabled=params.use_autocast, dtype=params.dtype
):
if params.batch_idx_train > params.limit_grad_start_batch:
model_prev = copy.deepcopy(model)
loss, loss_info = compute_loss(
params=params,
model=model,
model_prev=model_prev
if params.batch_idx_train > params.limit_grad_start_batch
if 0 < params.limit_grad_start_batch < params.batch_idx_train
else None,
sp=sp,
batch=batch,
@ -1140,17 +1155,15 @@ def train_one_epoch(
scaler.update()
optimizer.zero_grad()
if params.batch_idx_train >= params.limit_grad_start_batch:
if (
0 < params.limit_grad_start_batch <= params.batch_idx_train
and params.batch_idx_train % params.limit_grad_every_n_batch == 0
):
if model_prev is None:
model_prev = copy.deepcopy(model)
else:
model_prev = copy.deepcopy(model)
print(
"here",
params.batch_idx_train,
params.limit_grad_start_batch,
model_prev is None,
)
beta = max(0.5, 1.0 - 1.0 / (0.1 * params.batch_idx_train))
update_model_prev(model_prev=model_prev, model=model, beta=beta)
except Exception as e:
logging.info(f"Caught exception: {e}.")
@ -1221,6 +1234,7 @@ def train_one_epoch(
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "")
+ (f", beta: {beta}" if model_prev is not None else "")
)
if tb_writer is not None:
@ -1622,9 +1636,4 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
# torch.use_deterministic_algorithms(True, warn_only=True)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False
# torch.backends.cudnn.enabled = False
main()