mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-10 10:32:17 +00:00
Merge branch 'k2-fsa:master' into dev/zipformer_lstm
This commit is contained in:
commit
359ffce6c9
@ -758,7 +758,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, _ = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -766,6 +766,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss = losses[:2]
|
||||||
|
|
||||||
s = params.simple_loss_scale
|
s = params.simple_loss_scale
|
||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
|
@ -343,7 +343,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, _ = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -351,6 +351,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss = losses[:2]
|
||||||
|
|
||||||
s = params.simple_loss_scale
|
s = params.simple_loss_scale
|
||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
|
@ -814,7 +814,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -822,6 +822,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -449,7 +449,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -457,6 +457,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -803,7 +803,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -811,6 +811,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -90,7 +90,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
--encoder-dim 128,128,128,128,128,128 \
|
--encoder-dim 128,128,128,128,128,128 \
|
||||||
--encoder-unmasked-dim 128,128,128,128,128,128
|
--encoder-unmasked-dim 128,128,128,128,128,128
|
||||||
|
|
||||||
python ./zipformer/export_onnx_streaming.py \
|
python ./zipformer/export-onnx-streaming.py \
|
||||||
--exp-dir zipformer/exp \
|
--exp-dir zipformer/exp \
|
||||||
--tokens data/lang_bpe_500/tokens.txt \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 12 \
|
--epoch 12 \
|
||||||
@ -184,7 +184,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
--encoder-dim 128,128,128,128,128,128 \
|
--encoder-dim 128,128,128,128,128,128 \
|
||||||
--encoder-unmasked-dim 128,128,128,128,128,128
|
--encoder-unmasked-dim 128,128,128,128,128,128
|
||||||
|
|
||||||
python ./zipformer/export_onnx_streaming.py \
|
python ./zipformer/export-onnx-streaming.py \
|
||||||
--exp-dir zipformer/exp_finetune \
|
--exp-dir zipformer/exp_finetune \
|
||||||
--tokens data/lang_bpe_500/tokens.txt \
|
--tokens data/lang_bpe_500/tokens.txt \
|
||||||
--epoch 10 \
|
--epoch 10 \
|
||||||
|
@ -806,7 +806,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -814,6 +814,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -787,7 +787,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -795,6 +795,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -55,7 +55,6 @@ It supports training with:
|
|||||||
import argparse
|
import argparse
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
import random
|
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
@ -804,7 +803,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -812,6 +811,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -893,7 +893,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -901,6 +901,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -636,8 +636,9 @@ class ChunkCausalDepthwiseConv1d(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor:
|
def forward(self, x: Tensor, chunk_size: int = -1) -> Tensor:
|
||||||
"""
|
"""Forward function.
|
||||||
Forward function. Args:
|
|
||||||
|
Args:
|
||||||
x: a Tensor of shape (batch_size, channels, seq_len)
|
x: a Tensor of shape (batch_size, channels, seq_len)
|
||||||
chunk_size: the chunk size, in frames; does not have to divide seq_len exactly.
|
chunk_size: the chunk size, in frames; does not have to divide seq_len exactly.
|
||||||
"""
|
"""
|
||||||
@ -1032,7 +1033,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
|||||||
w.prob = w.max_prob
|
w.prob = w.max_prob
|
||||||
metric.backward()
|
metric.backward()
|
||||||
penalty_grad = x_detached.grad
|
penalty_grad = x_detached.grad
|
||||||
scale = w.grad_scale * (
|
scale = float(w.grad_scale) * (
|
||||||
x_grad.to(torch.float32).norm()
|
x_grad.to(torch.float32).norm()
|
||||||
/ (penalty_grad.norm() + 1.0e-20)
|
/ (penalty_grad.norm() + 1.0e-20)
|
||||||
)
|
)
|
||||||
@ -1074,7 +1075,7 @@ class Whiten(nn.Module):
|
|||||||
super(Whiten, self).__init__()
|
super(Whiten, self).__init__()
|
||||||
assert num_groups >= 1
|
assert num_groups >= 1
|
||||||
assert float(whitening_limit) >= 1
|
assert float(whitening_limit) >= 1
|
||||||
assert grad_scale >= 0
|
assert float(grad_scale) >= 0
|
||||||
self.num_groups = num_groups
|
self.num_groups = num_groups
|
||||||
self.whitening_limit = whitening_limit
|
self.whitening_limit = whitening_limit
|
||||||
self.grad_scale = grad_scale
|
self.grad_scale = grad_scale
|
||||||
|
@ -406,7 +406,7 @@ def get_parser():
|
|||||||
"--context-size",
|
"--context-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=2,
|
default=2,
|
||||||
help="The context size in the decoder. 1 means bigram; " "2 means tri-gram",
|
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -429,7 +429,7 @@ def get_parser():
|
|||||||
"--am-scale",
|
"--am-scale",
|
||||||
type=float,
|
type=float,
|
||||||
default=0.0,
|
default=0.0,
|
||||||
help="The scale to smooth the loss with am (output of encoder network)" "part.",
|
help="The scale to smooth the loss with am (output of encoder network) part.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -848,7 +848,7 @@ def compute_loss(
|
|||||||
True for training. False for validation. When it is True, this
|
True for training. False for validation. When it is True, this
|
||||||
function enables autograd during computation; when it is False, it
|
function enables autograd during computation; when it is False, it
|
||||||
disables autograd.
|
disables autograd.
|
||||||
warmup: a floating point value which increases throughout training;
|
warmup: a floating point value which increases throughout training;
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||||
|
@ -890,7 +890,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -898,6 +898,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -903,7 +903,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -911,6 +911,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -1137,7 +1137,7 @@ class WhiteningPenaltyFunction(torch.autograd.Function):
|
|||||||
w.prob = w.max_prob
|
w.prob = w.max_prob
|
||||||
metric.backward()
|
metric.backward()
|
||||||
penalty_grad = x_detached.grad
|
penalty_grad = x_detached.grad
|
||||||
scale = w.grad_scale * (
|
scale = float(w.grad_scale) * (
|
||||||
x_grad.to(torch.float32).norm()
|
x_grad.to(torch.float32).norm()
|
||||||
/ (penalty_grad.norm() + 1.0e-20)
|
/ (penalty_grad.norm() + 1.0e-20)
|
||||||
)
|
)
|
||||||
@ -1179,7 +1179,7 @@ class Whiten(nn.Module):
|
|||||||
super(Whiten, self).__init__()
|
super(Whiten, self).__init__()
|
||||||
assert num_groups >= 1
|
assert num_groups >= 1
|
||||||
assert float(whitening_limit) >= 1
|
assert float(whitening_limit) >= 1
|
||||||
assert grad_scale >= 0
|
assert float(grad_scale) >= 0
|
||||||
self.num_groups = num_groups
|
self.num_groups = num_groups
|
||||||
self.whitening_limit = whitening_limit
|
self.whitening_limit = whitening_limit
|
||||||
self.grad_scale = grad_scale
|
self.grad_scale = grad_scale
|
||||||
|
@ -792,7 +792,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -800,6 +800,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -754,7 +754,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, _ = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -762,6 +762,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss = losses[:2]
|
||||||
|
|
||||||
s = params.simple_loss_scale
|
s = params.simple_loss_scale
|
||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
|
@ -832,7 +832,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -840,6 +840,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -814,7 +814,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -822,6 +822,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -59,7 +59,6 @@ from typing import Any, Dict, Optional, Tuple, Union
|
|||||||
|
|
||||||
import k2
|
import k2
|
||||||
import optim
|
import optim
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -791,7 +790,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -799,6 +798,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -67,7 +67,6 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import SPGISpeechAsrDataModule
|
from asr_datamodule import SPGISpeechAsrDataModule
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import AsrModel
|
from model import AsrModel
|
||||||
@ -792,7 +791,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -800,6 +799,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
|
@ -758,7 +758,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, _ = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -766,6 +766,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss = losses[:2]
|
||||||
|
|
||||||
s = params.simple_loss_scale
|
s = params.simple_loss_scale
|
||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
|
@ -91,7 +91,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
|||||||
--encoder-dim 128,128,128,128,128,128 \
|
--encoder-dim 128,128,128,128,128,128 \
|
||||||
--encoder-unmasked-dim 128,128,128,128,128,128
|
--encoder-unmasked-dim 128,128,128,128,128,128
|
||||||
|
|
||||||
python ./zipformer/export_onnx_streaming.py \
|
python ./zipformer/export-onnx-streaming.py \
|
||||||
--exp-dir zipformer/exp \
|
--exp-dir zipformer/exp \
|
||||||
--tokens data/lang_partial_tone/tokens.txt \
|
--tokens data/lang_partial_tone/tokens.txt \
|
||||||
--epoch 18 \
|
--epoch 18 \
|
||||||
@ -187,7 +187,7 @@ if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
|||||||
--encoder-dim 128,128,128,128,128,128 \
|
--encoder-dim 128,128,128,128,128,128 \
|
||||||
--encoder-unmasked-dim 128,128,128,128,128,128
|
--encoder-unmasked-dim 128,128,128,128,128,128
|
||||||
|
|
||||||
python ./zipformer/export_onnx_streaming.py \
|
python ./zipformer/export-onnx-streaming.py \
|
||||||
--exp-dir zipformer/exp_finetune \
|
--exp-dir zipformer/exp_finetune \
|
||||||
--tokens data/lang_partial_tone/tokens.txt \
|
--tokens data/lang_partial_tone/tokens.txt \
|
||||||
--epoch 10 \
|
--epoch 10 \
|
||||||
|
@ -70,8 +70,7 @@ import copy
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from typing import List, Optional, Tuple, Union
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import optim
|
import optim
|
||||||
@ -80,7 +79,6 @@ import torch.multiprocessing as mp
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import WenetSpeechAsrDataModule
|
from asr_datamodule import WenetSpeechAsrDataModule
|
||||||
from lhotse.cut import Cut, CutSet
|
from lhotse.cut import Cut, CutSet
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -103,14 +101,13 @@ from train import (
|
|||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import remove_checkpoints
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
save_checkpoint_with_global_batch_idx,
|
save_checkpoint_with_global_batch_idx,
|
||||||
update_averaged_model,
|
update_averaged_model,
|
||||||
)
|
)
|
||||||
from icefall.dist import cleanup_dist, setup_dist
|
from icefall.dist import cleanup_dist, setup_dist
|
||||||
from icefall.env import get_env_info
|
|
||||||
from icefall.err import raise_grad_scale_is_too_small_error
|
from icefall.err import raise_grad_scale_is_too_small_error
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
@ -296,7 +293,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y)
|
y = k2.RaggedTensor(y)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, ctc_loss = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -304,6 +301,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss, ctc_loss = losses[:3]
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
|
||||||
@ -344,40 +342,6 @@ def compute_loss(
|
|||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
|
|
||||||
def compute_validation_loss(
|
|
||||||
params: AttributeDict,
|
|
||||||
model: Union[nn.Module, DDP],
|
|
||||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
|
||||||
world_size: int = 1,
|
|
||||||
) -> MetricsTracker:
|
|
||||||
"""Run the validation process."""
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
tot_loss = MetricsTracker()
|
|
||||||
|
|
||||||
for batch_idx, batch in enumerate(valid_dl):
|
|
||||||
loss, loss_info = compute_loss(
|
|
||||||
params=params,
|
|
||||||
model=model,
|
|
||||||
graph_compiler=graph_compiler,
|
|
||||||
batch=batch,
|
|
||||||
is_training=False,
|
|
||||||
)
|
|
||||||
assert loss.requires_grad is False
|
|
||||||
tot_loss = tot_loss + loss_info
|
|
||||||
|
|
||||||
if world_size > 1:
|
|
||||||
tot_loss.reduce(loss.device)
|
|
||||||
|
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
|
||||||
if loss_value < params.best_valid_loss:
|
|
||||||
params.best_valid_epoch = params.cur_epoch
|
|
||||||
params.best_valid_loss = loss_value
|
|
||||||
|
|
||||||
return tot_loss
|
|
||||||
|
|
||||||
|
|
||||||
def train_one_epoch(
|
def train_one_epoch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: Union[nn.Module, DDP],
|
model: Union[nn.Module, DDP],
|
||||||
|
@ -815,7 +815,7 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss, _ = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -823,6 +823,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
)
|
)
|
||||||
|
simple_loss, pruned_loss = losses[:2]
|
||||||
|
|
||||||
s = params.simple_loss_scale
|
s = params.simple_loss_scale
|
||||||
# take down the scale on the simple loss from 1.0 at the start
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
|
Loading…
x
Reference in New Issue
Block a user