Merge 17d7174cd10e94eadfb31a6f63687fb87393e487 into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9

This commit is contained in:
Fangjun Kuang 2025-07-25 09:16:26 +02:00 committed by GitHub
commit d1bbf051eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 1716 additions and 5 deletions

View File

@ -16,6 +16,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import contextlib
import random
from typing import Optional, Tuple from typing import Optional, Tuple
import k2 import k2
@ -23,11 +25,26 @@ import torch
import torch.nn as nn import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from lhotse.dataset import SpecAugment from lhotse.dataset import SpecAugment
from scaling import ScaledLinear from scaling import ScaledLinear, scale_grad
from icefall.utils import add_sos, make_pad_mask, time_warp, torch_autocast from icefall.utils import add_sos, make_pad_mask, time_warp, torch_autocast
@contextlib.contextmanager
def fork_rng(cpu_state, cuda_state, rng_state, device):
with torch.random.fork_rng(devices=[device]):
torch.set_rng_state(cpu_state)
torch.cuda.set_rng_state(cuda_state, device)
rng_state2 = random.getstate()
random.setstate(rng_state)
try:
yield
finally:
random.setstate(rng_state2)
class AsrModel(nn.Module): class AsrModel(nn.Module):
def __init__( def __init__(
self, self,
@ -159,6 +176,9 @@ class AsrModel(nn.Module):
encoder_out_lens: torch.Tensor, encoder_out_lens: torch.Tensor,
targets: torch.Tensor, targets: torch.Tensor,
target_lengths: torch.Tensor, target_lengths: torch.Tensor,
encoder_out_prev: Optional[torch.Tensor] = None,
encoder_out_lens_prev: Optional[torch.Tensor] = None,
model_prev=None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Compute CTC loss. """Compute CTC loss.
Args: Args:
@ -170,9 +190,28 @@ class AsrModel(nn.Module):
Target Tensor of shape (sum(target_lengths)). The targets are assumed Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension. to be un-padded and concatenated within 1 dimension.
""" """
device = encoder_out.device
if model_prev:
cpu_state = torch.get_rng_state()
cuda_state = torch.cuda.get_rng_state(device)
rng_state = random.getstate()
# Compute CTC log-prob # Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C) ctc_output = self.ctc_output(encoder_out) # (N, T, C)
if model_prev:
with fork_rng(
cpu_state=cpu_state,
cuda_state=cuda_state,
rng_state=rng_state,
device=device,
):
ctc_output_prev = model_prev.ctc_output(encoder_out_prev)
has_grown = ctc_output > 0.8 * ctc_output_prev
grad_scale_tensor = torch.where(has_grown, 0.5, 1.0)
ctc_output = scale_grad(ctc_output, grad_scale_tensor)
ctc_loss = torch.nn.functional.ctc_loss( ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C) log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets.cpu(), targets=targets.cpu(),
@ -345,6 +384,7 @@ class AsrModel(nn.Module):
spec_augment: Optional[SpecAugment] = None, spec_augment: Optional[SpecAugment] = None,
supervision_segments: Optional[torch.Tensor] = None, supervision_segments: Optional[torch.Tensor] = None,
time_warp_factor: Optional[int] = 80, time_warp_factor: Optional[int] = 80,
model_prev=None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -418,9 +458,29 @@ class AsrModel(nn.Module):
x_lens = x_lens.repeat(2) x_lens = x_lens.repeat(2)
y = k2.ragged.cat([y, y], axis=0) y = k2.ragged.cat([y, y], axis=0)
device = x.device
if model_prev:
cpu_state = torch.get_rng_state()
cuda_state = torch.cuda.get_rng_state(device)
rng_state = random.getstate()
# Compute encoder outputs # Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
if model_prev:
with fork_rng(
cpu_state=cpu_state,
cuda_state=cuda_state,
rng_state=rng_state,
device=device,
):
encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder(
x, x_lens
)
else:
encoder_out_prev = None
encoder_out_lens_prev = None
row_splits = y.shape.row_splits(1) row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1] y_lens = row_splits[1:] - row_splits[:-1]
@ -451,6 +511,9 @@ class AsrModel(nn.Module):
encoder_out_lens=encoder_out_lens, encoder_out_lens=encoder_out_lens,
targets=targets, targets=targets,
target_lengths=y_lens, target_lengths=y_lens,
encoder_out_prev=encoder_out_prev,
encoder_out_lens_prev=encoder_out_lens_prev,
model_prev=model_prev,
) )
cr_loss = torch.empty(0) cr_loss = torch.empty(0)
else: else:

View File

@ -1140,16 +1140,24 @@ def with_loss(x, y, name):
class ScaleGradFunction(torch.autograd.Function): class ScaleGradFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, x: Tensor, alpha: float) -> Tensor: def forward(ctx, x: Tensor, alpha: Union[float, Tensor]) -> Tensor:
ctx.alpha = alpha if isinstance(alpha, Tensor):
ctx.save_for_backward(alpha)
else:
ctx.alpha = alpha
return x return x
@staticmethod @staticmethod
def backward(ctx, grad: Tensor): def backward(ctx, grad: Tensor):
return grad * ctx.alpha, None if hasattr(ctx, "alpha"):
alpha = ctx.alpha
else:
(alpha,) = ctx.saved_tensors
return grad * alpha, None
def scale_grad(x: Tensor, alpha: float): def scale_grad(x: Tensor, alpha: Union[float, Tensor]):
return ScaleGradFunction.apply(x, alpha) return ScaleGradFunction.apply(x, alpha)

File diff suppressed because it is too large Load Diff