mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Merge 17d7174cd10e94eadfb31a6f63687fb87393e487 into 34fc1fdf0d8ff520e2bb18267d046ca207c78ef9
This commit is contained in:
commit
d1bbf051eb
@ -16,6 +16,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
|
import random
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
@ -23,11 +25,26 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from lhotse.dataset import SpecAugment
|
from lhotse.dataset import SpecAugment
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear, scale_grad
|
||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask, time_warp, torch_autocast
|
from icefall.utils import add_sos, make_pad_mask, time_warp, torch_autocast
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def fork_rng(cpu_state, cuda_state, rng_state, device):
|
||||||
|
with torch.random.fork_rng(devices=[device]):
|
||||||
|
torch.set_rng_state(cpu_state)
|
||||||
|
torch.cuda.set_rng_state(cuda_state, device)
|
||||||
|
|
||||||
|
rng_state2 = random.getstate()
|
||||||
|
random.setstate(rng_state)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
random.setstate(rng_state2)
|
||||||
|
|
||||||
|
|
||||||
class AsrModel(nn.Module):
|
class AsrModel(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -159,6 +176,9 @@ class AsrModel(nn.Module):
|
|||||||
encoder_out_lens: torch.Tensor,
|
encoder_out_lens: torch.Tensor,
|
||||||
targets: torch.Tensor,
|
targets: torch.Tensor,
|
||||||
target_lengths: torch.Tensor,
|
target_lengths: torch.Tensor,
|
||||||
|
encoder_out_prev: Optional[torch.Tensor] = None,
|
||||||
|
encoder_out_lens_prev: Optional[torch.Tensor] = None,
|
||||||
|
model_prev=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Compute CTC loss.
|
"""Compute CTC loss.
|
||||||
Args:
|
Args:
|
||||||
@ -170,9 +190,28 @@ class AsrModel(nn.Module):
|
|||||||
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||||
to be un-padded and concatenated within 1 dimension.
|
to be un-padded and concatenated within 1 dimension.
|
||||||
"""
|
"""
|
||||||
|
device = encoder_out.device
|
||||||
|
if model_prev:
|
||||||
|
cpu_state = torch.get_rng_state()
|
||||||
|
cuda_state = torch.cuda.get_rng_state(device)
|
||||||
|
rng_state = random.getstate()
|
||||||
|
|
||||||
# Compute CTC log-prob
|
# Compute CTC log-prob
|
||||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||||
|
|
||||||
|
if model_prev:
|
||||||
|
with fork_rng(
|
||||||
|
cpu_state=cpu_state,
|
||||||
|
cuda_state=cuda_state,
|
||||||
|
rng_state=rng_state,
|
||||||
|
device=device,
|
||||||
|
):
|
||||||
|
ctc_output_prev = model_prev.ctc_output(encoder_out_prev)
|
||||||
|
|
||||||
|
has_grown = ctc_output > 0.8 * ctc_output_prev
|
||||||
|
grad_scale_tensor = torch.where(has_grown, 0.5, 1.0)
|
||||||
|
ctc_output = scale_grad(ctc_output, grad_scale_tensor)
|
||||||
|
|
||||||
ctc_loss = torch.nn.functional.ctc_loss(
|
ctc_loss = torch.nn.functional.ctc_loss(
|
||||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||||
targets=targets.cpu(),
|
targets=targets.cpu(),
|
||||||
@ -345,6 +384,7 @@ class AsrModel(nn.Module):
|
|||||||
spec_augment: Optional[SpecAugment] = None,
|
spec_augment: Optional[SpecAugment] = None,
|
||||||
supervision_segments: Optional[torch.Tensor] = None,
|
supervision_segments: Optional[torch.Tensor] = None,
|
||||||
time_warp_factor: Optional[int] = 80,
|
time_warp_factor: Optional[int] = 80,
|
||||||
|
model_prev=None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -418,9 +458,29 @@ class AsrModel(nn.Module):
|
|||||||
x_lens = x_lens.repeat(2)
|
x_lens = x_lens.repeat(2)
|
||||||
y = k2.ragged.cat([y, y], axis=0)
|
y = k2.ragged.cat([y, y], axis=0)
|
||||||
|
|
||||||
|
device = x.device
|
||||||
|
if model_prev:
|
||||||
|
cpu_state = torch.get_rng_state()
|
||||||
|
cuda_state = torch.cuda.get_rng_state(device)
|
||||||
|
rng_state = random.getstate()
|
||||||
|
|
||||||
# Compute encoder outputs
|
# Compute encoder outputs
|
||||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||||
|
|
||||||
|
if model_prev:
|
||||||
|
with fork_rng(
|
||||||
|
cpu_state=cpu_state,
|
||||||
|
cuda_state=cuda_state,
|
||||||
|
rng_state=rng_state,
|
||||||
|
device=device,
|
||||||
|
):
|
||||||
|
encoder_out_prev, encoder_out_lens_prev = model_prev.forward_encoder(
|
||||||
|
x, x_lens
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
encoder_out_prev = None
|
||||||
|
encoder_out_lens_prev = None
|
||||||
|
|
||||||
row_splits = y.shape.row_splits(1)
|
row_splits = y.shape.row_splits(1)
|
||||||
y_lens = row_splits[1:] - row_splits[:-1]
|
y_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
|
||||||
@ -451,6 +511,9 @@ class AsrModel(nn.Module):
|
|||||||
encoder_out_lens=encoder_out_lens,
|
encoder_out_lens=encoder_out_lens,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
target_lengths=y_lens,
|
target_lengths=y_lens,
|
||||||
|
encoder_out_prev=encoder_out_prev,
|
||||||
|
encoder_out_lens_prev=encoder_out_lens_prev,
|
||||||
|
model_prev=model_prev,
|
||||||
)
|
)
|
||||||
cr_loss = torch.empty(0)
|
cr_loss = torch.empty(0)
|
||||||
else:
|
else:
|
||||||
|
@ -1140,16 +1140,24 @@ def with_loss(x, y, name):
|
|||||||
|
|
||||||
class ScaleGradFunction(torch.autograd.Function):
|
class ScaleGradFunction(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, x: Tensor, alpha: float) -> Tensor:
|
def forward(ctx, x: Tensor, alpha: Union[float, Tensor]) -> Tensor:
|
||||||
ctx.alpha = alpha
|
if isinstance(alpha, Tensor):
|
||||||
|
ctx.save_for_backward(alpha)
|
||||||
|
else:
|
||||||
|
ctx.alpha = alpha
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad: Tensor):
|
def backward(ctx, grad: Tensor):
|
||||||
return grad * ctx.alpha, None
|
if hasattr(ctx, "alpha"):
|
||||||
|
alpha = ctx.alpha
|
||||||
|
else:
|
||||||
|
(alpha,) = ctx.saved_tensors
|
||||||
|
|
||||||
|
return grad * alpha, None
|
||||||
|
|
||||||
|
|
||||||
def scale_grad(x: Tensor, alpha: float):
|
def scale_grad(x: Tensor, alpha: Union[float, Tensor]):
|
||||||
return ScaleGradFunction.apply(x, alpha)
|
return ScaleGradFunction.apply(x, alpha)
|
||||||
|
|
||||||
|
|
||||||
|
1640
egs/librispeech/ASR/zipformer/train-limit-grad.py
Executable file
1640
egs/librispeech/ASR/zipformer/train-limit-grad.py
Executable file
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user