mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge 17d7174cd10e94eadfb31a6f63687fb87393e487 into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9
This commit is contained in:
commit
d1bbf051eb
@ -16,6 +16,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import random
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import k2
|
||||
@ -23,11 +25,26 @@ import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from lhotse.dataset import SpecAugment
|
||||
from scaling import ScaledLinear
|
||||
from scaling import ScaledLinear, scale_grad
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask, time_warp, torch_autocast
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fork_rng(cpu_state, cuda_state, rng_state, device):
|
||||
with torch.random.fork_rng(devices=[device]):
|
||||
torch.set_rng_state(cpu_state)
|
||||
torch.cuda.set_rng_state(cuda_state, device)
|
||||
|
||||
rng_state2 = random.getstate()
|
||||
random.setstate(rng_state)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
random.setstate(rng_state2)
|
||||
|
||||
|
||||
class AsrModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -159,6 +176,9 @@ class AsrModel(nn.Module):
|
||||
encoder_out_lens: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
target_lengths: torch.Tensor,
|
||||
encoder_out_prev: Optional[torch.Tensor] = None,
|
||||
encoder_out_lens_prev: Optional[torch.Tensor] = None,
|
||||
model_prev=None,
|
||||
) -> torch.Tensor:
|
||||
"""Compute CTC loss.
|
||||
Args:
|
||||
@ -170,9 +190,28 @@ class AsrModel(nn.Module):
|
||||
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||
to be un-padded and concatenated within 1 dimension.
|
||||
"""
|
||||
device = encoder_out.device
|
||||
if model_prev:
|
||||
cpu_state = torch.get_rng_state()
|
||||
cuda_state = torch.cuda.get_rng_state(device)
|
||||
rng_state = random.getstate()
|
||||
|
||||
# Compute CTC log-prob
|
||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
if model_prev:
|
||||
with fork_rng(
|
||||
cpu_state=cpu_state,
|
||||
cuda_state=cuda_state,
|
||||
rng_state=rng_state,
|
||||
device=device,
|
||||
):
|
||||
ctc_output_prev = model_prev.ctc_output(encoder_out_prev)
|
||||
|
||||
has_grown = ctc_output > 0.8 * ctc_output_prev
|
||||
grad_scale_tensor = torch.where(has_grown, 0.5, 1.0)
|
||||
ctc_output = scale_grad(ctc_output, grad_scale_tensor)
|
||||
|
||||
ctc_loss = torch.nn.functional.ctc_loss(
|
||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||
targets=targets.cpu(),
|
||||
@ -345,6 +384,7 @@ class AsrModel(nn.Module):
|
||||
spec_augment: Optional[SpecAugment] = None,
|
||||
supervision_segments: Optional[torch.Tensor] = None,
|
||||
time_warp_factor: Optional[int] = 80,
|
||||
model_prev=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -418,9 +458,29 @@ class AsrModel(nn.Module):
|
||||
x_lens = x_lens.repeat(2)
|
||||
y = k2.ragged.cat([y, y], axis=0)
|
||||
|
||||
device = x.device
|
||||
if model_prev:
|
||||
cpu_state = torch.get_rng_state()
|
||||
cuda_state = torch.cuda.get_rng_state(device)
|
||||
rng_state = random.getstate()
|
||||
|
||||
# Compute encoder outputs
|
||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||
|
||||
if model_prev:
|
||||
with fork_rng(
|
||||
cpu_state=cpu_state,
|
||||
cuda_state=cuda_state,
|
||||
rng_state=rng_state,
|
||||
device=device,
|
||||
):
|
||||
encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder(
|
||||
x, x_lens
|
||||
)
|
||||
else:
|
||||
encoder_out_prev = None
|
||||
encoder_out_lens_prev = None
|
||||
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
@ -451,6 +511,9 @@ class AsrModel(nn.Module):
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
targets=targets,
|
||||
target_lengths=y_lens,
|
||||
encoder_out_prev=encoder_out_prev,
|
||||
encoder_out_lens_prev=encoder_out_lens_prev,
|
||||
model_prev=model_prev,
|
||||
)
|
||||
cr_loss = torch.empty(0)
|
||||
else:
|
||||
|
@ -1140,16 +1140,24 @@ def with_loss(x, y, name):
|
||||
|
||||
class ScaleGradFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: Tensor, alpha: float) -> Tensor:
|
||||
ctx.alpha = alpha
|
||||
def forward(ctx, x: Tensor, alpha: Union[float, Tensor]) -> Tensor:
|
||||
if isinstance(alpha, Tensor):
|
||||
ctx.save_for_backward(alpha)
|
||||
else:
|
||||
ctx.alpha = alpha
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad: Tensor):
|
||||
return grad * ctx.alpha, None
|
||||
if hasattr(ctx, "alpha"):
|
||||
alpha = ctx.alpha
|
||||
else:
|
||||
(alpha,) = ctx.saved_tensors
|
||||
|
||||
return grad * alpha, None
|
||||
|
||||
|
||||
def scale_grad(x: Tensor, alpha: float):
|
||||
def scale_grad(x: Tensor, alpha: Union[float, Tensor]):
|
||||
return ScaleGradFunction.apply(x, alpha)
|
||||
|
||||
|
||||
|
1640
egs/librispeech/ASR/zipformer/train-limit-grad.py
Executable file
1640
egs/librispeech/ASR/zipformer/train-limit-grad.py
Executable file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user