mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 18:42:19 +00:00
First working version
This commit is contained in:
parent
1e986c930d
commit
256c446f06
@ -25,7 +25,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from lhotse.dataset import SpecAugment
|
from lhotse.dataset import SpecAugment
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear, scale_grad
|
||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask, time_warp
|
from icefall.utils import add_sos, make_pad_mask, time_warp
|
||||||
|
|
||||||
@ -198,13 +198,6 @@ class AsrModel(nn.Module):
|
|||||||
|
|
||||||
# Compute CTC log-prob
|
# Compute CTC log-prob
|
||||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||||
print(
|
|
||||||
"ctc_output",
|
|
||||||
ctc_output.detach().mean(),
|
|
||||||
ctc_output.detach().sum(),
|
|
||||||
ctc_output.detach().min(),
|
|
||||||
ctc_output.detach().max(),
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_prev:
|
if model_prev:
|
||||||
with fork_rng(
|
with fork_rng(
|
||||||
@ -213,18 +206,11 @@ class AsrModel(nn.Module):
|
|||||||
rng_state=rng_state,
|
rng_state=rng_state,
|
||||||
device=device,
|
device=device,
|
||||||
):
|
):
|
||||||
ctc_output_prev = model_prev.ctc_output(encoder_out)
|
ctc_output_prev = model_prev.ctc_output(encoder_out_prev)
|
||||||
print(
|
|
||||||
"ctc_output_prev",
|
has_grown = ctc_output > 0.8 * ctc_output_prev
|
||||||
ctc_output_prev.detach().mean(),
|
grad_scale_tensor = torch.where(has_grown, 0.5, 1.0)
|
||||||
ctc_output_prev.detach().sum(),
|
ctc_output = scale_grad(ctc_output, grad_scale_tensor)
|
||||||
ctc_output_prev.detach().min(),
|
|
||||||
ctc_output_prev.detach().max(),
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
"isclose ctc",
|
|
||||||
(ctc_output - ctc_output).detach().abs().max(),
|
|
||||||
)
|
|
||||||
|
|
||||||
ctc_loss = torch.nn.functional.ctc_loss(
|
ctc_loss = torch.nn.functional.ctc_loss(
|
||||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||||
@ -481,15 +467,6 @@ class AsrModel(nn.Module):
|
|||||||
# Compute encoder outputs
|
# Compute encoder outputs
|
||||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||||
|
|
||||||
print(
|
|
||||||
"encoder_out",
|
|
||||||
encoder_out.detach().mean(),
|
|
||||||
encoder_out.detach().abs().max(),
|
|
||||||
encoder_out.detach().abs().min(),
|
|
||||||
encoder_out.detach().sum(),
|
|
||||||
encoder_out.shape,
|
|
||||||
)
|
|
||||||
|
|
||||||
if model_prev:
|
if model_prev:
|
||||||
with fork_rng(
|
with fork_rng(
|
||||||
cpu_state=cpu_state,
|
cpu_state=cpu_state,
|
||||||
@ -500,19 +477,6 @@ class AsrModel(nn.Module):
|
|||||||
encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder(
|
encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder(
|
||||||
x, x_lens
|
x, x_lens
|
||||||
)
|
)
|
||||||
print(
|
|
||||||
"encoder_out_prev",
|
|
||||||
encoder_out_prev.detach().mean(),
|
|
||||||
encoder_out_prev.detach().abs().max(),
|
|
||||||
encoder_out_prev.detach().abs().mean(),
|
|
||||||
encoder_out_prev.detach().sum(),
|
|
||||||
encoder_out_prev.shape,
|
|
||||||
)
|
|
||||||
print(
|
|
||||||
"isclose",
|
|
||||||
(encoder_out - encoder_out_prev).detach().abs().max(),
|
|
||||||
(encoder_out_lens - encoder_out_lens_prev).detach().abs().max(),
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
encoder_out_prev = None
|
encoder_out_prev = None
|
||||||
encoder_out_lens_prev = None
|
encoder_out_lens_prev = None
|
||||||
|
@ -1136,16 +1136,24 @@ def with_loss(x, y, name):
|
|||||||
|
|
||||||
class ScaleGradFunction(torch.autograd.Function):
|
class ScaleGradFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor, alpha: float) -> Tensor:
|
def forward(ctx, x: Tensor, alpha: Union[float, Tensor]) -> Tensor:
|
||||||
ctx.alpha = alpha
|
if isinstance(alpha, Tensor):
|
||||||
|
ctx.save_for_backward(alpha)
|
||||||
|
else:
|
||||||
|
ctx.alpha = alpha
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad: Tensor):
|
def backward(ctx, grad: Tensor):
|
||||||
return grad * ctx.alpha, None
|
if hasattr(ctx, "alpha"):
|
||||||
|
alpha = ctx.alpha
|
||||||
|
else:
|
||||||
|
(alpha,) = ctx.saved_tensors
|
||||||
|
|
||||||
|
return grad * alpha, None
|
||||||
|
|
||||||
|
|
||||||
def scale_grad(x: Tensor, alpha: float):
|
def scale_grad(x: Tensor, alpha: Union[float, Tensor]):
|
||||||
return ScaleGradFunction.apply(x, alpha)
|
return ScaleGradFunction.apply(x, alpha)
|
||||||
|
|
||||||
|
|
||||||
|
@ -552,9 +552,15 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--limit-grad-start-batch",
|
"--limit-grad-start-batch",
|
||||||
type=int,
|
type=int,
|
||||||
# default=1000,
|
default=1000,
|
||||||
default=2,
|
help="Enable grad limit starting from this batch. Set it to 0 to disable it",
|
||||||
help="Limit grad starting from this batch.",
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--limit-grad-every-n-batch",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Apply grad limit every this number of batch when it is enabled",
|
||||||
)
|
)
|
||||||
|
|
||||||
add_model_arguments(parser)
|
add_model_arguments(parser)
|
||||||
@ -1036,6 +1042,17 @@ def compute_validation_loss(
|
|||||||
return tot_loss
|
return tot_loss
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def update_model_prev(model_prev, model, beta):
|
||||||
|
# model_prev = beta * model_prev + (1-beta) * model
|
||||||
|
model_prev_dict = model_prev.state_dict()
|
||||||
|
model_dict = model.state_dict()
|
||||||
|
for key in model_prev_dict:
|
||||||
|
model_prev_dict[key].data.copy_(
|
||||||
|
model_prev_dict[key].data * beta + model_dict[key].data * (1 - beta)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
@ -1115,13 +1132,11 @@ def train_one_epoch(
|
|||||||
with torch.cuda.amp.autocast(
|
with torch.cuda.amp.autocast(
|
||||||
enabled=params.use_autocast, dtype=params.dtype
|
enabled=params.use_autocast, dtype=params.dtype
|
||||||
):
|
):
|
||||||
if params.batch_idx_train > params.limit_grad_start_batch:
|
|
||||||
model_prev = copy.deepcopy(model)
|
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
params=params,
|
params=params,
|
||||||
model=model,
|
model=model,
|
||||||
model_prev=model_prev
|
model_prev=model_prev
|
||||||
if params.batch_idx_train > params.limit_grad_start_batch
|
if 0 < params.limit_grad_start_batch < params.batch_idx_train
|
||||||
else None,
|
else None,
|
||||||
sp=sp,
|
sp=sp,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
@ -1140,17 +1155,15 @@ def train_one_epoch(
|
|||||||
scaler.update()
|
scaler.update()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
if params.batch_idx_train >= params.limit_grad_start_batch:
|
if (
|
||||||
|
0 < params.limit_grad_start_batch <= params.batch_idx_train
|
||||||
|
and params.batch_idx_train % params.limit_grad_every_n_batch == 0
|
||||||
|
):
|
||||||
if model_prev is None:
|
if model_prev is None:
|
||||||
model_prev = copy.deepcopy(model)
|
model_prev = copy.deepcopy(model)
|
||||||
else:
|
else:
|
||||||
model_prev = copy.deepcopy(model)
|
beta = max(0.5, 1.0 - 1.0 / (0.1 * params.batch_idx_train))
|
||||||
print(
|
update_model_prev(model_prev=model_prev, model=model, beta=beta)
|
||||||
"here",
|
|
||||||
params.batch_idx_train,
|
|
||||||
params.limit_grad_start_batch,
|
|
||||||
model_prev is None,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.info(f"Caught exception: {e}.")
|
logging.info(f"Caught exception: {e}.")
|
||||||
@ -1221,6 +1234,7 @@ def train_one_epoch(
|
|||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||||
f"lr: {cur_lr:.2e}, "
|
f"lr: {cur_lr:.2e}, "
|
||||||
+ (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "")
|
+ (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "")
|
||||||
|
+ (f", beta: {beta}" if model_prev is not None else "")
|
||||||
)
|
)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
@ -1622,9 +1636,4 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# torch.use_deterministic_algorithms(True, warn_only=True)
|
|
||||||
# torch.backends.cudnn.deterministic = True
|
|
||||||
# torch.backends.cudnn.benchmark = False
|
|
||||||
# torch.backends.cudnn.enabled = False
|
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user