mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 18:12:19 +00:00
First working version
This commit is contained in:
parent
1e986c930d
commit
256c446f06
@ -25,7 +25,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from lhotse.dataset import SpecAugment
|
||||
from scaling import ScaledLinear
|
||||
from scaling import ScaledLinear, scale_grad
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask, time_warp
|
||||
|
||||
@ -198,13 +198,6 @@ class AsrModel(nn.Module):
|
||||
|
||||
# Compute CTC log-prob
|
||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||
print(
|
||||
"ctc_output",
|
||||
ctc_output.detach().mean(),
|
||||
ctc_output.detach().sum(),
|
||||
ctc_output.detach().min(),
|
||||
ctc_output.detach().max(),
|
||||
)
|
||||
|
||||
if model_prev:
|
||||
with fork_rng(
|
||||
@ -213,18 +206,11 @@ class AsrModel(nn.Module):
|
||||
rng_state=rng_state,
|
||||
device=device,
|
||||
):
|
||||
ctc_output_prev = model_prev.ctc_output(encoder_out)
|
||||
print(
|
||||
"ctc_output_prev",
|
||||
ctc_output_prev.detach().mean(),
|
||||
ctc_output_prev.detach().sum(),
|
||||
ctc_output_prev.detach().min(),
|
||||
ctc_output_prev.detach().max(),
|
||||
)
|
||||
print(
|
||||
"isclose ctc",
|
||||
(ctc_output - ctc_output).detach().abs().max(),
|
||||
)
|
||||
ctc_output_prev = model_prev.ctc_output(encoder_out_prev)
|
||||
|
||||
has_grown = ctc_output > 0.8 * ctc_output_prev
|
||||
grad_scale_tensor = torch.where(has_grown, 0.5, 1.0)
|
||||
ctc_output = scale_grad(ctc_output, grad_scale_tensor)
|
||||
|
||||
ctc_loss = torch.nn.functional.ctc_loss(
|
||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||
@ -481,15 +467,6 @@ class AsrModel(nn.Module):
|
||||
# Compute encoder outputs
|
||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||
|
||||
print(
|
||||
"encoder_out",
|
||||
encoder_out.detach().mean(),
|
||||
encoder_out.detach().abs().max(),
|
||||
encoder_out.detach().abs().min(),
|
||||
encoder_out.detach().sum(),
|
||||
encoder_out.shape,
|
||||
)
|
||||
|
||||
if model_prev:
|
||||
with fork_rng(
|
||||
cpu_state=cpu_state,
|
||||
@ -500,19 +477,6 @@ class AsrModel(nn.Module):
|
||||
encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder(
|
||||
x, x_lens
|
||||
)
|
||||
print(
|
||||
"encoder_out_prev",
|
||||
encoder_out_prev.detach().mean(),
|
||||
encoder_out_prev.detach().abs().max(),
|
||||
encoder_out_prev.detach().abs().mean(),
|
||||
encoder_out_prev.detach().sum(),
|
||||
encoder_out_prev.shape,
|
||||
)
|
||||
print(
|
||||
"isclose",
|
||||
(encoder_out - encoder_out_prev).detach().abs().max(),
|
||||
(encoder_out_lens - encoder_out_lens_prev).detach().abs().max(),
|
||||
)
|
||||
else:
|
||||
encoder_out_prev = None
|
||||
encoder_out_lens_prev = None
|
||||
|
@ -1136,16 +1136,24 @@ def with_loss(x, y, name):
|
||||
|
||||
class ScaleGradFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, alpha: float) -> Tensor:
|
||||
ctx.alpha = alpha
|
||||
def forward(ctx, x: Tensor, alpha: Union[float, Tensor]) -> Tensor:
|
||||
if isinstance(alpha, Tensor):
|
||||
ctx.save_for_backward(alpha)
|
||||
else:
|
||||
ctx.alpha = alpha
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad: Tensor):
|
||||
return grad * ctx.alpha, None
|
||||
if hasattr(ctx, "alpha"):
|
||||
alpha = ctx.alpha
|
||||
else:
|
||||
(alpha,) = ctx.saved_tensors
|
||||
|
||||
return grad * alpha, None
|
||||
|
||||
|
||||
def scale_grad(x: Tensor, alpha: float):
|
||||
def scale_grad(x: Tensor, alpha: Union[float, Tensor]):
|
||||
return ScaleGradFunction.apply(x, alpha)
|
||||
|
||||
|
||||
|
@ -552,9 +552,15 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--limit-grad-start-batch",
|
||||
type=int,
|
||||
# default=1000,
|
||||
default=2,
|
||||
help="Limit grad starting from this batch.",
|
||||
default=1000,
|
||||
help="Enable grad limit starting from this batch. Set it to 0 to disable it",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--limit-grad-every-n-batch",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Apply grad limit every this number of batch when it is enabled",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
@ -1036,6 +1042,17 @@ def compute_validation_loss(
|
||||
return tot_loss
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def update_model_prev(model_prev, model, beta):
|
||||
# model_prev = beta * model_prev + (1-beta) * model
|
||||
model_prev_dict = model_prev.state_dict()
|
||||
model_dict = model.state_dict()
|
||||
for key in model_prev_dict:
|
||||
model_prev_dict[key].data.copy_(
|
||||
model_prev_dict[key].data * beta + model_dict[key].data * (1 - beta)
|
||||
)
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
@ -1115,13 +1132,11 @@ def train_one_epoch(
|
||||
with torch.cuda.amp.autocast(
|
||||
enabled=params.use_autocast, dtype=params.dtype
|
||||
):
|
||||
if params.batch_idx_train > params.limit_grad_start_batch:
|
||||
model_prev = copy.deepcopy(model)
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
model_prev=model_prev
|
||||
if params.batch_idx_train > params.limit_grad_start_batch
|
||||
if 0 < params.limit_grad_start_batch < params.batch_idx_train
|
||||
else None,
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
@ -1140,17 +1155,15 @@ def train_one_epoch(
|
||||
scaler.update()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if params.batch_idx_train >= params.limit_grad_start_batch:
|
||||
if (
|
||||
0 < params.limit_grad_start_batch <= params.batch_idx_train
|
||||
and params.batch_idx_train % params.limit_grad_every_n_batch == 0
|
||||
):
|
||||
if model_prev is None:
|
||||
model_prev = copy.deepcopy(model)
|
||||
else:
|
||||
model_prev = copy.deepcopy(model)
|
||||
print(
|
||||
"here",
|
||||
params.batch_idx_train,
|
||||
params.limit_grad_start_batch,
|
||||
model_prev is None,
|
||||
)
|
||||
beta = max(0.5, 1.0 - 1.0 / (0.1 * params.batch_idx_train))
|
||||
update_model_prev(model_prev=model_prev, model=model, beta=beta)
|
||||
|
||||
except Exception as e:
|
||||
logging.info(f"Caught exception: {e}.")
|
||||
@ -1221,6 +1234,7 @@ def train_one_epoch(
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}, "
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "")
|
||||
+ (f", beta: {beta}" if model_prev is not None else "")
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
@ -1622,9 +1636,4 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# torch.use_deterministic_algorithms(True, warn_only=True)
|
||||
# torch.backends.cudnn.deterministic = True
|
||||
# torch.backends.cudnn.benchmark = False
|
||||
# torch.backends.cudnn.enabled = False
|
||||
|
||||
main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user