diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
index c68538e1f..f2c85f2fa 100755
--- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
+++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py
@@ -26,8 +26,8 @@ from lhotse import (
CutSet,
WhisperFbank,
WhisperFbankConfig,
- KaldifeatWhisperFbank,
- KaldifeatWhisperFbankConfig,
+ # KaldifeatWhisperFbank,
+ # KaldifeatWhisperFbankConfig,
KaldifeatFbank,
KaldifeatFbankConfig,
LilcomChunkyWriter,
diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh
index 876daea11..6fe03ed93 100755
--- a/egs/wenetspeech/ASR/prepare.sh
+++ b/egs/wenetspeech/ASR/prepare.sh
@@ -211,29 +211,13 @@ if [ $stage -le 130 ] && [ $stop_stage -ge 130 ]; then
fi
if [ $stage -le 131 ] && [ $stop_stage -ge 131 ]; then
- log "Stage 131: test"
-
- python3 ./local/compute_fbank_wenetspeech_splits.py \
- --training-subset L \
- --num-workers 8 \
- --batch-duration 1000 \
- --start 48 \
- --stop 68 \
- --num-mel-bins ${whisper_mel_bins} --whisper-fbank true \
- --num-splits $num_splits
+ log "Stage 131: concat feats into train set"
+ if [ ! -f data/fbank/cuts_L.jsonl.gz ]; then
+ pieces=$(find data/fbank/L_split_1000 -name "cuts_L.*.jsonl.gz")
+ lhotse combine $pieces data/fbank/cuts_L.jsonl.gz
+ fi
fi
-if [ $stage -le 132 ] && [ $stop_stage -ge 132 ]; then
- log "Stage 132: test"
-
- python3 ./local/compute_fbank_wenetspeech_splits.py \
- --training-subset L \
- --num-workers 8 \
- --batch-duration 1000 \
- --start 68 \
- --num-mel-bins ${whisper_mel_bins} --whisper-fbank true \
- --num-splits $num_splits
-fi
if [ $stage -le 14 ] && [ $stop_stage -ge 14 ]; then
log "Stage 14: Compute fbank for musan"
diff --git a/egs/wenetspeech/ASR/whisper/decode.py b/egs/wenetspeech/ASR/whisper/decode.py
index 55f89ddb1..292e162af 100755
--- a/egs/wenetspeech/ASR/whisper/decode.py
+++ b/egs/wenetspeech/ASR/whisper/decode.py
@@ -2,6 +2,7 @@
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
# Fangjun Kuang,
# Wei Kang)
+# 2024 Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@@ -16,47 +17,64 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+"""
+Usage:
+# Command for decoding using fine-tuned models:
+git lfs install
+git clone https://huggingface.co/yuekai/icefall_asr_aishell_whisper
+ln -s icefall_asr_aishell_whisper/exp_large_v2/epoch-10-avg6.pt whisper/exp_large_v2/epoch-999.pt
+
+python3 ./whisper/decode.py \
+ --exp-dir whisper/exp_large_v2 \
+ --model-name large-v2 \
+ --epoch 999 --avg 1 \
+ --beam-size 10 --max-duration 50
+
+# Command for decoding using pretrained models (before fine-tuning):
+
+python3 ./whisper/decode.py \
+ --exp-dir whisper/exp_large_v2 \
+ --model-name large-v2 \
+ --epoch -1 --avg 1 \
+ --remove-whisper-encoder-input-length-restriction False \
+ --beam-size 10 --max-duration 50
+
+"""
import argparse
import logging
+import re
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
-import whisper
-from whisper.normalizers import BasicTextNormalizer
import k2
import torch
import torch.nn as nn
-from asr_datamodule import WenetSpeechAsrDataModule
-from model import load_model
+import whisper
-from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model
-from icefall.decode import (
- get_lattice,
- nbest_decoding,
- nbest_oracle,
- one_best_decoding,
- rescore_with_attention_decoder,
-)
-from lhotse.cut import Cut
+from asr_datamodule import WenetSpeechAsrDataModule
+from tn.chinese.normalizer import Normalizer
+from whisper.normalizers import BasicTextNormalizer
+from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
+from zhconv import convert
+
+from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
from icefall.env import get_env_info
-from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
- get_texts,
setup_logger,
store_transcripts,
+ str2bool,
write_error_stats,
)
-from zhconv import convert
-from tn.chinese.normalizer import Normalizer
-import re
+
def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu")
) -> dict:
"""Average a list of checkpoints.
+ The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict.
Args:
filenames:
@@ -71,9 +89,9 @@ def average_checkpoints(
n = len(filenames)
if "model" in torch.load(filenames[0], map_location=device):
- avg = torch.load(filenames[0], map_location=device)["model"]
+ avg = torch.load(filenames[0], map_location=device)["model"]
else:
- avg = torch.load(filenames[0], map_location=device)
+ avg = torch.load(filenames[0], map_location=device)
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
@@ -89,9 +107,9 @@ def average_checkpoints(
for i in range(1, n):
if "model" in torch.load(filenames[i], map_location=device):
- state_dict = torch.load(filenames[i], map_location=device)["model"]
+ state_dict = torch.load(filenames[i], map_location=device)["model"]
else:
- state_dict = torch.load(filenames[i], map_location=device)
+ state_dict = torch.load(filenames[i], map_location=device)
for k in uniqued_names:
avg[k] += state_dict[k]
@@ -103,33 +121,48 @@ def average_checkpoints(
return avg
+
def remove_punctuation(text: str or List[str]):
- # https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
- punctuation = '!,.;:?、!,。;:?'
+ """Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
+
+ Args:
+ text: It can be a string or a list of strings.
+ Returns:
+ Return a string or a list of strings without any punctuation.
+ """
+ punctuation = "!,.;:?、!,。;:?《》 "
if isinstance(text, str):
- text = re.sub(r'[{}]+'.format(punctuation), '', text).strip()
+ text = re.sub(r"[{}]+".format(punctuation), "", text).strip()
return text
elif isinstance(text, list):
result_text = []
for t in text:
- t = re.sub(r'[{}]+'.format(punctuation), '', t).strip()
+ t = re.sub(r"[{}]+".format(punctuation), "", t).strip()
result_text.append(t)
return result_text
else:
- raise Exception(f'Not support type {type(text)}')
+ raise Exception(f"Not support type {type(text)}")
+
def to_simple(text: str or List[str]):
+ """Convert traditional Chinese to simplified Chinese.
+ Args:
+ text: It can be a string or a list of strings.
+ Returns:
+ Return a string or a list of strings converted to simplified Chinese.
+ """
if isinstance(text, str):
- text = convert(text, 'zh-cn')
+ text = convert(text, "zh-cn")
return text
elif isinstance(text, list):
result_text = []
for t in text:
- t = convert(t, 'zh-cn')
+ t = convert(t, "zh-cn")
result_text.append(t)
return result_text
else:
- raise Exception(f'Not support type{type(text)}')
+ raise Exception(f"Not support type{type(text)}")
+
def get_parser():
parser = argparse.ArgumentParser(
@@ -184,7 +217,14 @@ def get_parser():
help="""The model name to use.
""",
)
-
+
+ parser.add_argument(
+ "--remove-whisper-encoder-input-length-restriction",
+ type=str2bool,
+ default=True,
+ help="replace whisper encoder forward method to remove input length restriction",
+ )
+
return parser
@@ -196,6 +236,7 @@ def get_params() -> AttributeDict:
)
return params
+
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
@@ -204,42 +245,17 @@ def decode_one_batch(
"""Decode one batch and return the result in a dict. The dict has the
following format:
- - key: It indicates the setting used for decoding. For example,
- if decoding method is 1best, the key is the string `no_rescore`.
- If attention rescoring is used, the key is the string
- `ngram_lm_scale_xxx_attention_scale_xxx`, where `xxx` is the
- value of `lm_scale` and `attention_scale`. An example key is
- `ngram_lm_scale_0.7_attention_scale_0.5`
- - value: It contains the decoding result. `len(value)` equals to
- batch size. `value[i]` is the decoding result for the i-th
- utterance in the given batch.
+ - key: "beam-search"
+ - value: A list of lists. Each sublist is a list of token IDs.
Args:
- params:
- It's the return value of :func:`get_params`.
-
- - params.method is "1best", it uses 1best decoding without LM rescoring.
- - params.method is "nbest", it uses nbest decoding without LM rescoring.
- - params.method is "attention-decoder", it uses attention rescoring.
-
- model:
- The neural model.
- HLG:
- The decoding graph. Used when params.method is NOT ctc-decoding.
- H:
- The ctc topo. Used only when params.method is ctc-decoding.
- batch:
- It is the return value from iterating
- `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
- for the format of the `batch`.
- lexicon:
- It contains the token symbol table and the word symbol table.
- sos_id:
- The token ID of the SOS.
- eos_id:
- The token ID of the EOS.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
+ batch:
+ It is returned by :meth:`torch.utils.data.DataLoader.__iter__`.
Returns:
- Return the decoding result. See above description for the format of
- the returned dict.
+ Return a dict, whose key may be "beam-search".
"""
dtype = torch.float16
device = torch.device("cuda")
@@ -247,21 +263,30 @@ def decode_one_batch(
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device, dtype=dtype).transpose(1, 2)
+ if not params.remove_whisper_encoder_input_length_restriction:
+ T = 3000
+ if feature.shape[2] < T:
+ feature = torch.cat(
+ [
+ feature,
+ torch.zeros(
+ feature.shape[0], feature.shape[1], T - feature.shape[2]
+ ).to(device, dtype=dtype),
+ ],
+ 2,
+ )
supervisions = batch["supervisions"]
feature_len = supervisions["num_frames"]
feature_len = feature_len.to(device, dtype=dtype)
results = model.decode(feature, params.decoding_options)
hyps = [result.text for result in results]
-
+
hyps = remove_punctuation(hyps)
hyps = to_simple(hyps)
-
hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
- print(hyps)
- key = "beam-search"
- return {key: hyps}
+ return {"beam-search": hyps}
def decode_dataset(
@@ -272,28 +297,14 @@ def decode_dataset(
"""Decode dataset.
Args:
- dl:
- PyTorch's dataloader containing the dataset to decode.
- params:
- It is returned by :func:`get_params`.
- model:
- The neural model.
- HLG:
- The decoding graph. Used when params.method is NOT ctc-decoding.
- H:
- The ctc topo. Used only when params.method is ctc-decoding.
- lexicon:
- It contains the token symbol table and the word symbol table.
- sos_id:
- The token ID for SOS.
- eos_id:
- The token ID for EOS.
+ dl:
+ The dataloader.
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The neural model.
Returns:
- Return a dict, whose key may be "no-rescore" if the decoding method is
- 1best or it may be "ngram_lm_scale_0.7_attention_scale_0.5" if attention
- rescoring is used. Its value is a list of tuples. Each tuple contains two
- elements: The first is the reference transcript, and the second is the
- predicted result.
+ Return a dict, whose key may be "beam-search".
"""
results = []
@@ -342,7 +353,9 @@ def save_results(
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
- recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ recog_path = (
+ params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
@@ -350,7 +363,9 @@ def save_results(
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
- errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ errs_filename = (
+ params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
+ )
# we compute CER for aishell dataset.
results_char = []
for res in results:
@@ -382,20 +397,27 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
- WenetSpeechAsrDataModule.add_arguments(parser)
+ AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
params = get_params()
params.update(vars(args))
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
- setup_logger(f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}")
+ setup_logger(
+ f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
+ )
- options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=params.beam_size)
+ options = whisper.DecodingOptions(
+ task="transcribe",
+ language="zh",
+ without_timestamps=True,
+ beam_size=params.beam_size,
+ )
params.decoding_options = options
params.cleaner = BasicTextNormalizer()
params.normalizer = Normalizer()
-
+
logging.info("Decoding started")
logging.info(params)
@@ -405,39 +427,49 @@ def main():
logging.info(f"device: {device}")
- model = load_model(params.model_name)
+ if params.remove_whisper_encoder_input_length_restriction:
+ replace_whisper_encoder_forward()
+ model = whisper.load_model(params.model_name, "cpu")
if params.epoch > 0:
- if params.avg > 1:
- start = params.epoch - params.avg
- assert start >= 1, start
- checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
- if 'model' not in checkpoint:
- filenames = [f"{params.exp_dir}/epoch-{epoch}.pt" for epoch in range(start, params.epoch + 1)]
- model.load_state_dict(average_checkpoints(filenames))
- else:
- filename_start = f"{params.exp_dir}/epoch-{start}.pt"
- filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
- logging.info(
- f"Calculating the averaged model over epoch range from "
- f"{start} (excluded) to {params.epoch}"
+ if params.avg > 1:
+ start = params.epoch - params.avg
+ assert start >= 1, start
+ checkpoint = torch.load(
+ f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
- model.to(device)
- model.load_state_dict(
- average_checkpoints_with_averaged_model(
- filename_start=filename_start,
- filename_end=filename_end,
- device=device,
+ if "model" not in checkpoint:
+ # deepspeed converted checkpoint only contains model state_dict
+ filenames = [
+ f"{params.exp_dir}/epoch-{epoch}.pt"
+ for epoch in range(start, params.epoch + 1)
+ ]
+ model.load_state_dict(average_checkpoints(filenames))
+ else:
+ filename_start = f"{params.exp_dir}/epoch-{start}.pt"
+ filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
+ logging.info(
+ f"Calculating the averaged model over epoch range from "
+ f"{start} (excluded) to {params.epoch}"
)
- )
- # save checkpoints
- filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
- torch.save(model.state_dict(), filename)
- else:
- checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
- if 'model' not in checkpoint:
- model.load_state_dict(checkpoint, strict=True)
+ model.to(device)
+ model.load_state_dict(
+ average_checkpoints_with_averaged_model(
+ filename_start=filename_start,
+ filename_end=filename_end,
+ device=device,
+ )
+ )
+ # save checkpoints
+ filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
+ torch.save(model.state_dict(), filename)
else:
- load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
+ checkpoint = torch.load(
+ f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
+ )
+ if "model" not in checkpoint:
+ model.load_state_dict(checkpoint, strict=True)
+ else:
+ load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])
@@ -446,25 +478,13 @@ def main():
# we need cut ids to display recognition results.
args.return_cuts = True
wenetspeech = WenetSpeechAsrDataModule(args)
+ dev_cuts = wenetspeech.valid_cuts()
+ dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
- def remove_short_utt(c: Cut):
- T = ((c.num_frames - 7) // 2 + 1) // 2
- if T <= 0:
- logging.warning(
- f"Exclude cut with ID {c.id} from decoding, num_frames : {c.num_frames}."
- )
- return T > 0
-
- # dev_cuts = wenetspeech.valid_cuts()
- # dev_cuts = dev_cuts.filter(remove_short_utt)
- # dev_dl = wenetspeech.valid_dataloaders(dev_cuts)
-
- # test_net_cuts = wenetspeech.test_net_cuts()
- # test_net_cuts = test_net_cuts.filter(remove_short_utt)
- # test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
+ test_net_cuts = wenetspeech.test_net_cuts()
+ test_net_dl = wenetspeech.test_dataloaders(test_net_cuts)
test_meeting_cuts = wenetspeech.test_meeting_cuts()
- test_meeting_cuts = test_meeting_cuts.filter(remove_short_utt)
test_meeting_dl = wenetspeech.test_dataloaders(test_meeting_cuts)
# test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
diff --git a/egs/wenetspeech/ASR/whisper/ds_config_zero1.json b/egs/wenetspeech/ASR/whisper/ds_config_zero1.json
new file mode 100644
index 000000000..bf8cc0452
--- /dev/null
+++ b/egs/wenetspeech/ASR/whisper/ds_config_zero1.json
@@ -0,0 +1,38 @@
+{
+ "fp16": {
+ "enabled": true,
+ "loss_scale": 0,
+ "loss_scale_window": 100,
+ "initial_scale_power": 16,
+ "hysteresis": 2,
+ "min_loss_scale": 0.01
+ },
+ "zero_optimization": {
+ "stage": 1,
+ "allgather_partitions": true,
+ "allgather_bucket_size": 2e8,
+ "overlap_comm": true,
+ "reduce_scatter": true,
+ "reduce_bucket_size": 2e8,
+ "contiguous_gradients": true
+ },
+ "optimizer": {
+ "type": "Adam",
+ "params": {
+ "lr": 1e-5
+ }
+ },
+ "scheduler": {
+ "type": "WarmupLR",
+ "params": {
+ "warmup_min_lr": 0,
+ "warmup_max_lr": 1e-5,
+ "warmup_num_steps": 100
+ }
+ },
+ "gradient_accumulation_steps": 1,
+ "gradient_clipping": 5,
+ "steps_per_print": 50,
+ "train_micro_batch_size_per_gpu": 1,
+ "wall_clock_breakdown": false
+}
diff --git a/egs/wenetspeech/ASR/whisper/label_smoothing.py b/egs/wenetspeech/ASR/whisper/label_smoothing.py
new file mode 100644
index 000000000..52d2eda3b
--- /dev/null
+++ b/egs/wenetspeech/ASR/whisper/label_smoothing.py
@@ -0,0 +1,109 @@
+# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+
+class LabelSmoothingLoss(torch.nn.Module):
+ """
+ Implement the LabelSmoothingLoss proposed in the following paper
+ https://arxiv.org/pdf/1512.00567.pdf
+ (Rethinking the Inception Architecture for Computer Vision)
+
+ """
+
+ def __init__(
+ self,
+ ignore_index: int = -1,
+ label_smoothing: float = 0.1,
+ reduction: str = "sum",
+ ) -> None:
+ """
+ Args:
+ ignore_index:
+ ignored class id
+ label_smoothing:
+ smoothing rate (0.0 means the conventional cross entropy loss)
+ reduction:
+ It has the same meaning as the reduction in
+ `torch.nn.CrossEntropyLoss`. It can be one of the following three
+ values: (1) "none": No reduction will be applied. (2) "mean": the
+ mean of the output is taken. (3) "sum": the output will be summed.
+ """
+ super().__init__()
+ assert 0.0 <= label_smoothing < 1.0, f"{label_smoothing}"
+ assert reduction in ("none", "sum", "mean"), reduction
+ self.ignore_index = ignore_index
+ self.label_smoothing = label_smoothing
+ self.reduction = reduction
+
+ def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ """
+ Compute loss between x and target.
+
+ Args:
+ x:
+ prediction of dimension
+ (batch_size, input_length, number_of_classes).
+ target:
+ target masked with self.ignore_index of
+ dimension (batch_size, input_length).
+
+ Returns:
+ A scalar tensor containing the loss without normalization.
+ """
+ assert x.ndim == 3
+ assert target.ndim == 2
+ assert x.shape[:2] == target.shape
+ num_classes = x.size(-1)
+ x = x.reshape(-1, num_classes)
+ # Now x is of shape (N*T, C)
+
+ # We don't want to change target in-place below,
+ # so we make a copy of it here
+ target = target.clone().reshape(-1)
+
+ ignored = target == self.ignore_index
+
+ # See https://github.com/k2-fsa/icefall/issues/240
+ # and https://github.com/k2-fsa/icefall/issues/297
+ # for why we don't use target[ignored] = 0 here
+ target = torch.where(ignored, torch.zeros_like(target), target)
+
+ true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x)
+
+ true_dist = (
+ true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes
+ )
+
+ # Set the value of ignored indexes to 0
+ #
+ # See https://github.com/k2-fsa/icefall/issues/240
+ # and https://github.com/k2-fsa/icefall/issues/297
+ # for why we don't use true_dist[ignored] = 0 here
+ true_dist = torch.where(
+ ignored.unsqueeze(1).repeat(1, true_dist.shape[1]),
+ torch.zeros_like(true_dist),
+ true_dist,
+ )
+
+ loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
+ if self.reduction == "sum":
+ return loss.sum()
+ elif self.reduction == "mean":
+ return loss.sum() / (~ignored).sum()
+ else:
+ return loss.sum(dim=-1)
diff --git a/egs/wenetspeech/ASR/whisper/optim.py b/egs/wenetspeech/ASR/whisper/optim.py
new file mode 100644
index 000000000..714d8db9a
--- /dev/null
+++ b/egs/wenetspeech/ASR/whisper/optim.py
@@ -0,0 +1,1248 @@
+# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
+#
+# See ../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import contextlib
+import logging
+import random
+from collections import defaultdict
+from typing import Dict, List, Optional, Tuple, Union
+
+import torch
+from lhotse.utils import fix_random_seed
+from torch import Tensor, nn
+from torch.optim import Optimizer
+
+
+class BatchedOptimizer(Optimizer):
+ """
+ This class adds to class Optimizer the capability to optimize parameters in batches:
+ it will stack the parameters and their grads for you so the optimizer can work
+ on tensors with an extra leading dimension. This is intended for speed with GPUs,
+ as it reduces the number of kernels launched in the optimizer.
+
+ Args:
+ params:
+ """
+
+ def __init__(self, params, defaults):
+ super(BatchedOptimizer, self).__init__(params, defaults)
+
+ @contextlib.contextmanager
+ def batched_params(self, param_group, group_params_names):
+ """
+ This function returns (technically, yields) a list of
+ of tuples (p, state), where
+ p is a `fake` parameter that is stacked (over axis 0) from real parameters
+ that share the same shape, and its gradient is also stacked;
+ `state` is the state corresponding to this batch of parameters
+ (it will be physically located in the "state" for one of the real
+ parameters, the last one that has any particular shape and dtype).
+
+ This function is decorated as a context manager so that it can
+ write parameters back to their "real" locations.
+
+ The idea is, instead of doing:
+
+ for p in group["params"]:
+ state = self.state[p]
+ ...
+
+ you can do:
+
+ with self.batched_params(group["params"]) as batches:
+ for p, state, p_names in batches:
+ ...
+
+
+ Args:
+ group: a parameter group, which is a list of parameters; should be
+ one of self.param_groups.
+ group_params_names: name for each parameter in group,
+ which is List[str].
+ """
+ batches = defaultdict(
+ list
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
+ batches_names = defaultdict(
+ list
+ ) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
+
+ assert len(param_group) == len(group_params_names)
+ for p, named_p in zip(param_group, group_params_names):
+ key = (str(p.dtype), *p.shape)
+ batches[key].append(p)
+ batches_names[key].append(named_p)
+
+ batches_names_keys = list(batches_names.keys())
+ sorted_idx = sorted(
+ range(len(batches_names)), key=lambda i: batches_names_keys[i]
+ )
+ batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
+ batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
+
+ stacked_params_dict = dict()
+
+ # turn batches into a list, in deterministic order.
+ # tuples will contain tuples of (stacked_param, state, stacked_params_names),
+ # one for each batch in `batches`.
+ tuples = []
+
+ for batch, batch_names in zip(batches, batches_names):
+ p = batch[0]
+ # we arbitrarily store the state in the
+ # state corresponding to the 1st parameter in the
+ # group. class Optimizer will take care of saving/loading state.
+ state = self.state[p]
+ p_stacked = torch.stack(batch)
+ grad = torch.stack(
+ [torch.zeros_like(p) if p.grad is None else p.grad for p in batch]
+ )
+ p_stacked.grad = grad
+ stacked_params_dict[key] = p_stacked
+ tuples.append((p_stacked, state, batch_names))
+
+ yield tuples # <-- calling code will do the actual optimization here!
+
+ for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
+ for i, p in enumerate(batch): # batch is list of Parameter
+ p.copy_(stacked_params[i])
+
+
+class ScaledAdam(BatchedOptimizer):
+ """
+ Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update
+ proportional to the norm of that parameter; and also learn the scale of the parameter,
+ in log space, subject to upper and lower limits (as if we had factored each parameter as
+ param = underlying_param * log_scale.exp())
+
+
+ Args:
+ params: The parameters or param_groups to optimize (like other Optimizer subclasses)
+ Unlike common optimizers, which accept model.parameters() or groups of parameters(),
+ this optimizer could accept model.named_parameters() or groups of named_parameters().
+ See comments of function _get_names_of_parameters for its 4 possible cases.
+ lr: The learning rate. We will typically use a learning rate schedule that starts
+ at 0.03 and decreases over time, i.e. much higher than other common
+ optimizers.
+ clipping_scale: (e.g. 2.0)
+ A scale for gradient-clipping: if specified, the normalized gradients
+ over the whole model will be clipped to have 2-norm equal to
+ `clipping_scale` times the median 2-norm over the most recent period
+ of `clipping_update_period` minibatches. By "normalized gradients",
+ we mean after multiplying by the rms parameter value for this tensor
+ [for non-scalars]; this is appropriate because our update is scaled
+ by this quantity.
+ betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad.
+ Must satisfy 0 < beta <= beta2 < 1.
+ scalar_lr_scale: A scaling factor on the learning rate, that we use to update the
+ scale of each parameter tensor and scalar parameters of the mode..
+ If each parameter were decomposed
+ as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale
+ would be a the scaling factor on the learning rate of p_scale.
+ eps: A general-purpose epsilon to prevent division by zero
+ param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
+ parameter tensor to be >= this value)
+ param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of
+ learning the scale on the parameters (we'll constrain the rms of each non-scalar
+ parameter tensor to be <= this value)
+ scalar_max: Maximum absolute value for scalar parameters (applicable if your
+ model has any parameters with numel() == 1).
+ size_update_period: The periodicity, in steps, with which we update the size (scale)
+ of the parameter tensor. This is provided to save a little time
+ in the update.
+ clipping_update_period: if clipping_scale is specified, this is the period
+ """
+
+ def __init__(
+ self,
+ params,
+ lr=3e-02,
+ clipping_scale=None,
+ betas=(0.9, 0.98),
+ scalar_lr_scale=0.1,
+ eps=1.0e-08,
+ param_min_rms=1.0e-05,
+ param_max_rms=3.0,
+ scalar_max=10.0,
+ size_update_period=4,
+ clipping_update_period=100,
+ ):
+
+ defaults = dict(
+ lr=lr,
+ clipping_scale=clipping_scale,
+ betas=betas,
+ scalar_lr_scale=scalar_lr_scale,
+ eps=eps,
+ param_min_rms=param_min_rms,
+ param_max_rms=param_max_rms,
+ scalar_max=scalar_max,
+ size_update_period=size_update_period,
+ clipping_update_period=clipping_update_period,
+ )
+
+ # If params only contains parameters or group of parameters,
+ # i.e when parameter names are not given,
+ # this flag will be set to False in funciton _get_names_of_parameters.
+ self.show_dominant_parameters = True
+ param_groups, parameters_names = self._get_names_of_parameters(params)
+ super(ScaledAdam, self).__init__(param_groups, defaults)
+ assert len(self.param_groups) == len(parameters_names)
+ self.parameters_names = parameters_names
+
+ def _get_names_of_parameters(
+ self, params_or_named_params
+ ) -> Tuple[List[Dict], List[List[str]]]:
+ """
+ Args:
+ params_or_named_params: according to the way ScaledAdam is initialized in train.py,
+ this argument could be one of following 4 cases,
+ case 1, a generator of parameter, e.g.:
+ optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=3.0)
+
+ case 2, a list of parameter groups with different config, e.g.:
+ model_param_groups = [
+ {'params': model.encoder.parameters(), 'lr': 0.05},
+ {'params': model.decoder.parameters(), 'lr': 0.01},
+ {'params': model.joiner.parameters(), 'lr': 0.03},
+ ]
+ optimizer = ScaledAdam(model_param_groups, lr=params.base_lr, clipping_scale=3.0)
+
+ case 3, a generator of named_parameter, e.g.:
+ optimizer = ScaledAdam(model.named_parameters(), lr=params.base_lr, clipping_scale=3.0)
+
+ case 4, a list of named_parameter groups with different config, e.g.:
+ model_named_param_groups = [
+ {'named_params': model.encoder.named_parameters(), 'lr': 0.05},
+ {'named_params': model.decoder.named_parameters(), 'lr': 0.01},
+ {'named_params': model.joiner.named_parameters(), 'lr': 0.03},
+ ]
+ optimizer = ScaledAdam(model_named_param_groups, lr=params.base_lr, clipping_scale=3.0)
+
+ For case 1 and case 2, input params is used to initialize the underlying torch.optimizer.
+ For case 3 and case 4, firstly, names and params are extracted from input named_params,
+ then, these extracted params are used to initialize the underlying torch.optimizer,
+ and these extracted names are mainly used by function
+ `_show_gradient_dominating_parameter`
+
+ Returns:
+ Returns a tuple containing 2 elements:
+ - `param_groups` with type List[Dict], each Dict element is a parameter group.
+ An example of `param_groups` could be:
+ [
+ {'params': `one iterable of Parameter`, 'lr': 0.05},
+ {'params': `another iterable of Parameter`, 'lr': 0.08},
+ {'params': `a third iterable of Parameter`, 'lr': 0.1},
+ ]
+ - `param_gruops_names` with type List[List[str]],
+ each `List[str]` is for a group['params'] in param_groups,
+ and each `str` is the name of a parameter.
+ A dummy name "foo" is related to each parameter,
+ if input are params without names, i.e. case 1 or case 2.
+ """
+ # variable naming convention in this function:
+ # p is short for param.
+ # np is short for named_param.
+ # p_or_np is short for param_or_named_param.
+ # cur is short for current.
+ # group is a dict, e.g. {'params': iterable of parameter, 'lr': 0.05, other fields}.
+ # groups is a List[group]
+
+ iterable_or_groups = list(params_or_named_params)
+ if len(iterable_or_groups) == 0:
+ raise ValueError("optimizer got an empty parameter list")
+
+ # The first value of returned tuple. A list of dicts containing at
+ # least 'params' as a key.
+ param_groups = []
+
+ # The second value of returned tuple,
+ # a List[List[str]], each sub-List is for a group.
+ param_groups_names = []
+
+ if not isinstance(iterable_or_groups[0], dict):
+ # case 1 or case 3,
+ # the input is an iterable of parameter or named parameter.
+ param_iterable_cur_group = []
+ param_names_cur_group = []
+ for p_or_np in iterable_or_groups:
+ if isinstance(p_or_np, tuple):
+ # case 3
+ name, param = p_or_np
+ else:
+ # case 1
+ assert isinstance(p_or_np, torch.Tensor)
+ param = p_or_np
+ # Assign a dummy name as a placeholder
+ name = "foo"
+ self.show_dominant_parameters = False
+ param_iterable_cur_group.append(param)
+ param_names_cur_group.append(name)
+ param_groups.append({"params": param_iterable_cur_group})
+ param_groups_names.append(param_names_cur_group)
+ else:
+ # case 2 or case 4
+ # the input is groups of parameter or named parameter.
+ for cur_group in iterable_or_groups:
+ assert "named_params" in cur_group
+ name_list = [x[0] for x in cur_group["named_params"]]
+ p_list = [x[1] for x in cur_group["named_params"]]
+ del cur_group["named_params"]
+ cur_group["params"] = p_list
+ param_groups.append(cur_group)
+ param_groups_names.append(name_list)
+
+ return param_groups, param_groups_names
+
+ def __setstate__(self, state):
+ super(ScaledAdam, self).__setstate__(state)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ batch = True
+
+ for group, group_params_names in zip(self.param_groups, self.parameters_names):
+
+ with self.batched_params(group["params"], group_params_names) as batches:
+
+ # batches is list of pairs (stacked_param, state). stacked_param is like
+ # a regular parameter, and will have a .grad, but the 1st dim corresponds to
+ # a stacking dim, it is not a real dim.
+
+ if (
+ len(batches[0][1]) == 0
+ ): # if len(first state) == 0: not yet initialized
+ clipping_scale = 1
+ else:
+ clipping_scale = self._get_clipping_scale(group, batches)
+
+ for p, state, _ in batches:
+ # Perform optimization step.
+ # grad is not going to be None, we handled that when creating the batches.
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError(
+ "ScaledAdam optimizer does not support sparse gradients"
+ )
+ # State initialization
+ if len(state) == 0:
+ self._init_state(group, p, state)
+
+ self._step_one_batch(group, p, state, clipping_scale)
+
+ return loss
+
+ def _init_state(self, group: dict, p: Tensor, state: dict):
+ """
+ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p
+ is actually the batch dimension, corresponding to batched-together
+ parameters of a given shape.
+
+
+ Args:
+ group: Dict to look up configuration values.
+ p: The parameter that we are initializing the state for
+ state: Dict from string to whatever state we are initializing
+ """
+ size_update_period = group["size_update_period"]
+
+ state["step"] = 0
+
+ kwargs = {"device": p.device, "dtype": p.dtype}
+
+ # 'delta' implements conventional momentum. There are
+ # several different kinds of update going on, so rather than
+ # compute "exp_avg" like in Adam, we store and decay a
+ # parameter-change "delta", which combines all forms of
+ # update. this is equivalent to how it's done in Adam,
+ # except for the first few steps.
+ state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
+
+ batch_size = p.shape[0]
+ numel = p.numel() // batch_size
+
+ if numel > 1:
+ # "param_rms" just periodically records the scalar root-mean-square value of
+ # the parameter tensor.
+ # it has a shape like (batch_size, 1, 1, 1, 1)
+ param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
+ state["param_rms"] = param_rms
+
+ state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
+ state["scale_grads"] = torch.zeros(
+ size_update_period, *param_rms.shape, **kwargs
+ )
+
+ # exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
+ state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
+
+ def _get_clipping_scale(
+ self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]
+ ) -> float:
+ """
+ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
+ by this amount before applying the rest of the update.
+
+ Args:
+ group: the parameter group, an item in self.param_groups
+ tuples: a list of tuples of (param, state, param_names)
+ where param is a batched set of parameters,
+ with a .grad (1st dim is batch dim)
+ and state is the state-dict where optimization parameters are kept.
+ param_names is a List[str] while each str is name for a parameter
+ in batched set of parameters "param".
+ """
+ assert len(tuples) >= 1
+ clipping_scale = group["clipping_scale"]
+ (first_p, first_state, _) = tuples[0]
+ step = first_state["step"]
+ if clipping_scale is None or step == 0:
+ # no clipping. return early on step == 0 because the other
+ # parameters' state won't have been initialized yet.
+ return 1.0
+ clipping_update_period = group["clipping_update_period"]
+ scalar_lr_scale = group["scalar_lr_scale"]
+
+ tot_sumsq = torch.tensor(0.0, device=first_p.device)
+ for (p, state, param_names) in tuples:
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError(
+ "ScaledAdam optimizer does not support sparse gradients"
+ )
+ if p.numel() == p.shape[0]: # a batch of scalars
+ tot_sumsq += (grad**2).sum() * (
+ scalar_lr_scale**2
+ ) # sum() to change shape [1] to []
+ else:
+ tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
+
+ tot_norm = tot_sumsq.sqrt()
+ if "model_norms" not in first_state:
+ first_state["model_norms"] = torch.zeros(
+ clipping_update_period, device=p.device
+ )
+ first_state["model_norms"][step % clipping_update_period] = tot_norm
+
+ irregular_estimate_steps = [
+ i for i in [10, 20, 40] if i < clipping_update_period
+ ]
+ if step % clipping_update_period == 0 or step in irregular_estimate_steps:
+ # Print some stats.
+ # We don't reach here if step == 0 because we would have returned
+ # above.
+ sorted_norms = first_state["model_norms"].sort()[0].to("cpu")
+ if step in irregular_estimate_steps:
+ sorted_norms = sorted_norms[-step:]
+ num_norms = sorted_norms.numel()
+ quartiles = []
+ for n in range(0, 5):
+ index = min(num_norms - 1, (num_norms // 4) * n)
+ quartiles.append(sorted_norms[index].item())
+
+ median = quartiles[2]
+ if median - median != 0:
+ raise RuntimeError("Too many grads were not finite")
+ threshold = clipping_scale * median
+ if step in irregular_estimate_steps:
+ # use larger thresholds on first few steps of estimating threshold,
+ # as norm may be changing rapidly.
+ threshold = threshold * 2.0
+ first_state["model_norm_threshold"] = threshold
+ percent_clipped = (
+ first_state["num_clipped"] * 100.0 / num_norms
+ if "num_clipped" in first_state
+ else 0.0
+ )
+ first_state["num_clipped"] = 0
+ quartiles = " ".join(["%.3e" % x for x in quartiles])
+ logging.warn(
+ f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
+ f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
+ )
+
+ try:
+ model_norm_threshold = first_state["model_norm_threshold"]
+ except KeyError:
+ return 1.0 # threshold has not yet been set.
+
+ ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
+ if ans != ans: # e.g. ans is nan
+ ans = 0.0
+ if ans < 1.0:
+ first_state["num_clipped"] += 1
+ if ans < 0.1:
+ logging.warn(
+ f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
+ )
+ if self.show_dominant_parameters:
+ assert p.shape[0] == len(param_names)
+ self._show_gradient_dominating_parameter(
+ tuples, tot_sumsq, group["scalar_lr_scale"]
+ )
+
+ if ans == 0.0:
+ for (p, state, param_names) in tuples:
+ p.grad.zero_() # get rid of infinity()
+
+ return ans
+
+ def _show_gradient_dominating_parameter(
+ self,
+ tuples: List[Tuple[Tensor, dict, List[str]]],
+ tot_sumsq: Tensor,
+ scalar_lr_scale: float,
+ ):
+ """
+ Show information of parameter which dominates tot_sumsq.
+
+ Args:
+ tuples: a list of tuples of (param, state, param_names)
+ where param is a batched set of parameters,
+ with a .grad (1st dim is batch dim)
+ and state is the state-dict where optimization parameters are kept.
+ param_names is a List[str] while each str is name for a parameter
+ in batched set of parameters "param".
+ tot_sumsq: sumsq of all parameters. Though it's could be calculated
+ from tuples, we still pass it to save some time.
+ """
+ all_sumsq_orig = {}
+ for (p, state, batch_param_names) in tuples:
+ # p is a stacked batch parameters.
+ batch_grad = p.grad
+ if p.numel() == p.shape[0]: # a batch of scalars
+ # Dummy values used by following `zip` statement.
+ batch_rms_orig = torch.full(
+ p.shape, scalar_lr_scale, device=batch_grad.device
+ )
+ else:
+ batch_rms_orig = state["param_rms"]
+ batch_sumsq_orig = (batch_grad * batch_rms_orig) ** 2
+ if batch_grad.ndim > 1:
+ # need to guard it with if-statement because sum() sums over
+ # all dims if dim == ().
+ batch_sumsq_orig = batch_sumsq_orig.sum(
+ dim=list(range(1, batch_grad.ndim))
+ )
+ for name, sumsq_orig, rms, grad in zip(
+ batch_param_names, batch_sumsq_orig, batch_rms_orig, batch_grad
+ ):
+
+ proportion_orig = sumsq_orig / tot_sumsq
+ all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
+
+ sorted_by_proportion = {
+ k: v
+ for k, v in sorted(
+ all_sumsq_orig.items(), key=lambda item: item[1][0], reverse=True
+ )
+ }
+ dominant_param_name = next(iter(sorted_by_proportion))
+ (
+ dominant_proportion,
+ dominant_sumsq,
+ dominant_rms,
+ dominant_grad,
+ ) = sorted_by_proportion[dominant_param_name]
+ logging.warn(
+ f"Parameter dominating tot_sumsq {dominant_param_name}"
+ f" with proportion {dominant_proportion:.2f},"
+ f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
+ f"={dominant_sumsq:.3e},"
+ f" grad_sumsq={(dominant_grad**2).sum():.3e},"
+ f" orig_rms_sq={(dominant_rms**2).item():.3e}"
+ )
+
+ def _step_one_batch(
+ self, group: dict, p: Tensor, state: dict, clipping_scale: float
+ ):
+ """
+ Do the step for one parameter, which is actually going to be a batch of
+ `real` parameters, with dim 0 as the batch dim.
+ Args:
+ group: dict to look up configuration values
+ p: parameter to update (actually multiple parameters stacked together
+ as a batch)
+ state: state-dict for p, to look up the optimizer state
+ """
+ lr = group["lr"]
+ size_update_period = group["size_update_period"]
+ beta1 = group["betas"][0]
+
+ grad = p.grad
+ if clipping_scale != 1.0:
+ grad *= clipping_scale
+ step = state["step"]
+ delta = state["delta"]
+
+ delta.mul_(beta1)
+ batch_size = p.shape[0]
+ numel = p.numel() // batch_size
+ if numel > 1:
+ # Update the size/scale of p, and set param_rms
+ scale_grads = state["scale_grads"]
+ scale_grads[step % size_update_period] = (p * grad).sum(
+ dim=list(range(1, p.ndim)), keepdim=True
+ )
+ if step % size_update_period == size_update_period - 1:
+ param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
+ param_rms.copy_(
+ (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
+ )
+ if step > 0:
+ # self._size_update() learns the overall scale on the
+ # parameter, by shrinking or expanding it.
+ self._size_update(group, scale_grads, p, state)
+
+ if numel == 1:
+ # For parameters with 1 element we just use regular Adam.
+ # Updates delta.
+ self._step_scalar(group, p, state)
+ else:
+ self._step(group, p, state)
+
+ state["step"] = step + 1
+
+ def _size_update(
+ self, group: dict, scale_grads: Tensor, p: Tensor, state: dict
+ ) -> None:
+ """
+ Called only where p.numel() > 1, this updates the scale of the parameter.
+ If we imagine: p = underlying_param * scale.exp(), and we are doing
+ gradient descent on underlying param and on scale, this function does the update
+ on `scale`.
+
+ Args:
+ group: dict to look up configuration values
+ scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing
+ grads w.r.t. the scales.
+ p: The parameter to update
+ state: The state-dict of p
+ """
+
+ param_rms = state["param_rms"]
+ beta1, beta2 = group["betas"]
+ size_lr = group["lr"] * group["scalar_lr_scale"]
+ param_min_rms = group["param_min_rms"]
+ param_max_rms = group["param_max_rms"]
+ eps = group["eps"]
+ step = state["step"]
+ batch_size = p.shape[0]
+
+ size_update_period = scale_grads.shape[0]
+ # correct beta2 for the size update period: we will have
+ # faster decay at this level.
+ beta2_corr = beta2**size_update_period
+
+ scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
+ scale_exp_avg_sq.mul_(beta2_corr).add_(
+ (scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
+ alpha=1 - beta2_corr,
+ ) # shape is (batch_size, 1, 1, ...)
+
+ # The 1st time we reach here is when size_step == 1.
+ size_step = (step + 1) // size_update_period
+ bias_correction2 = 1 - beta2_corr**size_step
+ # we don't bother with bias_correction1; this will help prevent divergence
+ # at the start of training.
+
+ denom = scale_exp_avg_sq.sqrt() + eps
+
+ scale_step = (
+ -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
+ )
+
+ is_too_small = param_rms < param_min_rms
+
+ # when the param gets too small, just don't shrink it any further.
+ scale_step.masked_fill_(is_too_small, 0.0)
+
+ # and ensure the parameter rms after update never exceeds param_max_rms.
+ # We have to look at the trained model for parameters at or around the
+ # param_max_rms, because sometimes they can indicate a problem with the
+ # topology or settings.
+ scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
+
+ delta = state["delta"]
+ # the factor of (1-beta1) relates to momentum.
+ delta.add_(p * scale_step, alpha=(1 - beta1))
+
+ def _step(self, group: dict, p: Tensor, state: dict):
+ """
+ This function does the core update of self.step(), in the case where the members of
+ the batch have more than 1 element.
+
+ Args:
+ group: A dict which will be used to look up configuration values
+ p: The parameter to be updated
+ grad: The grad of p
+ state: The state-dict corresponding to parameter p
+
+ This function modifies p.
+ """
+ grad = p.grad
+ lr = group["lr"]
+ beta1, beta2 = group["betas"]
+ eps = group["eps"]
+ param_min_rms = group["param_min_rms"]
+ step = state["step"]
+
+ exp_avg_sq = state["exp_avg_sq"]
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
+
+ this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
+ bias_correction2 = 1 - beta2 ** (this_step + 1)
+ if bias_correction2 < 0.99:
+ # note: not in-place.
+ exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
+
+ denom = exp_avg_sq.sqrt()
+ denom += eps
+ grad = grad / denom
+
+ alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms)
+
+ delta = state["delta"]
+ delta.add_(grad * alpha)
+ p.add_(delta)
+
+ def _step_scalar(self, group: dict, p: Tensor, state: dict):
+ """
+ A simplified form of the core update for scalar tensors, where we cannot get a good
+ estimate of the parameter rms.
+ """
+ beta1, beta2 = group["betas"]
+ scalar_max = group["scalar_max"]
+ eps = group["eps"]
+ lr = group["lr"] * group["scalar_lr_scale"]
+ grad = p.grad
+
+ exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,)
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+
+ # bias_correction2 is like in Adam. Don't bother with bias_correction1;
+ # slower update at the start will help stability anyway.
+ bias_correction2 = 1 - beta2 ** (state["step"] + 1)
+ denom = (exp_avg_sq / bias_correction2).sqrt() + eps
+
+ delta = state["delta"]
+ delta.add_(grad / denom, alpha=-lr * (1 - beta1))
+ p.clamp_(min=-scalar_max, max=scalar_max)
+ p.add_(delta)
+
+
+class LRScheduler(object):
+ """
+ Base-class for learning rate schedulers where the learning-rate depends on both the
+ batch and the epoch.
+ """
+
+ def __init__(self, optimizer: Optimizer, verbose: bool = False):
+ # Attach optimizer
+ if not isinstance(optimizer, Optimizer):
+ raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__))
+ self.optimizer = optimizer
+ self.verbose = verbose
+
+ for group in optimizer.param_groups:
+ group.setdefault("base_lr", group["lr"])
+
+ self.base_lrs = [group["base_lr"] for group in optimizer.param_groups]
+
+ self.epoch = 0
+ self.batch = 0
+
+ def state_dict(self):
+ """Returns the state of the scheduler as a :class:`dict`.
+
+ It contains an entry for every variable in self.__dict__ which
+ is not the optimizer.
+ """
+ return {
+ "base_lrs": self.base_lrs,
+ "epoch": self.epoch,
+ "batch": self.batch,
+ }
+
+ def load_state_dict(self, state_dict):
+ """Loads the schedulers state.
+
+ Args:
+ state_dict (dict): scheduler state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ self.__dict__.update(state_dict)
+
+ def get_last_lr(self) -> List[float]:
+ """Return last computed learning rate by current scheduler. Will be a list of float."""
+ return self._last_lr
+
+ def get_lr(self):
+ # Compute list of learning rates from self.epoch and self.batch and
+ # self.base_lrs; this must be overloaded by the user.
+ # e.g. return [some_formula(self.batch, self.epoch, base_lr) for base_lr in self.base_lrs ]
+ raise NotImplementedError
+
+ def step_batch(self, batch: Optional[int] = None) -> None:
+ # Step the batch index, or just set it. If `batch` is specified, it
+ # must be the batch index from the start of training, i.e. summed over
+ # all epochs.
+ # You can call this in any order; if you don't provide 'batch', it should
+ # of course be called once per batch.
+ if batch is not None:
+ self.batch = batch
+ else:
+ self.batch = self.batch + 1
+ self._set_lrs()
+
+ def step_epoch(self, epoch: Optional[int] = None):
+ # Step the epoch index, or just set it. If you provide the 'epoch' arg,
+ # you should call this at the start of the epoch; if you don't provide the 'epoch'
+ # arg, you should call it at the end of the epoch.
+ if epoch is not None:
+ self.epoch = epoch
+ else:
+ self.epoch = self.epoch + 1
+ self._set_lrs()
+
+ def _set_lrs(self):
+ values = self.get_lr()
+ assert len(values) == len(self.optimizer.param_groups)
+
+ for i, data in enumerate(zip(self.optimizer.param_groups, values)):
+ param_group, lr = data
+ param_group["lr"] = lr
+ self.print_lr(self.verbose, i, lr)
+ self._last_lr = [group["lr"] for group in self.optimizer.param_groups]
+
+ def print_lr(self, is_verbose, group, lr):
+ """Display the current learning rate."""
+ if is_verbose:
+ logging.warn(
+ f"Epoch={self.epoch}, batch={self.batch}: adjusting learning rate"
+ f" of group {group} to {lr:.4e}."
+ )
+
+
+class Eden(LRScheduler):
+ """
+ Eden scheduler.
+ The basic formula (before warmup) is:
+ lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 *
+ (((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup
+ where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
+ and then stays constant at 1.
+
+ If you don't have the concept of epochs, or one epoch takes a very long time,
+ you can replace the notion of 'epoch' with some measure of the amount of data
+ processed, e.g. hours of data or frames of data, with 'lr_epochs' being set to
+ some measure representing "quite a lot of data": say, one fifth or one third
+ of an entire training run, but it doesn't matter much. You could also use
+ Eden2 which has only the notion of batches.
+
+ We suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
+
+ Args:
+ optimizer: the optimizer to change the learning rates on
+ lr_batches: the number of batches after which we start significantly
+ decreasing the learning rate, suggest 5000.
+ lr_epochs: the number of epochs after which we start significantly
+ decreasing the learning rate, suggest 6 if you plan to do e.g.
+ 20 to 40 epochs, but may need smaller number if dataset is huge
+ and you will do few epochs.
+ """
+
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ lr_batches: Union[int, float],
+ lr_epochs: Union[int, float],
+ warmup_batches: Union[int, float] = 500.0,
+ warmup_start: float = 0.5,
+ verbose: bool = False,
+ ):
+ super(Eden, self).__init__(optimizer, verbose)
+ self.lr_batches = lr_batches
+ self.lr_epochs = lr_epochs
+ self.warmup_batches = warmup_batches
+
+ assert 0.0 <= warmup_start <= 1.0, warmup_start
+ self.warmup_start = warmup_start
+
+ def get_lr(self):
+ factor = (
+ (self.batch**2 + self.lr_batches**2) / self.lr_batches**2
+ ) ** -0.25 * (
+ ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25
+ )
+ warmup_factor = (
+ 1.0
+ if self.batch >= self.warmup_batches
+ else self.warmup_start
+ + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
+ # else 0.5 + 0.5 * (self.batch / self.warmup_batches)
+ )
+
+ return [x * factor * warmup_factor for x in self.base_lrs]
+
+
+class Eden2(LRScheduler):
+ """
+ Eden2 scheduler, simpler than Eden because it does not use the notion of epoch,
+ only batches.
+
+ The basic formula (before warmup) is:
+ lr = base_lr * ((batch**2 + lr_batches**2) / lr_batches**2) ** -0.5) * warmup
+
+ where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches
+ and then stays constant at 1.
+
+
+ E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam
+
+ Args:
+ optimizer: the optimizer to change the learning rates on
+ lr_batches: the number of batches after which we start significantly
+ decreasing the learning rate, suggest 5000.
+ """
+
+ def __init__(
+ self,
+ optimizer: Optimizer,
+ lr_batches: Union[int, float],
+ warmup_batches: Union[int, float] = 500.0,
+ warmup_start: float = 0.5,
+ verbose: bool = False,
+ ):
+ super().__init__(optimizer, verbose)
+ self.lr_batches = lr_batches
+ self.warmup_batches = warmup_batches
+
+ assert 0.0 <= warmup_start <= 1.0, warmup_start
+ self.warmup_start = warmup_start
+
+ def get_lr(self):
+ factor = (
+ (self.batch**2 + self.lr_batches**2) / self.lr_batches**2
+ ) ** -0.5
+ warmup_factor = (
+ 1.0
+ if self.batch >= self.warmup_batches
+ else self.warmup_start
+ + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
+ # else 0.5 + 0.5 * (self.batch / self.warmup_batches)
+ )
+
+ return [x * factor * warmup_factor for x in self.base_lrs]
+
+
+def _test_eden():
+ m = torch.nn.Linear(100, 100)
+ optim = ScaledAdam(m.parameters(), lr=0.03)
+
+ scheduler = Eden(optim, lr_batches=100, lr_epochs=2, verbose=True)
+
+ for epoch in range(10):
+ scheduler.step_epoch(epoch) # sets epoch to `epoch`
+
+ for step in range(20):
+ x = torch.randn(200, 100).detach()
+ x.requires_grad = True
+ y = m(x)
+ dy = torch.randn(200, 100).detach()
+ f = (y * dy).sum()
+ f.backward()
+
+ optim.step()
+ scheduler.step_batch()
+ optim.zero_grad()
+
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
+ logging.info(f"state dict = {scheduler.state_dict()}")
+
+
+# This is included mostly as a baseline for ScaledAdam.
+class Eve(Optimizer):
+ """
+ Implements Eve algorithm. This is a modified version of AdamW with a special
+ way of setting the weight-decay / shrinkage-factor, which is designed to make the
+ rms of the parameters approach a particular target_rms (default: 0.1). This is
+ for use with networks with 'scaled' versions of modules (see scaling.py), which
+ will be close to invariant to the absolute scale on the parameter matrix.
+
+ The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
+ The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
+ Eve is unpublished so far.
+
+ Arguments:
+ params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ lr (float, optional): learning rate (default: 1e-3)
+ betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.999))
+ eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ weight_decay (float, optional): weight decay coefficient (default: 3e-4;
+ this value means that the weight would decay significantly after
+ about 3k minibatches. Is not multiplied by learning rate, but
+ is conditional on RMS-value of parameter being > target_rms.
+ target_rms (float, optional): target root-mean-square value of
+ parameters, if they fall below this we will stop applying weight decay.
+
+
+ .. _Adam: A Method for Stochastic Optimization:
+ https://arxiv.org/abs/1412.6980
+ .. _Decoupled Weight Decay Regularization:
+ https://arxiv.org/abs/1711.05101
+ .. _On the Convergence of Adam and Beyond:
+ https://openreview.net/forum?id=ryQu7f-RZ
+ """
+
+ def __init__(
+ self,
+ params,
+ lr=1e-3,
+ betas=(0.9, 0.98),
+ eps=1e-8,
+ weight_decay=1e-3,
+ target_rms=0.1,
+ ):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+ if not 0 <= weight_decay <= 0.1:
+ raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
+ if not 0 < target_rms <= 10.0:
+ raise ValueError("Invalid target_rms value: {}".format(target_rms))
+ defaults = dict(
+ lr=lr,
+ betas=betas,
+ eps=eps,
+ weight_decay=weight_decay,
+ target_rms=target_rms,
+ )
+ super(Eve, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(Eve, self).__setstate__(state)
+
+ @torch.no_grad()
+ def step(self, closure=None):
+ """Performs a single optimization step.
+
+ Arguments:
+ closure (callable, optional): A closure that reevaluates the model
+ and returns the loss.
+ """
+ loss = None
+ if closure is not None:
+ with torch.enable_grad():
+ loss = closure()
+
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.grad is None:
+ continue
+
+ # Perform optimization step
+ grad = p.grad
+ if grad.is_sparse:
+ raise RuntimeError("AdamW does not support sparse gradients")
+
+ state = self.state[p]
+
+ # State initialization
+ if len(state) == 0:
+ state["step"] = 0
+ # Exponential moving average of gradient values
+ state["exp_avg"] = torch.zeros_like(
+ p, memory_format=torch.preserve_format
+ )
+ # Exponential moving average of squared gradient values
+ state["exp_avg_sq"] = torch.zeros_like(
+ p, memory_format=torch.preserve_format
+ )
+
+ exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
+
+ beta1, beta2 = group["betas"]
+
+ state["step"] += 1
+ bias_correction1 = 1 - beta1 ** state["step"]
+ bias_correction2 = 1 - beta2 ** state["step"]
+
+ # Decay the first and second moment running average coefficient
+ exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
+ denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_(
+ group["eps"]
+ )
+
+ step_size = group["lr"] / bias_correction1
+ target_rms = group["target_rms"]
+ weight_decay = group["weight_decay"]
+
+ if p.numel() > 1:
+ # avoid applying this weight-decay on "scaling factors"
+ # (which are scalar).
+ is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5))
+ p.mul_(1 - (weight_decay * is_above_target_rms))
+
+ p.addcdiv_(exp_avg, denom, value=-step_size)
+
+ if random.random() < 0.0005:
+ step = (exp_avg / denom) * step_size
+ logging.info(
+ f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}"
+ )
+
+ return loss
+
+
+def _test_scaled_adam(hidden_dim: int):
+ import timeit
+
+ from scaling import ScaledLinear
+
+ E = 100
+ B = 4
+ T = 2
+ logging.info("in test_eve_cain")
+ # device = torch.device('cuda')
+ device = torch.device("cpu")
+ dtype = torch.float32
+
+ fix_random_seed(42)
+ # these input_magnitudes and output_magnitudes are to test that
+ # Abel is working as we expect and is able to adjust scales of
+ # different dims differently.
+ input_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
+ output_magnitudes = (1.0 * torch.randn(E, dtype=dtype, device=device)).exp()
+
+ for iter in [1, 0]:
+ fix_random_seed(42)
+ Linear = torch.nn.Linear if iter == 0 else ScaledLinear
+
+ m = torch.nn.Sequential(
+ Linear(E, hidden_dim),
+ torch.nn.PReLU(),
+ Linear(hidden_dim, hidden_dim),
+ torch.nn.PReLU(),
+ Linear(hidden_dim, E),
+ ).to(device)
+
+ train_pairs = [
+ (
+ 100.0
+ * torch.randn(B, T, E, device=device, dtype=dtype)
+ * input_magnitudes,
+ torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes,
+ )
+ for _ in range(20)
+ ]
+
+ if iter == 0:
+ optim = Eve(m.parameters(), lr=0.003)
+ elif iter == 1:
+ optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0)
+ scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False)
+
+ start = timeit.default_timer()
+ avg_loss = 0.0
+ for epoch in range(180):
+ scheduler.step_epoch()
+ # if epoch == 100 and iter in [2,3]:
+ # optim.reset_speedup() # check it doesn't crash.
+
+ # if epoch == 130:
+ # opts = diagnostics.TensorDiagnosticOptions(
+ # 512
+ # ) # allow 4 megabytes per sub-module
+ # diagnostic = diagnostics.attach_diagnostics(m, opts)
+
+ for n, (x, y) in enumerate(train_pairs):
+ y_out = m(x)
+ loss = ((y_out - y) ** 2).mean() * 100.0
+ if epoch == 0 and n == 0:
+ avg_loss = loss.item()
+ else:
+ avg_loss = 0.98 * avg_loss + 0.02 * loss.item()
+ if n == 0 and epoch % 5 == 0:
+ # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item()
+ # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item()
+ # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item()
+ # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item()
+ # scale1 = '%.2e' % (m[0].weight_scale.exp().item())
+ # scale1b = '%.2e' % (m[0].bias_scale.exp().item())
+ # scale2 = '%.2e' % (m[2].weight_scale.exp().item())
+ # scale2b = '%.2e' % (m[2].bias_scale.exp().item())
+ lr = scheduler.get_last_lr()[0]
+ logging.info(
+ f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}"
+ ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b}
+ loss.log().backward()
+ optim.step()
+ optim.zero_grad()
+ scheduler.step_batch()
+
+ # diagnostic.print_diagnostics()
+
+ stop = timeit.default_timer()
+ logging.info(f"Iter={iter}, Time taken: {stop - start}")
+
+ logging.info(f"last lr = {scheduler.get_last_lr()}")
+ # logging.info("state dict = ", scheduler.state_dict())
+ # logging.info("optim state_dict = ", optim.state_dict())
+ logging.info(f"input_magnitudes = {input_magnitudes}")
+ logging.info(f"output_magnitudes = {output_magnitudes}")
+
+
+if __name__ == "__main__":
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ logging.getLogger().setLevel(logging.INFO)
+ import subprocess
+
+ s = subprocess.check_output(
+ "git status -uno .; git log -1; git diff HEAD .", shell=True
+ )
+ logging.info(s)
+ import sys
+
+ if len(sys.argv) > 1:
+ hidden_dim = int(sys.argv[1])
+ else:
+ hidden_dim = 200
+
+ _test_scaled_adam(hidden_dim)
+ _test_eden()
diff --git a/egs/wenetspeech/ASR/whisper/train.py b/egs/wenetspeech/ASR/whisper/train.py
new file mode 100644
index 000000000..07de35dd5
--- /dev/null
+++ b/egs/wenetspeech/ASR/whisper/train.py
@@ -0,0 +1,924 @@
+#!/usr/bin/env python3
+# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
+# 2024 Yuekai Zhang
+#
+# See ../../../../LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+Usage:
+
+#fine-tuning with deepspeed zero stage 1
+torchrun --nproc-per-node 8 ./whisper/train.py \
+ --max-duration 200 \
+ --exp-dir whisper/exp_large_v2 \
+ --model-name large-v2 \
+ --deepspeed \
+ --deepspeed_config ./whisper/ds_config_zero1.json
+
+# fine-tuning with ddp
+torchrun --nproc-per-node 8 ./whisper/train.py \
+ --max-duration 200 \
+ --exp-dir whisper/exp_medium \
+ --base-lr 1e-5 \
+ --model-name medium
+"""
+
+
+import argparse
+import copy
+import logging
+import random
+import warnings
+from pathlib import Path
+from shutil import copyfile
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import deepspeed
+import k2
+import optim
+import torch
+import torch.multiprocessing as mp
+import torch.nn as nn
+import whisper
+from asr_datamodule import WenetSpeechAsrDataModule
+from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
+from label_smoothing import LabelSmoothingLoss
+from lhotse import CutSet, load_manifest
+from lhotse.cut import Cut
+from lhotse.dataset.sampling.base import CutSampler
+from lhotse.utils import fix_random_seed
+from optim import Eden, ScaledAdam
+from torch import Tensor
+from torch.cuda.amp import GradScaler
+from torch.nn.functional import pad as pad_tensor
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.tensorboard import SummaryWriter
+from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
+
+from icefall import diagnostics
+from icefall.checkpoint import load_checkpoint, remove_checkpoints
+from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
+from icefall.checkpoint import update_averaged_model
+from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist
+from icefall.env import get_env_info
+from icefall.hooks import register_inf_check_hooks
+from icefall.utils import (
+ AttributeDict,
+ MetricsTracker,
+ filter_uneven_sized_batch,
+ setup_logger,
+ str2bool,
+)
+
+LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
+
+
+def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+ for module in model.modules():
+ if hasattr(module, "batch_count"):
+ module.batch_count = batch_count
+
+
+def get_parser():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+
+ parser.add_argument(
+ "--tensorboard",
+ type=str2bool,
+ default=True,
+ help="Should various information be logged in tensorboard.",
+ )
+
+ parser.add_argument(
+ "--num-epochs",
+ type=int,
+ default=10,
+ help="Number of epochs to train.",
+ )
+
+ parser.add_argument(
+ "--start-epoch",
+ type=int,
+ default=1,
+ help="""Resume training from this epoch. It should be positive.
+ If larger than 1, it will load checkpoint from
+ exp-dir/epoch-{start_epoch-1}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--start-batch",
+ type=int,
+ default=0,
+ help="""If positive, --start-epoch is ignored and
+ it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt
+ """,
+ )
+
+ parser.add_argument(
+ "--exp-dir",
+ type=str,
+ default="pruned_transducer_stateless7/exp",
+ help="""The experiment dir.
+ It specifies the directory where all training related
+ files, e.g., checkpoints, log, etc, are saved
+ """,
+ )
+
+ parser.add_argument(
+ "--model-name",
+ type=str,
+ default="large-v2",
+ choices=["large-v2", "large-v3", "medium", "small", "tiny"],
+ help="""The model name to use.
+ """,
+ )
+
+ parser.add_argument(
+ "--base-lr", type=float, default=1e-5, help="The base learning rate."
+ )
+
+ parser.add_argument(
+ "--lr-batches",
+ type=float,
+ default=5000,
+ help="""Number of steps that affects how rapidly the learning rate
+ decreases. We suggest not to change this.""",
+ )
+
+ parser.add_argument(
+ "--lr-epochs",
+ type=float,
+ default=6,
+ help="""Number of epochs that affects how rapidly the learning rate decreases.
+ """,
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="The seed for random generators intended for reproducibility",
+ )
+
+ parser.add_argument(
+ "--print-diagnostics",
+ type=str2bool,
+ default=False,
+ help="Accumulate stats on activations, print them and exit.",
+ )
+
+ parser.add_argument(
+ "--inf-check",
+ type=str2bool,
+ default=False,
+ help="Add hooks to check for infinite module outputs and gradients.",
+ )
+
+ parser.add_argument(
+ "--keep-last-k",
+ type=int,
+ default=30,
+ help="""Only keep this number of checkpoints on disk.
+ For instance, if it is 3, there are only 3 checkpoints
+ in the exp-dir with filenames `checkpoint-xxx.pt`.
+ It does not affect checkpoints with name `epoch-xxx.pt`.
+ """,
+ )
+
+ parser.add_argument(
+ "--average-period",
+ type=int,
+ default=200,
+ help="""Update the averaged model, namely `model_avg`, after processing
+ this number of batches. `model_avg` is a separate version of model,
+ in which each floating-point parameter is the average of all the
+ parameters from the start of training. Each time we take the average,
+ we do: `model_avg = model * (average_period / batch_idx_train) +
+ model_avg * ((batch_idx_train - average_period) / batch_idx_train)`.
+ """,
+ )
+
+ parser.add_argument(
+ "--use-fp16",
+ type=str2bool,
+ default=True,
+ help="Whether to use half precision training.",
+ )
+
+ parser = deepspeed.add_config_arguments(parser)
+
+ return parser
+
+
+def get_params() -> AttributeDict:
+ """Return a dict containing training parameters.
+
+ All training related parameters that are not passed from the commandline
+ are saved in the variable `params`.
+
+ Commandline options are merged into `params` after they are parsed, so
+ you can also access them via `params`.
+
+ Explanation of options saved in `params`:
+
+ - frame_shift_ms: The frame shift in milliseconds.
+ - allowed_excess_duration_ratio: The allowed excess duration ratio.
+ - best_train_loss: The best training loss so far.
+ - best_valid_loss: The best validation loss so far.
+ - best_train_epoch: The epoch where the best training loss is achieved.
+ - best_valid_epoch: The epoch where the best validation loss is achieved.
+ - batch_idx_train: The batch index of the current batch.
+ - log_interval: Log training stats every `log_interval` batches.
+ - reset_interval: Reset the stats every `reset_interval` batches.
+ - valid_interval: Run validation every `valid_interval` batches.
+ - env_info: The environment information.
+ """
+ params = AttributeDict(
+ {
+ "frame_shift_ms": 10.0,
+ "allowed_excess_duration_ratio": 0.1,
+ "best_train_loss": float("inf"),
+ "best_valid_loss": float("inf"),
+ "best_train_epoch": -1,
+ "best_valid_epoch": -1,
+ "batch_idx_train": 0,
+ "log_interval": 50,
+ "reset_interval": 200,
+ "valid_interval": 5000,
+ "env_info": get_env_info(),
+ }
+ )
+
+ return params
+
+
+def load_checkpoint_if_available(
+ params: AttributeDict,
+ model: nn.Module,
+ model_avg: nn.Module = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+) -> Optional[Dict[str, Any]]:
+ """Load checkpoint from file.
+
+ If params.start_batch is positive, it will load the checkpoint from
+ `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if
+ params.start_epoch is larger than 1, it will load the checkpoint from
+ `params.start_epoch - 1`.
+
+ Apart from loading state dict for `model` and `optimizer` it also updates
+ `best_train_epoch`, `best_train_loss`, `best_valid_epoch`,
+ and `best_valid_loss` in `params`.
+
+ Args:
+ params:
+ The return value of :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer that we are using.
+ scheduler:
+ The scheduler that we are using.
+ Returns:
+ Return a dict containing previously saved training info.
+ """
+ if params.start_batch > 0:
+ filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt"
+ elif params.start_epoch > 1:
+ filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
+ else:
+ return None
+
+ assert filename.is_file(), f"{filename} does not exist!"
+
+ saved_params = load_checkpoint(
+ filename,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ )
+
+ keys = [
+ "best_train_epoch",
+ "best_valid_epoch",
+ "batch_idx_train",
+ "best_train_loss",
+ "best_valid_loss",
+ ]
+ for k in keys:
+ params[k] = saved_params[k]
+
+ if params.start_batch > 0:
+ if "cur_epoch" in saved_params:
+ params["start_epoch"] = saved_params["cur_epoch"]
+
+ return saved_params
+
+
+def save_checkpoint(
+ params: AttributeDict,
+ model: Union[nn.Module, DDP],
+ model_avg: Optional[nn.Module] = None,
+ optimizer: Optional[torch.optim.Optimizer] = None,
+ scheduler: Optional[LRSchedulerType] = None,
+ sampler: Optional[CutSampler] = None,
+ scaler: Optional[GradScaler] = None,
+ rank: int = 0,
+) -> None:
+ """Save model, optimizer, scheduler and training stats to file.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The training model.
+ model_avg:
+ The stored model averaged from the start of training.
+ optimizer:
+ The optimizer used in the training.
+ sampler:
+ The sampler for the training dataset.
+ scaler:
+ The scaler used for mix precision training.
+ """
+ if rank != 0:
+ return
+ filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt"
+ save_checkpoint_impl(
+ filename=filename,
+ model=model,
+ model_avg=model_avg,
+ params=params,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ if params.best_train_epoch == params.cur_epoch:
+ best_train_filename = params.exp_dir / "best-train-loss.pt"
+ copyfile(src=filename, dst=best_train_filename)
+
+ if params.best_valid_epoch == params.cur_epoch:
+ best_valid_filename = params.exp_dir / "best-valid-loss.pt"
+ copyfile(src=filename, dst=best_valid_filename)
+
+
+def compute_loss(
+ params: AttributeDict,
+ tokenizer: whisper.tokenizer.Tokenizer,
+ model: Union[nn.Module, DDP],
+ batch: dict,
+ is_training: bool,
+) -> Tuple[Tensor, MetricsTracker]:
+ """
+ Compute the loss for the given batch.
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ tokenizer:
+ The tokenizer used to encode the text.
+ model:
+ The model for training.
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ is_training:
+ Whether it is training.
+ Returns:
+ Return a tuple of two elements. The first element is the loss tensor.
+ """
+ # For the uneven-sized batch, the total duration after padding would possibly
+ # cause OOM. Hence, for each batch, which is sorted descendingly by length,
+ # we simply drop the last few shortest samples, so that the retained total frames
+ # (after padding) would not exceed `allowed_max_frames`:
+ # `allowed_max_frames = int(max_frames * (1.0 + allowed_excess_duration_ratio))`,
+ # where `max_frames = max_duration * 1000 // frame_shift_ms`.
+ # We set allowed_excess_duration_ratio=0.1.
+ if isinstance(model, DDP):
+ # get underlying nn.Module
+ model = model.module
+
+ def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor:
+ padding_size = max(tensor.shape[0] for tensor in tensors)
+ dims = len(tensors[0].shape)
+ padded_tensors = []
+ for tensor in tensors:
+ padding = [0] * 2 * dims
+ padding[-1] = padding_size - tensor.shape[0]
+ padded_tensors.append(pad_tensor(tensor, padding, "constant", pad_value))
+ return torch.stack([tensor for tensor in padded_tensors], dim=0)
+
+ max_frames = params.max_duration * 1000 // params.frame_shift_ms
+ allowed_max_frames = int(max_frames * (1.0 + params.allowed_excess_duration_ratio))
+ batch = filter_uneven_sized_batch(batch, allowed_max_frames)
+
+ device = model.device if isinstance(model, DDP) else next(model.parameters()).device
+ feature = batch["inputs"]
+
+ assert feature.ndim == 3
+ feature = feature.to(device)
+ feature = feature.transpose(1, 2) # (N, C, T)
+
+ supervisions = batch["supervisions"]
+ feature_lens = supervisions["num_frames"].to(device)
+
+ batch_idx_train = params.batch_idx_train
+
+ texts = batch["supervisions"]["text"]
+ # remove spaces in texts
+ texts = [text.replace(" ", "") for text in texts]
+
+ text_tokens_list = [
+ list(tokenizer.sot_sequence_including_notimestamps)
+ + tokenizer.encode(text)
+ + [tokenizer.eot]
+ for text in texts
+ ]
+ # convert it to torch tensor
+ text_tokens_list = [
+ torch.LongTensor(text_tokens) for text_tokens in text_tokens_list
+ ]
+
+ # 50256 is the index of for all whisper models
+ prev_outputs_tokens = _batch_tensors(
+ [tokens[:-1] for tokens in text_tokens_list], pad_value=50256
+ )
+ target_tokens = _batch_tensors(
+ [tokens[1:] for tokens in text_tokens_list], pad_value=50256
+ )
+ target_lengths = torch.LongTensor(
+ [tokens.shape[0] - 1 for tokens in text_tokens_list]
+ )
+
+ decoder_criterion = LabelSmoothingLoss(
+ ignore_index=50256, label_smoothing=0.1, reduction="sum"
+ )
+
+ # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|>
+ ignore_prefix_size = 3
+ with torch.set_grad_enabled(is_training):
+ encoder_out = model.encoder(feature)
+ text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
+ text_logits = text_logits[:, ignore_prefix_size:, :]
+ target_tokens = target_tokens[:, ignore_prefix_size:]
+ loss = decoder_criterion(text_logits, target_tokens.to(device))
+
+ assert loss.requires_grad == is_training
+
+ info = MetricsTracker()
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ info["frames"] = (feature_lens // params.subsampling_factor).sum().item()
+
+ # Note: We use reduction=sum while computing the loss.
+ info["loss"] = loss.detach().cpu().item()
+
+ return loss, info
+
+
+def compute_validation_loss(
+ params: AttributeDict,
+ tokenizer: whisper.tokenizer.Tokenizer,
+ model: Union[nn.Module, DDP],
+ valid_dl: torch.utils.data.DataLoader,
+ world_size: int = 1,
+) -> MetricsTracker:
+ """Run the validation process."""
+ model.eval()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(valid_dl):
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ batch=batch,
+ is_training=False,
+ )
+ assert loss.requires_grad is False
+ tot_loss = tot_loss + loss_info
+
+ if world_size > 1:
+ tot_loss.reduce(loss.device)
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ if loss_value < params.best_valid_loss:
+ params.best_valid_epoch = params.cur_epoch
+ params.best_valid_loss = loss_value
+
+ return tot_loss
+
+
+def train_one_epoch(
+ params: AttributeDict,
+ tokenizer: whisper.tokenizer.Tokenizer,
+ model: Union[nn.Module, DDP],
+ optimizer: torch.optim.Optimizer,
+ scheduler: LRSchedulerType,
+ train_dl: torch.utils.data.DataLoader,
+ valid_dl: torch.utils.data.DataLoader,
+ scaler: GradScaler,
+ model_avg: Optional[nn.Module] = None,
+ tb_writer: Optional[SummaryWriter] = None,
+ world_size: int = 1,
+ rank: int = 0,
+) -> None:
+ """Train the model for one epoch.
+
+ The training loss from the mean of all frames is saved in
+ `params.train_loss`. It runs the validation process every
+ `params.valid_interval` batches.
+
+ Args:
+ params:
+ It is returned by :func:`get_params`.
+ model:
+ The model for training.
+ optimizer:
+ The optimizer we are using.
+ scheduler:
+ The learning rate scheduler, we call step() every step.
+ train_dl:
+ Dataloader for the training dataset.
+ valid_dl:
+ Dataloader for the validation dataset.
+ scaler:
+ The scaler used for mix precision training.
+ model_avg:
+ The stored model averaged from the start of training.
+ tb_writer:
+ Writer to write log messages to tensorboard.
+ world_size:
+ Number of nodes in DDP training. If it is 1, DDP is disabled.
+ rank:
+ The rank of the node in DDP training. If no DDP is used, it should
+ be set to 0.
+ """
+ model.train()
+
+ tot_loss = MetricsTracker()
+
+ for batch_idx, batch in enumerate(train_dl):
+ params.batch_idx_train += 1
+ batch_size = len(batch["supervisions"]["text"])
+ if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
+ logging.info("Computing validation loss")
+ valid_info = compute_validation_loss(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ valid_dl=valid_dl,
+ world_size=world_size,
+ )
+ model.train()
+ logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
+ logging.info(
+ f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
+ )
+ if tb_writer is not None:
+ valid_info.write_summary(
+ tb_writer, "train/valid_", params.batch_idx_train
+ )
+
+ try:
+ with torch.cuda.amp.autocast(enabled=params.use_fp16):
+ loss, loss_info = compute_loss(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ batch=batch,
+ is_training=True,
+ )
+ # summary stats
+ tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
+
+ # NOTE: We use reduction==sum and loss is computed over utterances
+ # in the batch and there is no normalization to it so far.
+ if params.deepspeed:
+ # deepspeed's backward() is different from torch's backward()
+ # in that it does not accept a loss tensor as input.
+ # It computes the loss internally.
+ model.backward(loss)
+ model.step()
+ else:
+ scaler.scale(loss).backward()
+ set_batch_count(model, params.batch_idx_train)
+ scheduler.step_batch(params.batch_idx_train)
+
+ scaler.step(optimizer)
+ scaler.update()
+ optimizer.zero_grad()
+ except: # noqa
+ display_and_save_batch(batch, params=params)
+ raise
+
+ if params.print_diagnostics and batch_idx == 5:
+ return
+
+ if (
+ rank == 0
+ and params.batch_idx_train > 0
+ and params.batch_idx_train % params.average_period == 0
+ and not params.deepspeed
+ ):
+ update_averaged_model(
+ params=params,
+ model_cur=model,
+ model_avg=model_avg,
+ )
+
+ if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed:
+ # If the grad scale was less than 1, try increasing it. The _growth_interval
+ # of the grad scaler is configurable, but we can't configure it to have different
+ # behavior depending on the current grad scale.
+ cur_grad_scale = scaler._scale.item()
+ if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0):
+ scaler.update(cur_grad_scale * 2.0)
+ if cur_grad_scale < 0.01:
+ logging.warning(f"Grad scale is small: {cur_grad_scale}")
+ if cur_grad_scale < 1.0e-05:
+ raise RuntimeError(
+ f"grad_scale is too small, exiting: {cur_grad_scale}"
+ )
+ if batch_idx % params.log_interval == 0:
+ try:
+ cur_lr = scheduler.get_last_lr()[0]
+ except: # noqa
+ cur_lr = 0.0
+ cur_grad_scale = (
+ scaler._scale.item()
+ if (params.use_fp16 and not params.deepspeed)
+ else 1.0
+ )
+
+ logging.info(
+ f"Epoch {params.cur_epoch}, "
+ f"batch {batch_idx}, loss[{loss_info}], "
+ f"tot_loss[{tot_loss}], batch size: {batch_size}, "
+ f"lr: {cur_lr:.2e}, "
+ + (
+ f"grad_scale: {scaler._scale.item()}"
+ if (params.use_fp16 and not params.deepspeed)
+ else ""
+ )
+ )
+
+ if tb_writer is not None:
+ tb_writer.add_scalar(
+ "train/learning_rate", cur_lr, params.batch_idx_train
+ )
+
+ loss_info.write_summary(
+ tb_writer, "train/current_", params.batch_idx_train
+ )
+ tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
+ if params.use_fp16:
+ tb_writer.add_scalar(
+ "train/grad_scale",
+ cur_grad_scale,
+ params.batch_idx_train,
+ )
+
+ loss_value = tot_loss["loss"] / tot_loss["frames"]
+ params.train_loss = loss_value
+ if params.train_loss < params.best_train_loss:
+ params.best_train_epoch = params.cur_epoch
+ params.best_train_loss = params.train_loss
+
+
+def run(rank, world_size, args):
+ """
+ Args:
+ rank:
+ It is a value between 0 and `world_size-1`, which is
+ passed automatically by `mp.spawn()` in :func:`main`.
+ The node with rank 0 is responsible for saving checkpoint.
+ world_size:
+ Number of GPUs for DDP training.
+ args:
+ The return value of get_parser().parse_args()
+ """
+ params = get_params()
+ params.update(vars(args))
+
+ fix_random_seed(params.seed)
+
+ setup_logger(f"{params.exp_dir}/log/log-train")
+ logging.info(params)
+
+ logging.info("About to create model")
+
+ replace_whisper_encoder_forward()
+ model = whisper.load_model(params.model_name, "cpu")
+ del model.alignment_heads
+ num_param = sum([p.numel() for p in model.parameters()])
+ logging.info(f"Number of model parameters: {num_param}")
+
+ tokenizer = whisper.tokenizer.get_tokenizer(
+ model.is_multilingual,
+ num_languages=model.num_languages,
+ language="zh",
+ task="transcribe",
+ )
+
+ model_avg: Optional[nn.Module] = None
+ if rank == 0:
+ # model_avg is only used with rank 0
+ model_avg = copy.deepcopy(model).to(torch.float64)
+
+ assert params.start_epoch > 0, params.start_epoch
+ checkpoints = load_checkpoint_if_available(
+ params=params, model=model, model_avg=model_avg
+ )
+
+ if torch.cuda.is_available():
+ device = torch.device("cuda", rank)
+ else:
+ device = torch.device("cpu")
+ logging.info(f"Device: {device}")
+ model.to(device)
+
+ optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
+ scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
+
+ if checkpoints and "optimizer" in checkpoints:
+ logging.info("Loading optimizer state dict")
+ optimizer.load_state_dict(checkpoints["optimizer"])
+
+ if (
+ checkpoints
+ and "scheduler" in checkpoints
+ and checkpoints["scheduler"] is not None
+ ):
+ logging.info("Loading scheduler state dict")
+ scheduler.load_state_dict(checkpoints["scheduler"])
+
+ if world_size > 1:
+ if params.deepspeed:
+ logging.info("Using DeepSpeed")
+ model, optimizer, _, scheduler = deepspeed.initialize(
+ args=params, model=model, model_parameters=model.parameters()
+ )
+ else:
+ logging.info("Using DDP")
+ setup_dist(use_ddp_launch=True)
+ model = DDP(model, device_ids=[rank], find_unused_parameters=True)
+
+ if params.print_diagnostics:
+ opts = diagnostics.TensorDiagnosticOptions(
+ 2**22
+ ) # allow 4 megabytes per sub-module
+ diagnostic = diagnostics.attach_diagnostics(model, opts)
+
+ if params.inf_check:
+ register_inf_check_hooks(model)
+
+ wenetspeech = WenetSpeechAsrDataModule(args)
+
+ if params.start_batch > 0 and checkpoints and "sampler" in checkpoints:
+ # We only load the sampler's state dict when it loads a checkpoint
+ # saved in the middle of an epoch
+ sampler_state_dict = checkpoints["sampler"]
+ else:
+ sampler_state_dict = None
+
+ train_dl = wenetspeech.train_dataloaders(wenetspeech.train_cuts())
+ valid_dl = wenetspeech.valid_dataloaders(wenetspeech.valid_cuts())
+
+ scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
+ if checkpoints and "grad_scaler" in checkpoints:
+ logging.info("Loading grad scaler state dict")
+ scaler.load_state_dict(checkpoints["grad_scaler"])
+
+ if args.tensorboard and rank == 0:
+ tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard")
+ else:
+ tb_writer = None
+
+ logging.info(f"start training from epoch {params.start_epoch}")
+ for epoch in range(params.start_epoch, params.num_epochs + 1):
+ if not params.deepspeed:
+ scheduler.step_epoch(epoch - 1)
+ fix_random_seed(params.seed + epoch - 1)
+ train_dl.sampler.set_epoch(epoch - 1)
+
+ if tb_writer is not None:
+ tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train)
+
+ params.cur_epoch = epoch
+
+ train_one_epoch(
+ params=params,
+ tokenizer=tokenizer,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ train_dl=train_dl,
+ valid_dl=valid_dl,
+ scaler=scaler,
+ tb_writer=tb_writer,
+ world_size=world_size,
+ rank=rank,
+ )
+
+ if params.print_diagnostics:
+ diagnostic.print_diagnostics()
+ break
+
+ if params.deepspeed:
+ model.save_checkpoint(
+ save_dir=params.exp_dir,
+ tag=f"epoch-{params.cur_epoch}",
+ client_state={},
+ )
+ if rank == 0:
+ convert_zero_checkpoint_to_fp32_state_dict(
+ params.exp_dir,
+ f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
+ tag=f"epoch-{params.cur_epoch}",
+ )
+ else:
+ save_checkpoint(
+ params=params,
+ model=model,
+ model_avg=model_avg,
+ optimizer=optimizer,
+ scheduler=scheduler,
+ sampler=train_dl.sampler,
+ scaler=scaler,
+ rank=rank,
+ )
+
+ logging.info("Done!")
+
+ if world_size > 1 and not params.deepspeed:
+ torch.distributed.barrier()
+ cleanup_dist()
+
+
+def display_and_save_batch(
+ batch: dict,
+ params: AttributeDict,
+) -> None:
+ """Display the batch statistics and save the batch into disk.
+
+ Args:
+ batch:
+ A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
+ for the content in it.
+ params:
+ Parameters for training. See :func:`get_params`.
+ """
+ from lhotse.utils import uuid4
+
+ filename = f"{params.exp_dir}/batch-{uuid4()}.pt"
+ logging.info(f"Saving batch to {filename}")
+ torch.save(batch, filename)
+
+ supervisions = batch["supervisions"]
+ features = batch["inputs"]
+
+ logging.info(f"features shape: {features.shape}")
+
+
+def main():
+ parser = get_parser()
+ AishellAsrDataModule.add_arguments(parser)
+ args = parser.parse_args()
+ args.exp_dir = Path(args.exp_dir)
+
+ world_size = get_world_size()
+ rank = get_rank()
+
+ torch.set_num_threads(1)
+ torch.set_num_interop_threads(1)
+ run(rank=rank, world_size=world_size, args=args)
+
+
+if __name__ == "__main__":
+ main()