Add cr-ctc loss and ctc-decode in aishell (#1980)

This commit is contained in:
Mistmoon 2025-07-08 14:47:24 +08:00 committed by GitHub
parent fba5e67d5e
commit 9293edc62f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 738 additions and 19 deletions

View File

@ -1,5 +1,63 @@
## Results
### Aishell training results (zipformer + CR-CTC)
See <https://github.com/k2-fsa/icefall/pull/1980> for more details.
[zipformer](./zipformer)
#### Non-streaming
##### medium-scale model, number of model parameters: 66218471, i.e., 66.2 M
| decoding method | test | dev | comment |
|--------------------------------------|------------|------------|---------------------|
| ctc-greedy-search | 3.98 | 3.69 | --epoch 60 --avg 28 |
| ctc-prefix-beam-search | 3.98 | 3.70 | --epoch 60 --avg 21 |
The training command using 2 32G-V100 GPUs is:
```bash
export CUDA_VISIBLE_DEVICES="0,1"
./zipformer/train.py \
--world-size 2 \
--num-epochs 60 \
--start-epoch 1 \
--use-fp16 1 \
--context-size 1 \
--enable-musan 0 \
--exp-dir zipformer/exp \
--max-duration 500 \
--base-lr 0.045 \
--lr-batches 7500 \
--lr-epochs 18 \
--spec-aug-time-warp-factor 20 \
--use-ctc 1 \
--use-cr-ctc 1 \
--use-transducer 0 \
--enable-spec-aug 0 \
--cr-loss-scale 0.2
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in ctc-greedy-search ctc-prefix-beam-search; do
./zipformer/ctc_decode.py \
--epoch 60 \
--avg 28 \
--exp-dir zipformer/exp \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--max-duration 600 \
--decoding-method $m
done
```
Pretrained models, training logs, decoding logs, tensorboard and decoding results
are available at
<https://huggingface.co/MistMoon/icefall-asr-aishell-zipformer-medium-cr-ctc-20250702>
### Aishell training results (Fine-tuning Pretrained Models)
#### Whisper
[./whisper](./whisper)

View File

@ -0,0 +1,540 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Liyong Guo,
# Quandong Wang,
# Zengwei Yao,
# Zhifeng Han,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) ctc-greedy-search (with cr-ctc)
./zipformer/ctc_decode.py \
--epoch 60 \
--avg 28 \
--exp-dir ./zipformer/exp \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--max-duration 600 \
--decoding-method ctc-greedy-search
(2) ctc-prefix-beam-search (with cr-ctc)
./zipformer/ctc_decode.py \
--epoch 60 \
--avg 21 \
--exp-dir zipformer/exp \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--max-duration 600 \
--decoding-method ctc-prefix-beam-search
"""
import argparse
import logging
import math
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import AishellAsrDataModule
from lhotse.cut import Cut
from train import add_model_arguments, get_model, get_params
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
find_checkpoints,
load_checkpoint,
)
from icefall.decode import (
ctc_greedy_search,
ctc_prefix_beam_search,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
LOG_EPS = math.log(1e-10)
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--epoch",
type=int,
default=30,
help="""It specifies the checkpoint to use for decoding.
Note: Epoch counts from 1.
You can specify --avg to use more checkpoints for model averaging.""",
)
parser.add_argument(
"--iter",
type=int,
default=0,
help="""If positive, --epoch is ignored and it
will use the checkpoint exp_dir/checkpoint-iter.pt.
You can specify --avg to use more checkpoints for model averaging.
""",
)
parser.add_argument(
"--avg",
type=int,
default=15,
help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'",
)
parser.add_argument(
"--use-averaged-model",
type=str2bool,
default=True,
help="Whether to load averaged model. Currently it only supports "
"using --epoch. If True, it would decode with the averaged model "
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
"Actually only the models with epoch number of `epoch-avg` and "
"`epoch` are loaded for averaging. ",
)
parser.add_argument(
"--exp-dir",
type=str,
default="zipformer/exp",
help="The experiment dir",
)
parser.add_argument(
"--lang-dir",
type=Path,
default="data/lang_char",
help="The lang dir containing word table and LG graph",
)
parser.add_argument(
"--decoding-method",
type=str,
default="ctc-greedy-search",
help="""Decoding method.
Supported values are:
- (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
It needs neither a lexicon nor an n-gram LM.
(2) ctc-prefix-beam-search. Extract n paths with the given beam, the best
path of the n paths is the decoding result.
""",
)
add_model_arguments(parser)
return parser
def get_decoding_params() -> AttributeDict:
"""Parameters for decoding."""
params = AttributeDict(
{
"beam": 4, # for prefix-beam-search
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
batch: dict,
) -> Dict[str, Tuple[List[List[str]], List[List[Tuple[float, float]]]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if greedy_search is used, it would be "greedy_search"
If beam search with a beam size of 7 is used, it would be
"beam_7"
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args:
params:
It's the return value of :func:`get_params`.
model:
The neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
if params.causal:
# this seems to cause insertions at the end of the utterance if used with zipformer.
pad_len = 30
feature_lens += pad_len
feature = torch.nn.functional.pad(
feature,
pad=(0, 0, 0, pad_len),
value=LOG_EPS,
)
x, x_lens = model.encoder_embed(feature, feature_lens)
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
hyp_tokens = []
hyps = []
if params.decoding_method == "ctc-greedy-search":
hyp_tokens = ctc_greedy_search(
ctc_output=ctc_output,
encoder_out_lens=encoder_out_lens,
)
elif params.decoding_method == "ctc-prefix-beam-search":
hyp_tokens = ctc_prefix_beam_search(
ctc_output=ctc_output,
encoder_out_lens=encoder_out_lens,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
if params.decoding_method == "ctc-greedy-search":
return {"ctc-greedy-search" : hyps}
elif params.decoding_method == "ctc-prefix-beam-search":
return {"ctc-prefix-beam-search" : hyps}
else:
assert False, f"Unsupported decoding method: {params.decoding_method}"
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
lexicon: Lexicon,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains 3 elements:
Respectively, they are cut_id, the reference transcript, and the predicted result.
"""
num_cuts = 0
try:
num_batches = len(dl)
except TypeError:
num_batches = "?"
log_interval = 20
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list("".join(text.split())) for text in texts]
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
hyps_dict = decode_one_batch(
params=params,
model=model,
lexicon=lexicon,
batch=batch,
)
for name, hyps in hyps_dict.items():
this_batch = []
assert len(hyps) == len(texts)
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
this_batch.append((cut_id, ref_text, hyp_words))
results[name].extend(this_batch)
num_cuts += len(texts)
if batch_idx % log_interval == 0:
batch_str = f"{batch_idx}/{num_batches}"
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
return results
def save_results(
params: AttributeDict,
test_set_name: str,
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
):
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = (
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results, char_level = True)
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = (
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f,
f"{test_set_name}-{key}",
results,
enable_log=True,
compute_CER=True,
)
test_set_wers[key] = wer
logging.info("Wrote detailed error stats to {}".format(errs_filename))
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = (
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_wers:
s += "{}\t{}{}\n".format(key, val, note)
note = ""
logging.info(s)
@torch.no_grad()
def main():
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
args.lang_dir = Path(args.lang_dir)
params = get_params()
# add decoding params
params.update(get_decoding_params())
params.update(vars(args))
assert params.decoding_method in (
"ctc-greedy-search",
"ctc-prefix-beam-search",
) # support ctc-greedy-search and ctc-prefix-beam-search
params.res_dir = params.exp_dir / params.decoding_method
if params.iter > 0:
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
else:
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
if params.causal:
assert (
"," not in params.chunk_size
), "chunk_size should be one value in decoding."
assert (
"," not in params.left_context_frames
), "left_context_frames should be one value in decoding."
params.suffix += f"-chunk-{params.chunk_size}"
params.suffix += f"-left-context-{params.left_context_frames}"
if "prefix-beam-search" in params.decoding_method:
params.suffix += f"_beam-{params.beam}"
if params.use_averaged_model:
params.suffix += "-use-averaged-model"
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
logging.info("Decoding started")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
params.device = device
logging.info(f"Device: {device}")
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(params)
logging.info("About to create model")
model = get_model(params)
if not params.use_averaged_model:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
elif params.avg == 1:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
else:
start = params.epoch - params.avg + 1
filenames = []
for i in range(start, params.epoch + 1):
if i >= 1:
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
logging.info(f"averaging {filenames}")
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
else:
if params.iter > 0:
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
: params.avg + 1
]
if len(filenames) == 0:
raise ValueError(
f"No checkpoints found for"
f" --iter {params.iter}, --avg {params.avg}"
)
elif len(filenames) < params.avg + 1:
raise ValueError(
f"Not enough checkpoints ({len(filenames)}) found for"
f" --iter {params.iter}, --avg {params.avg}"
)
filename_start = filenames[-1]
filename_end = filenames[0]
logging.info(
"Calculating the averaged model over iteration checkpoints"
f" from {filename_start} (excluded) to {filename_end}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
else:
assert params.avg > 0, params.avg
start = params.epoch - params.avg
assert start >= 1, start
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
# we need cut ids to display recognition results.
args.return_cuts = True
aishell = AishellAsrDataModule(args)
dev_cuts = aishell.valid_cuts()
dev_dl = aishell.valid_dataloaders(dev_cuts)
test_cuts = aishell.test_cuts()
test_dl = aishell.test_dataloaders(test_cuts)
test_sets = ["dev", "test"]
test_dls = [dev_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dls):
results_dict = decode_dataset(
dl=test_dl,
params=params,
model=model,
lexicon=lexicon,
)
save_results(
params=params,
test_set_name=test_set,
results_dict=results_dict,
)
logging.info("Done!")
if __name__ == "__main__":
main()

View File

@ -64,6 +64,7 @@ from asr_datamodule import AishellAsrDataModule
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.dataset import SpecAugment
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import AsrModel
@ -240,6 +241,27 @@ def add_model_arguments(parser: argparse.ArgumentParser):
chunk left-context frames will be chosen randomly from this list; else not relevant.""",
)
parser.add_argument(
"--use-transducer",
type=str2bool,
default=True,
help="If True, use Transducer head.",
)
parser.add_argument(
"--use-ctc",
type=str2bool,
default=False,
help="If True, use CTC head.",
)
parser.add_argument(
"--use-cr-ctc",
type=str2bool,
default=False,
help="If True, use consistency-regularized CTC.",
)
def get_parser():
parser = argparse.ArgumentParser(
@ -380,6 +402,27 @@ def get_parser():
with this parameter before adding to the final loss.""",
)
parser.add_argument(
"--ctc-loss-scale",
type=float,
default=0.2,
help="Scale for CTC loss.",
)
parser.add_argument(
"--cr-loss-scale",
type=float,
default=0.2,
help="Scale for consistency-regularization loss.",
)
parser.add_argument(
"--time-mask-ratio",
type=float,
default=2.5,
help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.",
)
parser.add_argument(
"--seed",
type=int,
@ -583,8 +626,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_model(params: AttributeDict) -> nn.Module:
encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
if params.use_transducer:
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)
else:
decoder = None
joiner = None
model = AsrModel(
encoder_embed=encoder_embed,
@ -594,9 +642,27 @@ def get_model(params: AttributeDict) -> nn.Module:
encoder_dim=int(max(params.encoder_dim.split(","))),
decoder_dim=params.decoder_dim,
vocab_size=params.vocab_size,
use_transducer=params.use_transducer,
use_ctc=params.use_ctc,
)
return model
def get_spec_augment(params: AttributeDict) -> SpecAugment:
num_frame_masks = int(10 * params.time_mask_ratio)
max_frames_mask_fraction = 0.15 * params.time_mask_ratio
logging.info(
f"num_frame_masks: {num_frame_masks}, "
f"max_frames_mask_fraction: {max_frames_mask_fraction}"
)
spec_augment = SpecAugment(
time_warp_factor=0, # Do time warping in model.py
num_frame_masks=num_frame_masks, # default: 10
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15
)
return spec_augment
def load_checkpoint_if_available(
params: AttributeDict,
@ -723,6 +789,7 @@ def compute_loss(
graph_compiler: CharCtcTrainingGraphCompiler,
batch: dict,
is_training: bool,
spec_augment: Optional[SpecAugment] = None,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute CTC loss given the model and its inputs.
@ -739,6 +806,8 @@ def compute_loss(
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
spec_augment:
The SpecAugment instance used only when use_cr_ctc is True.
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
@ -758,6 +827,21 @@ def compute_loss(
y = graph_compiler.texts_to_ids(texts)
y = k2.RaggedTensor(y).to(device)
use_cr_ctc = params.use_cr_ctc
use_spec_aug = use_cr_ctc and is_training
if use_spec_aug:
supervision_intervals = batch["supervisions"]
supervision_segments = torch.stack(
[
supervision_intervals["sequence_idx"],
supervision_intervals["start_frame"],
supervision_intervals["num_frames"],
],
dim=1,
) # shape: (S, 3)
else:
supervision_segments = None
with torch.set_grad_enabled(is_training):
losses = model(
x=feature,
@ -766,25 +850,40 @@ def compute_loss(
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
use_cr_ctc=use_cr_ctc,
use_spec_aug=use_spec_aug,
spec_augment=spec_augment,
supervision_segments=supervision_segments,
time_warp_factor=params.spec_aug_time_warp_factor,
)
simple_loss, pruned_loss = losses[:2]
if params.use_ctc:
simple_loss, pruned_loss, ctc_loss, _, cr_loss = losses[:5]
else:
simple_loss, pruned_loss = losses[:2]
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss = 0.0
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.use_transducer:
s = params.simple_loss_scale
# take down the scale on the simple loss from 1.0 at the start
# to params.simple_loss scale by warm_step.
simple_loss_scale = (
s
if batch_idx_train >= warm_step
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
)
pruned_loss_scale = (
1.0
if batch_idx_train >= warm_step
else 0.1 + 0.9 * (batch_idx_train / warm_step)
)
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss
if use_cr_ctc:
loss += params.cr_loss_scale * cr_loss
assert loss.requires_grad == is_training
info = MetricsTracker()
@ -794,8 +893,13 @@ def compute_loss(
# Note: We use reduction=sum while computing the loss.
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_transducer:
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.use_cr_ctc:
info["cr_loss"] = cr_loss.detach().cpu().item()
return loss, info
@ -843,6 +947,7 @@ def train_one_epoch(
train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler,
spec_augment: Optional[SpecAugment] = None,
model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1,
@ -869,6 +974,8 @@ def train_one_epoch(
Dataloader for the validation dataset.
scaler:
The scaler used for mix precision training.
spec_augment:
The SpecAugment instance used only when use_cr_ctc is True.
model_avg:
The stored model averaged from the start of training.
tb_writer:
@ -918,6 +1025,7 @@ def train_one_epoch(
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
spec_augment=spec_augment,
)
# summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -1083,6 +1191,9 @@ def run(rank, world_size, args):
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
if not params.use_transducer:
params.ctc_loss_scale = 1.0
logging.info(params)
logging.info("About to create model")
@ -1091,6 +1202,12 @@ def run(rank, world_size, args):
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
if params.use_cr_ctc:
assert not params.enable_spec_aug # we will do spec_augment in model.py
spec_augment = get_spec_augment(params)
else:
spec_augment = None
assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None
if rank == 0:
@ -1200,6 +1317,7 @@ def run(rank, world_size, args):
optimizer=optimizer,
graph_compiler=graph_compiler,
params=params,
spec_augment=spec_augment,
)
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
@ -1227,6 +1345,7 @@ def run(rank, world_size, args):
train_dl=train_dl,
valid_dl=valid_dl,
scaler=scaler,
spec_augment=spec_augment,
tb_writer=tb_writer,
world_size=world_size,
rank=rank,
@ -1293,6 +1412,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer,
graph_compiler: CharCtcTrainingGraphCompiler,
params: AttributeDict,
spec_augment: Optional[SpecAugment] = None,
):
from lhotse.dataset import find_pessimistic_batches
@ -1310,6 +1430,7 @@ def scan_pessimistic_batches_for_oom(
graph_compiler=graph_compiler,
batch=batch,
is_training=True,
spec_augment=spec_augment,
)
loss.backward()
optimizer.zero_grad()