mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Merge branch 'k2-fsa:master' into dev/zipformer_lstm
This commit is contained in:
commit
359ffce6c9
@ -758,7 +758,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, _ = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -766,6 +766,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss = losses[:2]
|
||||
|
||||
s = params.simple_loss_scale
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
|
@ -343,7 +343,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, _ = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -351,6 +351,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss = losses[:2]
|
||||
|
||||
s = params.simple_loss_scale
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
|
@ -814,7 +814,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -822,6 +822,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -449,7 +449,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -457,6 +457,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -803,7 +803,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -811,6 +811,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -90,7 +90,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
--encoder-dim 128,128,128,128,128,128 \
|
||||
--encoder-unmasked-dim 128,128,128,128,128,128
|
||||
|
||||
python ./zipformer/export_onnx_streaming.py \
|
||||
python ./zipformer/export-onnx-streaming.py \
|
||||
--exp-dir zipformer/exp \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 12 \
|
||||
@ -184,7 +184,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
--encoder-dim 128,128,128,128,128,128 \
|
||||
--encoder-unmasked-dim 128,128,128,128,128,128
|
||||
|
||||
python ./zipformer/export_onnx_streaming.py \
|
||||
python ./zipformer/export-onnx-streaming.py \
|
||||
--exp-dir zipformer/exp_finetune \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 10 \
|
||||
|
@ -806,7 +806,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -814,6 +814,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -787,7 +787,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -795,6 +795,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -55,7 +55,6 @@ It supports training with:
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
@ -804,7 +803,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -812,6 +811,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -893,7 +893,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -901,6 +901,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -636,8 +636,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module):
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor:
|
||||
"""
|
||||
Forward function. Args:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
x: a Tensor of shape (batch_size, channels, seq_len)
|
||||
chunk_size: the chunk size, in frames; does not have to divide seq_len exactly.
|
||||
"""
|
||||
@ -1032,7 +1033,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
||||
w.prob = w.max_prob
|
||||
metric.backward()
|
||||
penalty_grad = x_detached.grad
|
||||
scale = w.grad_scale * (
|
||||
scale = float(w.grad_scale) * (
|
||||
x_grad.to(torch.float32).norm()
|
||||
/ (penalty_grad.norm() + 1.0e-20)
|
||||
)
|
||||
@ -1074,7 +1075,7 @@ class Whiten(nn.Module):
|
||||
super(Whiten, self).__init__()
|
||||
assert num_groups >= 1
|
||||
assert float(whitening_limit) >= 1
|
||||
assert grad_scale >= 0
|
||||
assert float(grad_scale) >= 0
|
||||
self.num_groups = num_groups
|
||||
self.whitening_limit = whitening_limit
|
||||
self.grad_scale = grad_scale
|
||||
|
@ -406,7 +406,7 @@ def get_parser():
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -429,7 +429,7 @@ def get_parser():
|
||||
"--am-scale",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
||||
help="The scale to smooth the loss with am (output of encoder network) part.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -848,7 +848,7 @@ def compute_loss(
|
||||
True for training. False for validation. When it is True, this
|
||||
function enables autograd during computation; when it is False, it
|
||||
disables autograd.
|
||||
warmup: a floating point value which increases throughout training;
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
|
@ -890,7 +890,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -898,6 +898,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -903,7 +903,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -911,6 +911,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -1137,7 +1137,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
||||
w.prob = w.max_prob
|
||||
metric.backward()
|
||||
penalty_grad = x_detached.grad
|
||||
scale = w.grad_scale * (
|
||||
scale = float(w.grad_scale) * (
|
||||
x_grad.to(torch.float32).norm()
|
||||
/ (penalty_grad.norm() + 1.0e-20)
|
||||
)
|
||||
@ -1179,7 +1179,7 @@ class Whiten(nn.Module):
|
||||
super(Whiten, self).__init__()
|
||||
assert num_groups >= 1
|
||||
assert float(whitening_limit) >= 1
|
||||
assert grad_scale >= 0
|
||||
assert float(grad_scale) >= 0
|
||||
self.num_groups = num_groups
|
||||
self.whitening_limit = whitening_limit
|
||||
self.grad_scale = grad_scale
|
||||
|
@ -792,7 +792,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -800,6 +800,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -754,7 +754,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, _ = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -762,6 +762,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss = losses[:2]
|
||||
|
||||
s = params.simple_loss_scale
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
|
@ -832,7 +832,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -840,6 +840,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -814,7 +814,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -822,6 +822,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -59,7 +59,6 @@ from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
@ -791,7 +790,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -799,6 +798,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -67,7 +67,6 @@ import torch.nn as nn
|
||||
from asr_datamodule import SPGISpeechAsrDataModule
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import AsrModel
|
||||
@ -792,7 +791,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -800,6 +799,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
|
@ -758,7 +758,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, _ = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -766,6 +766,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss = losses[:2]
|
||||
|
||||
s = params.simple_loss_scale
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
|
@ -91,7 +91,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
--encoder-dim 128,128,128,128,128,128 \
|
||||
--encoder-unmasked-dim 128,128,128,128,128,128
|
||||
|
||||
python ./zipformer/export_onnx_streaming.py \
|
||||
python ./zipformer/export-onnx-streaming.py \
|
||||
--exp-dir zipformer/exp \
|
||||
--tokens data/lang_partial_tone/tokens.txt \
|
||||
--epoch 18 \
|
||||
@ -187,7 +187,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
--encoder-dim 128,128,128,128,128,128 \
|
||||
--encoder-unmasked-dim 128,128,128,128,128,128
|
||||
|
||||
python ./zipformer/export_onnx_streaming.py \
|
||||
python ./zipformer/export-onnx-streaming.py \
|
||||
--exp-dir zipformer/exp_finetune \
|
||||
--tokens data/lang_partial_tone/tokens.txt \
|
||||
--epoch 10 \
|
||||
|
@ -70,8 +70,7 @@ import copy
|
||||
import logging
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import k2
|
||||
import optim
|
||||
@ -80,7 +79,6 @@ import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import WenetSpeechAsrDataModule
|
||||
from lhotse.cut import Cut, CutSet
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
@ -103,14 +101,13 @@ from train import (
|
||||
|
||||
from icefall import diagnostics
|
||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import (
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
update_averaged_model,
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.err import raise_grad_scale_is_too_small_error
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
@ -296,7 +293,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -304,6 +301,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||
|
||||
loss = 0.0
|
||||
|
||||
@ -344,40 +342,6 @@ def compute_loss(
|
||||
return loss, info
|
||||
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
) -> MetricsTracker:
|
||||
"""Run the validation process."""
|
||||
model.eval()
|
||||
|
||||
tot_loss = MetricsTracker()
|
||||
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
graph_compiler=graph_compiler,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
)
|
||||
assert loss.requires_grad is False
|
||||
tot_loss = tot_loss + loss_info
|
||||
|
||||
if world_size > 1:
|
||||
tot_loss.reduce(loss.device)
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
if loss_value < params.best_valid_loss:
|
||||
params.best_valid_epoch = params.cur_epoch
|
||||
params.best_valid_loss = loss_value
|
||||
|
||||
return tot_loss
|
||||
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
model: Union[nn.Module, DDP],
|
||||
|
@ -815,7 +815,7 @@ def compute_loss(
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, _ = model(
|
||||
losses = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -823,6 +823,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
)
|
||||
simple_loss, pruned_loss = losses[:2]
|
||||
|
||||
s = params.simple_loss_scale
|
||||
# take down the scale on the simple loss from 1.0 at the start
|
||||
|
Loading…
x
Reference in New Issue
Block a user