mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add cr-ctc loss and ctc-decode in aishell (#1980)
This commit is contained in:
parent
fba5e67d5e
commit
9293edc62f
@ -1,5 +1,63 @@
|
|||||||
## Results
|
## Results
|
||||||
|
|
||||||
|
### Aishell training results (zipformer + CR-CTC)
|
||||||
|
|
||||||
|
See <https://github.com/k2-fsa/icefall/pull/1980> for more details.
|
||||||
|
|
||||||
|
[zipformer](./zipformer)
|
||||||
|
|
||||||
|
#### Non-streaming
|
||||||
|
|
||||||
|
##### medium-scale model, number of model parameters: 66218471, i.e., 66.2 M
|
||||||
|
|
||||||
|
| decoding method | test | dev | comment |
|
||||||
|
|--------------------------------------|------------|------------|---------------------|
|
||||||
|
| ctc-greedy-search | 3.98 | 3.69 | --epoch 60 --avg 28 |
|
||||||
|
| ctc-prefix-beam-search | 3.98 | 3.70 | --epoch 60 --avg 21 |
|
||||||
|
|
||||||
|
The training command using 2 32G-V100 GPUs is:
|
||||||
|
```bash
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1"
|
||||||
|
./zipformer/train.py \
|
||||||
|
--world-size 2 \
|
||||||
|
--num-epochs 60 \
|
||||||
|
--start-epoch 1 \
|
||||||
|
--use-fp16 1 \
|
||||||
|
--context-size 1 \
|
||||||
|
--enable-musan 0 \
|
||||||
|
--exp-dir zipformer/exp \
|
||||||
|
--max-duration 500 \
|
||||||
|
--base-lr 0.045 \
|
||||||
|
--lr-batches 7500 \
|
||||||
|
--lr-epochs 18 \
|
||||||
|
--spec-aug-time-warp-factor 20 \
|
||||||
|
--use-ctc 1 \
|
||||||
|
--use-cr-ctc 1 \
|
||||||
|
--use-transducer 0 \
|
||||||
|
--enable-spec-aug 0 \
|
||||||
|
--cr-loss-scale 0.2
|
||||||
|
```
|
||||||
|
|
||||||
|
The decoding command is:
|
||||||
|
```bash
|
||||||
|
export CUDA_VISIBLE_DEVICES="0"
|
||||||
|
for m in ctc-greedy-search ctc-prefix-beam-search; do
|
||||||
|
./zipformer/ctc_decode.py \
|
||||||
|
--epoch 60 \
|
||||||
|
--avg 28 \
|
||||||
|
--exp-dir zipformer/exp \
|
||||||
|
--use-cr-ctc 1 \
|
||||||
|
--use-ctc 1 \
|
||||||
|
--use-transducer 0 \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method $m
|
||||||
|
done
|
||||||
|
```
|
||||||
|
|
||||||
|
Pretrained models, training logs, decoding logs, tensorboard and decoding results
|
||||||
|
are available at
|
||||||
|
<https://huggingface.co/MistMoon/icefall-asr-aishell-zipformer-medium-cr-ctc-20250702>
|
||||||
|
|
||||||
### Aishell training results (Fine-tuning Pretrained Models)
|
### Aishell training results (Fine-tuning Pretrained Models)
|
||||||
#### Whisper
|
#### Whisper
|
||||||
[./whisper](./whisper)
|
[./whisper](./whisper)
|
||||||
|
540
egs/aishell/ASR/zipformer/ctc_decode.py
Executable file
540
egs/aishell/ASR/zipformer/ctc_decode.py
Executable file
@ -0,0 +1,540 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
#
|
||||||
|
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||||
|
# Liyong Guo,
|
||||||
|
# Quandong Wang,
|
||||||
|
# Zengwei Yao,
|
||||||
|
# Zhifeng Han,)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
|
||||||
|
(1) ctc-greedy-search (with cr-ctc)
|
||||||
|
./zipformer/ctc_decode.py \
|
||||||
|
--epoch 60 \
|
||||||
|
--avg 28 \
|
||||||
|
--exp-dir ./zipformer/exp \
|
||||||
|
--use-cr-ctc 1 \
|
||||||
|
--use-ctc 1 \
|
||||||
|
--use-transducer 0 \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method ctc-greedy-search
|
||||||
|
(2) ctc-prefix-beam-search (with cr-ctc)
|
||||||
|
./zipformer/ctc_decode.py \
|
||||||
|
--epoch 60 \
|
||||||
|
--avg 21 \
|
||||||
|
--exp-dir zipformer/exp \
|
||||||
|
--use-cr-ctc 1 \
|
||||||
|
--use-ctc 1 \
|
||||||
|
--use-transducer 0 \
|
||||||
|
--max-duration 600 \
|
||||||
|
--decoding-method ctc-prefix-beam-search
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from asr_datamodule import AishellAsrDataModule
|
||||||
|
from lhotse.cut import Cut
|
||||||
|
from train import add_model_arguments, get_model, get_params
|
||||||
|
|
||||||
|
from icefall.checkpoint import (
|
||||||
|
average_checkpoints,
|
||||||
|
average_checkpoints_with_averaged_model,
|
||||||
|
find_checkpoints,
|
||||||
|
load_checkpoint,
|
||||||
|
)
|
||||||
|
from icefall.decode import (
|
||||||
|
ctc_greedy_search,
|
||||||
|
ctc_prefix_beam_search,
|
||||||
|
)
|
||||||
|
from icefall.lexicon import Lexicon
|
||||||
|
from icefall.utils import (
|
||||||
|
AttributeDict,
|
||||||
|
make_pad_mask,
|
||||||
|
setup_logger,
|
||||||
|
store_transcripts,
|
||||||
|
str2bool,
|
||||||
|
write_error_stats,
|
||||||
|
)
|
||||||
|
|
||||||
|
LOG_EPS = math.log(1e-10)
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--epoch",
|
||||||
|
type=int,
|
||||||
|
default=30,
|
||||||
|
help="""It specifies the checkpoint to use for decoding.
|
||||||
|
Note: Epoch counts from 1.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--iter",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
help="""If positive, --epoch is ignored and it
|
||||||
|
will use the checkpoint exp_dir/checkpoint-iter.pt.
|
||||||
|
You can specify --avg to use more checkpoints for model averaging.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--avg",
|
||||||
|
type=int,
|
||||||
|
default=15,
|
||||||
|
help="Number of checkpoints to average. Automatically select "
|
||||||
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
|
"'--epoch' and '--iter'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-averaged-model",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="Whether to load averaged model. Currently it only supports "
|
||||||
|
"using --epoch. If True, it would decode with the averaged model "
|
||||||
|
"over the epoch range from `epoch-avg` (excluded) to `epoch`."
|
||||||
|
"Actually only the models with epoch number of `epoch-avg` and "
|
||||||
|
"`epoch` are loaded for averaging. ",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--exp-dir",
|
||||||
|
type=str,
|
||||||
|
default="zipformer/exp",
|
||||||
|
help="The experiment dir",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--lang-dir",
|
||||||
|
type=Path,
|
||||||
|
default="data/lang_char",
|
||||||
|
help="The lang dir containing word table and LG graph",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-method",
|
||||||
|
type=str,
|
||||||
|
default="ctc-greedy-search",
|
||||||
|
help="""Decoding method.
|
||||||
|
Supported values are:
|
||||||
|
- (1) ctc-greedy-search. Use CTC greedy search. It uses a sentence piece
|
||||||
|
model, i.e., lang_dir/bpe.model, to convert word pieces to words.
|
||||||
|
It needs neither a lexicon nor an n-gram LM.
|
||||||
|
(2) ctc-prefix-beam-search. Extract n paths with the given beam, the best
|
||||||
|
path of the n paths is the decoding result.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_decoding_params() -> AttributeDict:
|
||||||
|
"""Parameters for decoding."""
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"beam": 4, # for prefix-beam-search
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
lexicon: Lexicon,
|
||||||
|
batch: dict,
|
||||||
|
) -> Dict[str, Tuple[List[List[str]], List[List[Tuple[float, float]]]]]:
|
||||||
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
|
following format:
|
||||||
|
|
||||||
|
- key: It indicates the setting used for decoding. For example,
|
||||||
|
if greedy_search is used, it would be "greedy_search"
|
||||||
|
If beam search with a beam size of 7 is used, it would be
|
||||||
|
"beam_7"
|
||||||
|
- value: It contains the decoding result. `len(value)` equals to
|
||||||
|
batch size. `value[i]` is the decoding result for the i-th
|
||||||
|
utterance in the given batch.
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
Returns:
|
||||||
|
Return the decoding result. See above description for the format of
|
||||||
|
the returned dict.
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
feature = batch["inputs"]
|
||||||
|
assert feature.ndim == 3
|
||||||
|
|
||||||
|
feature = feature.to(device)
|
||||||
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
if params.causal:
|
||||||
|
# this seems to cause insertions at the end of the utterance if used with zipformer.
|
||||||
|
pad_len = 30
|
||||||
|
feature_lens += pad_len
|
||||||
|
feature = torch.nn.functional.pad(
|
||||||
|
feature,
|
||||||
|
pad=(0, 0, 0, pad_len),
|
||||||
|
value=LOG_EPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
x, x_lens = model.encoder_embed(feature, feature_lens)
|
||||||
|
|
||||||
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = model.encoder(x, x_lens, src_key_padding_mask)
|
||||||
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
|
ctc_output = model.ctc_output(encoder_out) # (N, T, C)
|
||||||
|
|
||||||
|
hyp_tokens = []
|
||||||
|
hyps = []
|
||||||
|
|
||||||
|
if params.decoding_method == "ctc-greedy-search":
|
||||||
|
hyp_tokens = ctc_greedy_search(
|
||||||
|
ctc_output=ctc_output,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "ctc-prefix-beam-search":
|
||||||
|
hyp_tokens = ctc_prefix_beam_search(
|
||||||
|
ctc_output=ctc_output,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(encoder_out.size(0)):
|
||||||
|
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
|
||||||
|
|
||||||
|
if params.decoding_method == "ctc-greedy-search":
|
||||||
|
return {"ctc-greedy-search" : hyps}
|
||||||
|
elif params.decoding_method == "ctc-prefix-beam-search":
|
||||||
|
return {"ctc-prefix-beam-search" : hyps}
|
||||||
|
else:
|
||||||
|
assert False, f"Unsupported decoding method: {params.decoding_method}"
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
lexicon: Lexicon,
|
||||||
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
|
"""Decode dataset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
Returns:
|
||||||
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
|
Its value is a list of tuples. Each tuple contains 3 elements:
|
||||||
|
Respectively, they are cut_id, the reference transcript, and the predicted result.
|
||||||
|
"""
|
||||||
|
num_cuts = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
num_batches = len(dl)
|
||||||
|
except TypeError:
|
||||||
|
num_batches = "?"
|
||||||
|
|
||||||
|
log_interval = 20
|
||||||
|
|
||||||
|
results = defaultdict(list)
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
texts = batch["supervisions"]["text"]
|
||||||
|
texts = [list("".join(text.split())) for text in texts]
|
||||||
|
cut_ids = [cut.id for cut in batch["supervisions"]["cut"]]
|
||||||
|
|
||||||
|
hyps_dict = decode_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
lexicon=lexicon,
|
||||||
|
batch=batch,
|
||||||
|
)
|
||||||
|
for name, hyps in hyps_dict.items():
|
||||||
|
this_batch = []
|
||||||
|
assert len(hyps) == len(texts)
|
||||||
|
for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts):
|
||||||
|
this_batch.append((cut_id, ref_text, hyp_words))
|
||||||
|
results[name].extend(this_batch)
|
||||||
|
|
||||||
|
num_cuts += len(texts)
|
||||||
|
|
||||||
|
if batch_idx % log_interval == 0:
|
||||||
|
batch_str = f"{batch_idx}/{num_batches}"
|
||||||
|
|
||||||
|
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def save_results(
|
||||||
|
params: AttributeDict,
|
||||||
|
test_set_name: str,
|
||||||
|
results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]],
|
||||||
|
):
|
||||||
|
test_set_wers = dict()
|
||||||
|
for key, results in results_dict.items():
|
||||||
|
recog_path = (
|
||||||
|
params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
results = sorted(results)
|
||||||
|
store_transcripts(filename=recog_path, texts=results, char_level = True)
|
||||||
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
|
# ref/hyp pairs.
|
||||||
|
errs_filename = (
|
||||||
|
params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(errs_filename, "w") as f:
|
||||||
|
wer = write_error_stats(
|
||||||
|
f,
|
||||||
|
f"{test_set_name}-{key}",
|
||||||
|
results,
|
||||||
|
enable_log=True,
|
||||||
|
compute_CER=True,
|
||||||
|
)
|
||||||
|
test_set_wers[key] = wer
|
||||||
|
|
||||||
|
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||||
|
|
||||||
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
|
errs_info = (
|
||||||
|
params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
|
with open(errs_info, "w") as f:
|
||||||
|
print("settings\tWER", file=f)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
print("{}\t{}".format(key, val), file=f)
|
||||||
|
|
||||||
|
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||||
|
note = "\tbest for {}".format(test_set_name)
|
||||||
|
for key, val in test_set_wers:
|
||||||
|
s += "{}\t{}{}\n".format(key, val, note)
|
||||||
|
note = ""
|
||||||
|
logging.info(s)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
AishellAsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
args.exp_dir = Path(args.exp_dir)
|
||||||
|
args.lang_dir = Path(args.lang_dir)
|
||||||
|
|
||||||
|
params = get_params()
|
||||||
|
# add decoding params
|
||||||
|
params.update(get_decoding_params())
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
assert params.decoding_method in (
|
||||||
|
"ctc-greedy-search",
|
||||||
|
"ctc-prefix-beam-search",
|
||||||
|
) # support ctc-greedy-search and ctc-prefix-beam-search
|
||||||
|
params.res_dir = params.exp_dir / params.decoding_method
|
||||||
|
|
||||||
|
if params.iter > 0:
|
||||||
|
params.suffix = f"iter-{params.iter}-avg-{params.avg}"
|
||||||
|
else:
|
||||||
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
|
|
||||||
|
if params.causal:
|
||||||
|
assert (
|
||||||
|
"," not in params.chunk_size
|
||||||
|
), "chunk_size should be one value in decoding."
|
||||||
|
assert (
|
||||||
|
"," not in params.left_context_frames
|
||||||
|
), "left_context_frames should be one value in decoding."
|
||||||
|
params.suffix += f"-chunk-{params.chunk_size}"
|
||||||
|
params.suffix += f"-left-context-{params.left_context_frames}"
|
||||||
|
|
||||||
|
if "prefix-beam-search" in params.decoding_method:
|
||||||
|
params.suffix += f"_beam-{params.beam}"
|
||||||
|
|
||||||
|
if params.use_averaged_model:
|
||||||
|
params.suffix += "-use-averaged-model"
|
||||||
|
|
||||||
|
setup_logger(f"{params.res_dir}/log-decode-{params.suffix}")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", 0)
|
||||||
|
params.device = device
|
||||||
|
|
||||||
|
logging.info(f"Device: {device}")
|
||||||
|
|
||||||
|
lexicon = Lexicon(params.lang_dir)
|
||||||
|
|
||||||
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
logging.info(params)
|
||||||
|
|
||||||
|
logging.info("About to create model")
|
||||||
|
model = get_model(params)
|
||||||
|
|
||||||
|
if not params.use_averaged_model:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
elif params.avg == 1:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
|
else:
|
||||||
|
start = params.epoch - params.avg + 1
|
||||||
|
filenames = []
|
||||||
|
for i in range(start, params.epoch + 1):
|
||||||
|
if i >= 1:
|
||||||
|
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||||
|
logging.info(f"averaging {filenames}")
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(average_checkpoints(filenames, device=device))
|
||||||
|
else:
|
||||||
|
if params.iter > 0:
|
||||||
|
filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[
|
||||||
|
: params.avg + 1
|
||||||
|
]
|
||||||
|
if len(filenames) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"No checkpoints found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
elif len(filenames) < params.avg + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Not enough checkpoints ({len(filenames)}) found for"
|
||||||
|
f" --iter {params.iter}, --avg {params.avg}"
|
||||||
|
)
|
||||||
|
filename_start = filenames[-1]
|
||||||
|
filename_end = filenames[0]
|
||||||
|
logging.info(
|
||||||
|
"Calculating the averaged model over iteration checkpoints"
|
||||||
|
f" from {filename_start} (excluded) to {filename_end}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert params.avg > 0, params.avg
|
||||||
|
start = params.epoch - params.avg
|
||||||
|
assert start >= 1, start
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
|
)
|
||||||
|
model.to(device)
|
||||||
|
model.load_state_dict(
|
||||||
|
average_checkpoints_with_averaged_model(
|
||||||
|
filename_start=filename_start,
|
||||||
|
filename_end=filename_end,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
# we need cut ids to display recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
aishell = AishellAsrDataModule(args)
|
||||||
|
|
||||||
|
dev_cuts = aishell.valid_cuts()
|
||||||
|
dev_dl = aishell.valid_dataloaders(dev_cuts)
|
||||||
|
|
||||||
|
test_cuts = aishell.test_cuts()
|
||||||
|
test_dl = aishell.test_dataloaders(test_cuts)
|
||||||
|
|
||||||
|
test_sets = ["dev", "test"]
|
||||||
|
test_dls = [dev_dl, test_dl]
|
||||||
|
|
||||||
|
for test_set, test_dl in zip(test_sets, test_dls):
|
||||||
|
results_dict = decode_dataset(
|
||||||
|
dl=test_dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
lexicon=lexicon,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_results(
|
||||||
|
params=params,
|
||||||
|
test_set_name=test_set,
|
||||||
|
results_dict=results_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -64,6 +64,7 @@ from asr_datamodule import AishellAsrDataModule
|
|||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
|
from lhotse.dataset import SpecAugment
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from model import AsrModel
|
from model import AsrModel
|
||||||
@ -240,6 +241,27 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
chunk left-context frames will be chosen randomly from this list; else not relevant.""",
|
chunk left-context frames will be chosen randomly from this list; else not relevant.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-transducer",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="If True, use Transducer head.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-ctc",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="If True, use CTC head.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-cr-ctc",
|
||||||
|
type=str2bool,
|
||||||
|
default=False,
|
||||||
|
help="If True, use consistency-regularized CTC.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -380,6 +402,27 @@ def get_parser():
|
|||||||
with this parameter before adding to the final loss.""",
|
with this parameter before adding to the final loss.""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--ctc-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help="Scale for CTC loss.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--cr-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.2,
|
||||||
|
help="Scale for consistency-regularization loss.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--time-mask-ratio",
|
||||||
|
type=float,
|
||||||
|
default=2.5,
|
||||||
|
help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--seed",
|
"--seed",
|
||||||
type=int,
|
type=int,
|
||||||
@ -583,8 +626,13 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
def get_model(params: AttributeDict) -> nn.Module:
|
def get_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder_embed = get_encoder_embed(params)
|
encoder_embed = get_encoder_embed(params)
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
|
||||||
joiner = get_joiner_model(params)
|
if params.use_transducer:
|
||||||
|
decoder = get_decoder_model(params)
|
||||||
|
joiner = get_joiner_model(params)
|
||||||
|
else:
|
||||||
|
decoder = None
|
||||||
|
joiner = None
|
||||||
|
|
||||||
model = AsrModel(
|
model = AsrModel(
|
||||||
encoder_embed=encoder_embed,
|
encoder_embed=encoder_embed,
|
||||||
@ -594,9 +642,27 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
encoder_dim=int(max(params.encoder_dim.split(","))),
|
encoder_dim=int(max(params.encoder_dim.split(","))),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
|
use_transducer=params.use_transducer,
|
||||||
|
use_ctc=params.use_ctc,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
def get_spec_augment(params: AttributeDict) -> SpecAugment:
|
||||||
|
num_frame_masks = int(10 * params.time_mask_ratio)
|
||||||
|
max_frames_mask_fraction = 0.15 * params.time_mask_ratio
|
||||||
|
logging.info(
|
||||||
|
f"num_frame_masks: {num_frame_masks}, "
|
||||||
|
f"max_frames_mask_fraction: {max_frames_mask_fraction}"
|
||||||
|
)
|
||||||
|
spec_augment = SpecAugment(
|
||||||
|
time_warp_factor=0, # Do time warping in model.py
|
||||||
|
num_frame_masks=num_frame_masks, # default: 10
|
||||||
|
features_mask_size=27,
|
||||||
|
num_feature_masks=2,
|
||||||
|
frames_mask_size=100,
|
||||||
|
max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15
|
||||||
|
)
|
||||||
|
return spec_augment
|
||||||
|
|
||||||
def load_checkpoint_if_available(
|
def load_checkpoint_if_available(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
@ -723,6 +789,7 @@ def compute_loss(
|
|||||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
is_training: bool,
|
is_training: bool,
|
||||||
|
spec_augment: Optional[SpecAugment] = None,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute CTC loss given the model and its inputs.
|
Compute CTC loss given the model and its inputs.
|
||||||
@ -739,6 +806,8 @@ def compute_loss(
|
|||||||
True for training. False for validation. When it is True, this
|
True for training. False for validation. When it is True, this
|
||||||
function enables autograd during computation; when it is False, it
|
function enables autograd during computation; when it is False, it
|
||||||
disables autograd.
|
disables autograd.
|
||||||
|
spec_augment:
|
||||||
|
The SpecAugment instance used only when use_cr_ctc is True.
|
||||||
warmup: a floating point value which increases throughout training;
|
warmup: a floating point value which increases throughout training;
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
"""
|
"""
|
||||||
@ -758,6 +827,21 @@ def compute_loss(
|
|||||||
y = graph_compiler.texts_to_ids(texts)
|
y = graph_compiler.texts_to_ids(texts)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
|
use_cr_ctc = params.use_cr_ctc
|
||||||
|
use_spec_aug = use_cr_ctc and is_training
|
||||||
|
if use_spec_aug:
|
||||||
|
supervision_intervals = batch["supervisions"]
|
||||||
|
supervision_segments = torch.stack(
|
||||||
|
[
|
||||||
|
supervision_intervals["sequence_idx"],
|
||||||
|
supervision_intervals["start_frame"],
|
||||||
|
supervision_intervals["num_frames"],
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
) # shape: (S, 3)
|
||||||
|
else:
|
||||||
|
supervision_segments = None
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
losses = model(
|
losses = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
@ -766,25 +850,40 @@ def compute_loss(
|
|||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
|
use_cr_ctc=use_cr_ctc,
|
||||||
|
use_spec_aug=use_spec_aug,
|
||||||
|
spec_augment=spec_augment,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
time_warp_factor=params.spec_aug_time_warp_factor,
|
||||||
)
|
)
|
||||||
simple_loss, pruned_loss = losses[:2]
|
if params.use_ctc:
|
||||||
|
simple_loss, pruned_loss, ctc_loss, _, cr_loss = losses[:5]
|
||||||
|
else:
|
||||||
|
simple_loss, pruned_loss = losses[:2]
|
||||||
|
|
||||||
s = params.simple_loss_scale
|
loss = 0.0
|
||||||
# take down the scale on the simple loss from 1.0 at the start
|
|
||||||
# to params.simple_loss scale by warm_step.
|
|
||||||
simple_loss_scale = (
|
|
||||||
s
|
|
||||||
if batch_idx_train >= warm_step
|
|
||||||
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
|
||||||
)
|
|
||||||
pruned_loss_scale = (
|
|
||||||
1.0
|
|
||||||
if batch_idx_train >= warm_step
|
|
||||||
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
if params.use_transducer:
|
||||||
|
s = params.simple_loss_scale
|
||||||
|
# take down the scale on the simple loss from 1.0 at the start
|
||||||
|
# to params.simple_loss scale by warm_step.
|
||||||
|
simple_loss_scale = (
|
||||||
|
s
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
|
else 1.0 - (batch_idx_train / warm_step) * (1.0 - s)
|
||||||
|
)
|
||||||
|
pruned_loss_scale = (
|
||||||
|
1.0
|
||||||
|
if batch_idx_train >= warm_step
|
||||||
|
else 0.1 + 0.9 * (batch_idx_train / warm_step)
|
||||||
|
)
|
||||||
|
loss += simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss
|
||||||
|
|
||||||
|
if params.use_ctc:
|
||||||
|
loss += params.ctc_loss_scale * ctc_loss
|
||||||
|
if use_cr_ctc:
|
||||||
|
loss += params.cr_loss_scale * cr_loss
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
info = MetricsTracker()
|
info = MetricsTracker()
|
||||||
@ -794,8 +893,13 @@ def compute_loss(
|
|||||||
|
|
||||||
# Note: We use reduction=sum while computing the loss.
|
# Note: We use reduction=sum while computing the loss.
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
if params.use_transducer:
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||||
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
|
if params.use_ctc:
|
||||||
|
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||||
|
if params.use_cr_ctc:
|
||||||
|
info["cr_loss"] = cr_loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -843,6 +947,7 @@ def train_one_epoch(
|
|||||||
train_dl: torch.utils.data.DataLoader,
|
train_dl: torch.utils.data.DataLoader,
|
||||||
valid_dl: torch.utils.data.DataLoader,
|
valid_dl: torch.utils.data.DataLoader,
|
||||||
scaler: GradScaler,
|
scaler: GradScaler,
|
||||||
|
spec_augment: Optional[SpecAugment] = None,
|
||||||
model_avg: Optional[nn.Module] = None,
|
model_avg: Optional[nn.Module] = None,
|
||||||
tb_writer: Optional[SummaryWriter] = None,
|
tb_writer: Optional[SummaryWriter] = None,
|
||||||
world_size: int = 1,
|
world_size: int = 1,
|
||||||
@ -869,6 +974,8 @@ def train_one_epoch(
|
|||||||
Dataloader for the validation dataset.
|
Dataloader for the validation dataset.
|
||||||
scaler:
|
scaler:
|
||||||
The scaler used for mix precision training.
|
The scaler used for mix precision training.
|
||||||
|
spec_augment:
|
||||||
|
The SpecAugment instance used only when use_cr_ctc is True.
|
||||||
model_avg:
|
model_avg:
|
||||||
The stored model averaged from the start of training.
|
The stored model averaged from the start of training.
|
||||||
tb_writer:
|
tb_writer:
|
||||||
@ -918,6 +1025,7 @@ def train_one_epoch(
|
|||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
|
spec_augment=spec_augment,
|
||||||
)
|
)
|
||||||
# summary stats
|
# summary stats
|
||||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||||
@ -1083,6 +1191,9 @@ def run(rank, world_size, args):
|
|||||||
params.blank_id = lexicon.token_table["<blk>"]
|
params.blank_id = lexicon.token_table["<blk>"]
|
||||||
params.vocab_size = max(lexicon.tokens) + 1
|
params.vocab_size = max(lexicon.tokens) + 1
|
||||||
|
|
||||||
|
if not params.use_transducer:
|
||||||
|
params.ctc_loss_scale = 1.0
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
@ -1091,6 +1202,12 @@ def run(rank, world_size, args):
|
|||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
|
if params.use_cr_ctc:
|
||||||
|
assert not params.enable_spec_aug # we will do spec_augment in model.py
|
||||||
|
spec_augment = get_spec_augment(params)
|
||||||
|
else:
|
||||||
|
spec_augment = None
|
||||||
|
|
||||||
assert params.save_every_n >= params.average_period
|
assert params.save_every_n >= params.average_period
|
||||||
model_avg: Optional[nn.Module] = None
|
model_avg: Optional[nn.Module] = None
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
@ -1200,6 +1317,7 @@ def run(rank, world_size, args):
|
|||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
params=params,
|
params=params,
|
||||||
|
spec_augment=spec_augment,
|
||||||
)
|
)
|
||||||
|
|
||||||
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
|
||||||
@ -1227,6 +1345,7 @@ def run(rank, world_size, args):
|
|||||||
train_dl=train_dl,
|
train_dl=train_dl,
|
||||||
valid_dl=valid_dl,
|
valid_dl=valid_dl,
|
||||||
scaler=scaler,
|
scaler=scaler,
|
||||||
|
spec_augment=spec_augment,
|
||||||
tb_writer=tb_writer,
|
tb_writer=tb_writer,
|
||||||
world_size=world_size,
|
world_size=world_size,
|
||||||
rank=rank,
|
rank=rank,
|
||||||
@ -1293,6 +1412,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
optimizer: torch.optim.Optimizer,
|
optimizer: torch.optim.Optimizer,
|
||||||
graph_compiler: CharCtcTrainingGraphCompiler,
|
graph_compiler: CharCtcTrainingGraphCompiler,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
|
spec_augment: Optional[SpecAugment] = None,
|
||||||
):
|
):
|
||||||
from lhotse.dataset import find_pessimistic_batches
|
from lhotse.dataset import find_pessimistic_batches
|
||||||
|
|
||||||
@ -1310,6 +1430,7 @@ def scan_pessimistic_batches_for_oom(
|
|||||||
graph_compiler=graph_compiler,
|
graph_compiler=graph_compiler,
|
||||||
batch=batch,
|
batch=batch,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
|
spec_augment=spec_augment,
|
||||||
)
|
)
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user