This commit is contained in:
root 2024-01-22 08:10:26 +00:00
parent b623c3be15
commit 8d9ab308af
10 changed files with 257 additions and 229 deletions

View File

@ -24,3 +24,10 @@ The following table lists the differences among them.
The decoder in `transducer_stateless` is modified from the paper The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer. We place an additional Conv1d layer right after the input embedding layer.
# Whisper
Recipe to finetune large pretrained models
| | Encoder | Decoder | Comment |
|------------------------------------|-----------|--------------------|-----------------------------------------------------------------------------------|
| `whisper` | Transformer | Transformer | support fine-tuning using deepspeed

View File

@ -77,7 +77,7 @@ It's reworked Zipformer with Pruned RNNT loss.
Command for training is: Command for training is:
```bash ```bash
./prepare.sh ./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1" export CUDA_VISIBLE_DEVICES="0,1"
@ -142,7 +142,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
--feedforward-dim 512,768,768,768,768,768 \ --feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \ --encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \ --encoder-unmasked-dim 192,192,192,192,192,192 \
--max-duration 1200 --max-duration 1200
``` ```
Command for decoding is: Command for decoding is:
@ -192,7 +192,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
--feedforward-dim 512,768,1536,2048,1536,768 \ --feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \ --encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \ --encoder-unmasked-dim 192,192,256,320,256,192 \
--max-duration 800 --max-duration 800
``` ```
Command for decoding is: Command for decoding is:
@ -208,7 +208,7 @@ for m in greedy_search modified_beam_search fast_beam_search ; do
--num-encoder-layers 2,2,4,5,4,2 \ --num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \ --feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \ --encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 --encoder-unmasked-dim 192,192,256,320,256,192
done done
``` ```

View File

@ -29,7 +29,14 @@ import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool from icefall.utils import get_executor, str2bool
@ -42,7 +49,9 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False): def compute_fbank_aishell(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests") src_dir = Path("data/manifests")
output_dir = Path("data/fbank") output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count()) num_jobs = min(15, os.cpu_count())
@ -69,7 +78,9 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, w
dataset_parts, dataset_parts,
) )
if whisper_fbank: if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda')) extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else: else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -84,7 +95,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, w
supervisions=m["supervisions"], supervisions=m["supervisions"],
) )
if "train" in partition and perturb_speed: if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb") logging.info("Doing speed perturb")
cut_set = ( cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
) )
@ -129,5 +140,7 @@ if __name__ == "__main__":
args = get_args() args = get_args()
compute_fbank_aishell( compute_fbank_aishell(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
) )

View File

@ -387,4 +387,4 @@ if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true ./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
touch data/fbank/.aishell.whisper.done touch data/fbank/.aishell.whisper.done
fi fi
fi fi

View File

@ -2,6 +2,7 @@
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, # Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
# Fangjun Kuang, # Fangjun Kuang,
# Wei Kang) # Wei Kang)
# 2024 Yuekai Zhang
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -42,44 +43,37 @@ python3 ./whisper/decode.py \
import argparse import argparse
import logging import logging
import re
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import whisper
from whisper.normalizers import BasicTextNormalizer
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
import whisper
from asr_datamodule import AishellAsrDataModule from asr_datamodule import AishellAsrDataModule
#from model import load_model from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model from zhconv import convert
from icefall.decode import (
get_lattice, from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
)
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts,
setup_logger, setup_logger,
store_transcripts, store_transcripts,
write_error_stats,
str2bool, str2bool,
write_error_stats,
) )
from zhconv import convert
from tn.chinese.normalizer import Normalizer
import re
def average_checkpoints( def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu") filenames: List[Path], device: torch.device = torch.device("cpu")
) -> dict: ) -> dict:
"""Average a list of checkpoints. """Average a list of checkpoints.
The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict.
Args: Args:
filenames: filenames:
@ -94,9 +88,9 @@ def average_checkpoints(
n = len(filenames) n = len(filenames)
if "model" in torch.load(filenames[0], map_location=device): if "model" in torch.load(filenames[0], map_location=device):
avg = torch.load(filenames[0], map_location=device)["model"] avg = torch.load(filenames[0], map_location=device)["model"]
else: else:
avg = torch.load(filenames[0], map_location=device) avg = torch.load(filenames[0], map_location=device)
# Identify shared parameters. Two parameters are said to be shared # Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr # if they have the same data_ptr
@ -112,9 +106,9 @@ def average_checkpoints(
for i in range(1, n): for i in range(1, n):
if "model" in torch.load(filenames[i], map_location=device): if "model" in torch.load(filenames[i], map_location=device):
state_dict = torch.load(filenames[i], map_location=device)["model"] state_dict = torch.load(filenames[i], map_location=device)["model"]
else: else:
state_dict = torch.load(filenames[i], map_location=device) state_dict = torch.load(filenames[i], map_location=device)
for k in uniqued_names: for k in uniqued_names:
avg[k] += state_dict[k] avg[k] += state_dict[k]
@ -126,33 +120,48 @@ def average_checkpoints(
return avg return avg
def remove_punctuation(text: str or List[str]): def remove_punctuation(text: str or List[str]):
# https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py """Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
punctuation = '!,.;:?、!,。;:?《》 '
Args:
text: It can be a string or a list of strings.
Returns:
Return a string or a list of strings without any punctuation.
"""
punctuation = "!,.;:?、!,。;:?《》 "
if isinstance(text, str): if isinstance(text, str):
text = re.sub(r'[{}]+'.format(punctuation), '', text).strip() text = re.sub(r"[{}]+".format(punctuation), "", text).strip()
return text return text
elif isinstance(text, list): elif isinstance(text, list):
result_text = [] result_text = []
for t in text: for t in text:
t = re.sub(r'[{}]+'.format(punctuation), '', t).strip() t = re.sub(r"[{}]+".format(punctuation), "", t).strip()
result_text.append(t) result_text.append(t)
return result_text return result_text
else: else:
raise Exception(f'Not support type {type(text)}') raise Exception(f"Not support type {type(text)}")
def to_simple(text: str or List[str]): def to_simple(text: str or List[str]):
"""Convert traditional Chinese to simplified Chinese.
Args:
text: It can be a string or a list of strings.
Returns:
Return a string or a list of strings converted to simplified Chinese.
"""
if isinstance(text, str): if isinstance(text, str):
text = convert(text, 'zh-cn') text = convert(text, "zh-cn")
return text return text
elif isinstance(text, list): elif isinstance(text, list):
result_text = [] result_text = []
for t in text: for t in text:
t = convert(t, 'zh-cn') t = convert(t, "zh-cn")
result_text.append(t) result_text.append(t)
return result_text return result_text
else: else:
raise Exception(f'Not support type{type(text)}') raise Exception(f"Not support type{type(text)}")
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -214,7 +223,7 @@ def get_parser():
default=True, default=True,
help="replace whisper encoder forward method to remove input length restriction", help="replace whisper encoder forward method to remove input length restriction",
) )
return parser return parser
@ -226,6 +235,7 @@ def get_params() -> AttributeDict:
) )
return params return params
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -234,42 +244,17 @@ def decode_one_batch(
"""Decode one batch and return the result in a dict. The dict has the """Decode one batch and return the result in a dict. The dict has the
following format: following format:
- key: It indicates the setting used for decoding. For example, - key: "beam-search"
if decoding method is 1best, the key is the string `no_rescore`. - value: A list of lists. Each sublist is a list of token IDs.
If attention rescoring is used, the key is the string
`ngram_lm_scale_xxx_attention_scale_xxx`, where `xxx` is the
value of `lm_scale` and `attention_scale`. An example key is
`ngram_lm_scale_0.7_attention_scale_0.5`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
Args: Args:
params: params:
It's the return value of :func:`get_params`. It is returned by :func:`get_params`.
model:
- params.method is "1best", it uses 1best decoding without LM rescoring. The neural model.
- params.method is "nbest", it uses nbest decoding without LM rescoring. batch:
- params.method is "attention-decoder", it uses attention rescoring. It is returned by :meth:`torch.utils.data.DataLoader.__iter__`.
model:
The neural model.
HLG:
The decoding graph. Used when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
lexicon:
It contains the token symbol table and the word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
Returns: Returns:
Return the decoding result. See above description for the format of Return a dict, whose key may be "beam-search".
the returned dict.
""" """
dtype = torch.float16 dtype = torch.float16
device = torch.device("cuda") device = torch.device("cuda")
@ -280,22 +265,27 @@ def decode_one_batch(
if not params.remove_whisper_encoder_input_length_restriction: if not params.remove_whisper_encoder_input_length_restriction:
T = 3000 T = 3000
if feature.shape[2] < T: if feature.shape[2] < T:
feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2) feature = torch.cat(
[
feature,
torch.zeros(
feature.shape[0], feature.shape[1], T - feature.shape[2]
).to(device, dtype=dtype),
],
2,
)
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_len = supervisions["num_frames"] feature_len = supervisions["num_frames"]
feature_len = feature_len.to(device, dtype=dtype) feature_len = feature_len.to(device, dtype=dtype)
results = model.decode(feature, params.decoding_options) results = model.decode(feature, params.decoding_options)
hyps = [result.text for result in results] hyps = [result.text for result in results]
hyps = remove_punctuation(hyps) hyps = remove_punctuation(hyps)
hyps = to_simple(hyps) hyps = to_simple(hyps)
hyps = [params.normalizer.normalize(hyp) for hyp in hyps] hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
key = "beam-search" return {"beam-search": hyps}
return {key: hyps}
def decode_dataset( def decode_dataset(
@ -306,28 +296,14 @@ def decode_dataset(
"""Decode dataset. """Decode dataset.
Args: Args:
dl: dl:
PyTorch's dataloader containing the dataset to decode. The dataloader.
params: params:
It is returned by :func:`get_params`. It is returned by :func:`get_params`.
model: model:
The neural model. The neural model.
HLG:
The decoding graph. Used when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
lexicon:
It contains the token symbol table and the word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
Returns: Returns:
Return a dict, whose key may be "no-rescore" if the decoding method is Return a dict, whose key may be "beam-search".
1best or it may be "ngram_lm_scale_0.7_attention_scale_0.5" if attention
rescoring is used. Its value is a list of tuples. Each tuple contains two
elements: The first is the reference transcript, and the second is the
predicted result.
""" """
results = [] results = []
@ -376,7 +352,9 @@ def save_results(
enable_log = True enable_log = True
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" recog_path = (
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results) results = sorted(results)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
if enable_log: if enable_log:
@ -384,7 +362,9 @@ def save_results(
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" errs_filename = (
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
# we compute CER for aishell dataset. # we compute CER for aishell dataset.
results_char = [] results_char = []
for res in results: for res in results:
@ -423,13 +403,20 @@ def main():
params = get_params() params = get_params()
params.update(vars(args)) params.update(vars(args))
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
setup_logger(f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}") setup_logger(
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
)
options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=params.beam_size) options = whisper.DecodingOptions(
task="transcribe",
language="zh",
without_timestamps=True,
beam_size=params.beam_size,
)
params.decoding_options = options params.decoding_options = options
params.cleaner = BasicTextNormalizer() params.cleaner = BasicTextNormalizer()
params.normalizer = Normalizer() params.normalizer = Normalizer()
logging.info("Decoding started") logging.info("Decoding started")
logging.info(params) logging.info(params)
@ -441,39 +428,47 @@ def main():
if params.remove_whisper_encoder_input_length_restriction: if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward() replace_whisper_encoder_forward()
model = whisper.load_model(params.model_name, 'cpu') model = whisper.load_model(params.model_name, "cpu")
if params.epoch > 0: if params.epoch > 0:
if params.avg > 1: if params.avg > 1:
start = params.epoch - params.avg start = params.epoch - params.avg
assert start >= 1, start assert start >= 1, start
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu') checkpoint = torch.load(
if 'model' not in checkpoint: f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
filenames = [f"{params.exp_dir}/epoch-{epoch}.pt" for epoch in range(start, params.epoch + 1)]
model.load_state_dict(average_checkpoints(filenames))
else:
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
) )
model.to(device) if "model" not in checkpoint:
model.load_state_dict( # deepspeed converted checkpoint only contains model state_dict
average_checkpoints_with_averaged_model( filenames = [
filename_start=filename_start, f"{params.exp_dir}/epoch-{epoch}.pt"
filename_end=filename_end, for epoch in range(start, params.epoch + 1)
device=device, ]
model.load_state_dict(average_checkpoints(filenames))
else:
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
) )
) model.to(device)
# save checkpoints model.load_state_dict(
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" average_checkpoints_with_averaged_model(
torch.save(model.state_dict(), filename) filename_start=filename_start,
else: filename_end=filename_end,
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu') device=device,
if 'model' not in checkpoint: )
model.load_state_dict(checkpoint, strict=True) )
# save checkpoints
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save(model.state_dict(), filename)
else: else:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
if "model" not in checkpoint:
model.load_state_dict(checkpoint, strict=True)
else:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device) model.to(device)
model.eval() model.eval()
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])

View File

@ -35,4 +35,4 @@
"steps_per_print": 50, "steps_per_print": 50,
"train_micro_batch_size_per_gpu": 1, "train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": false "wall_clock_breakdown": false
} }

View File

@ -7,4 +7,4 @@ librosa
git+https://github.com/yuekaizhang/whisper.git git+https://github.com/yuekaizhang/whisper.git
zhconv zhconv
WeTextProcessing WeTextProcessing
deepspeed deepspeed

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) # Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
# 2024 Yuekai Zhang
# #
# See ../../../../LICENSE for clarification regarding multiple authors # See ../../../../LICENSE for clarification regarding multiple authors
# #
@ -41,44 +42,37 @@ import random
import warnings import warnings
from pathlib import Path from pathlib import Path
from shutil import copyfile from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
import deepspeed
import k2 import k2
import optim import optim
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
from typing import List import whisper
from asr_datamodule import AishellAsrDataModule from asr_datamodule import AishellAsrDataModule
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest from lhotse import CutSet, load_manifest
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.functional import pad as pad_tensor from torch.nn.functional import pad as pad_tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall import diagnostics from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import ( from icefall.checkpoint import update_averaged_model
save_checkpoint_with_global_batch_idx, from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist
update_averaged_model,
)
from icefall.dist import cleanup_dist, setup_dist, get_world_size, get_rank, get_local_rank
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
MetricsTracker, MetricsTracker,
@ -87,10 +81,6 @@ from icefall.utils import (
str2bool, str2bool,
) )
import whisper
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from label_smoothing import LabelSmoothingLoss
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -102,6 +92,7 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if hasattr(module, "batch_count"): if hasattr(module, "batch_count"):
module.batch_count = batch_count module.batch_count = batch_count
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -247,39 +238,17 @@ def get_params() -> AttributeDict:
Explanation of options saved in `params`: Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select - frame_shift_ms: The frame shift in milliseconds.
the model that has the lowest training loss. It is - allowed_excess_duration_ratio: The allowed excess duration ratio.
updated during the training. - best_train_loss: The best training loss so far.
- best_valid_loss: The best validation loss so far.
- best_valid_loss: Best validation loss so far. It is used to select - best_train_epoch: The epoch where the best training loss is achieved.
the model that has the lowest validation loss. It is - best_valid_epoch: The epoch where the best validation loss is achieved.
updated during the training. - batch_idx_train: The batch index of the current batch.
- log_interval: Log training stats every `log_interval` batches.
- best_train_epoch: It is the epoch that has the best training loss. - reset_interval: Reset the stats every `reset_interval` batches.
- valid_interval: Run validation every `valid_interval` batches.
- best_valid_epoch: It is the epoch that has the best validation loss. - env_info: The environment information.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- encoder_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- warm_step: The warmup period that dictates the decay of the
scale on "simple" (un-pruned) loss.
""" """
params = AttributeDict( params = AttributeDict(
{ {
@ -292,13 +261,14 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0, "batch_idx_train": 0,
"log_interval": 50, "log_interval": 50,
"reset_interval": 200, "reset_interval": 200,
"valid_interval": 9999999, "valid_interval": 5000,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )
return params return params
def load_checkpoint_if_available( def load_checkpoint_if_available(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -414,6 +384,7 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt" best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename) copyfile(src=filename, dst=best_valid_filename)
def compute_loss( def compute_loss(
params: AttributeDict, params: AttributeDict,
tokenizer: whisper.tokenizer.Tokenizer, tokenizer: whisper.tokenizer.Tokenizer,
@ -422,22 +393,21 @@ def compute_loss(
is_training: bool, is_training: bool,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute RNN-T loss given the model and its inputs. Compute the loss for the given batch.
Args: Args:
params: params:
Parameters for training. See :func:`get_params`. It is returned by :func:`get_params`.
model: tokenizer:
The model for training. It is an instance of Zipformer in our case. The tokenizer used to encode the text.
batch: model:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` The model for training.
for the content in it. batch:
is_training: A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
True for training. False for validation. When it is True, this for the content in it.
function enables autograd during computation; when it is False, it is_training:
disables autograd. Whether it is training.
warmup: a floating point value which increases throughout training; Returns:
values >= 1.0 are fully warmed up and have all modules present. Return a tuple of two elements. The first element is the loss tensor.
""" """
# For the uneven-sized batch, the total duration after padding would possibly # For the uneven-sized batch, the total duration after padding would possibly
# cause OOM. Hence, for each batch, which is sorted descendingly by length, # cause OOM. Hence, for each batch, which is sorted descendingly by length,
@ -449,6 +419,7 @@ def compute_loss(
if isinstance(model, DDP): if isinstance(model, DDP):
# get underlying nn.Module # get underlying nn.Module
model = model.module model = model.module
def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor: def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor:
padding_size = max(tensor.shape[0] for tensor in tensors) padding_size = max(tensor.shape[0] for tensor in tensors)
dims = len(tensors[0].shape) dims = len(tensors[0].shape)
@ -479,9 +450,16 @@ def compute_loss(
# remove spaces in texts # remove spaces in texts
texts = [text.replace(" ", "") for text in texts] texts = [text.replace(" ", "") for text in texts]
text_tokens_list = [list(tokenizer.sot_sequence_including_notimestamps) + tokenizer.encode(text) + [tokenizer.eot] for text in texts] text_tokens_list = [
list(tokenizer.sot_sequence_including_notimestamps)
+ tokenizer.encode(text)
+ [tokenizer.eot]
for text in texts
]
# convert it to torch tensor # convert it to torch tensor
text_tokens_list = [torch.LongTensor(text_tokens) for text_tokens in text_tokens_list] text_tokens_list = [
torch.LongTensor(text_tokens) for text_tokens in text_tokens_list
]
# 50256 is the index of <pad> for all whisper models # 50256 is the index of <pad> for all whisper models
prev_outputs_tokens = _batch_tensors( prev_outputs_tokens = _batch_tensors(
@ -494,9 +472,11 @@ def compute_loss(
[tokens.shape[0] - 1 for tokens in text_tokens_list] [tokens.shape[0] - 1 for tokens in text_tokens_list]
) )
decoder_criterion = LabelSmoothingLoss(ignore_index=50256, label_smoothing=0.1, reduction="sum") decoder_criterion = LabelSmoothingLoss(
ignore_index=50256, label_smoothing=0.1, reduction="sum"
)
# ignore the first 3 tokens, which are always <sos>, <lang_id>, <transcibe> # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|>
ignore_prefix_size = 3 ignore_prefix_size = 3
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
encoder_out = model.encoder(feature) encoder_out = model.encoder(feature)
@ -623,7 +603,7 @@ def train_one_epoch(
valid_info.write_summary( valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train tb_writer, "train/valid_", params.batch_idx_train
) )
try: try:
with torch.cuda.amp.autocast(enabled=params.use_fp16): with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss( loss, loss_info = compute_loss(
@ -687,16 +667,24 @@ def train_one_epoch(
if batch_idx % params.log_interval == 0: if batch_idx % params.log_interval == 0:
try: try:
cur_lr = scheduler.get_last_lr()[0] cur_lr = scheduler.get_last_lr()[0]
except: except: # noqa
cur_lr = 0.0 cur_lr = 0.0
cur_grad_scale = scaler._scale.item() if (params.use_fp16 and not params.deepspeed) else 1.0 cur_grad_scale = (
scaler._scale.item()
if (params.use_fp16 and not params.deepspeed)
else 1.0
)
logging.info( logging.info(
f"Epoch {params.cur_epoch}, " f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], " f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, " f"lr: {cur_lr:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if (params.use_fp16 and not params.deepspeed) else "") + (
f"grad_scale: {scaler._scale.item()}"
if (params.use_fp16 and not params.deepspeed)
else ""
)
) )
if tb_writer is not None: if tb_writer is not None:
@ -715,7 +703,6 @@ def train_one_epoch(
params.batch_idx_train, params.batch_idx_train,
) )
loss_value = tot_loss["loss"] / tot_loss["frames"] loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value params.train_loss = loss_value
if params.train_loss < params.best_train_loss: if params.train_loss < params.best_train_loss:
@ -744,15 +731,18 @@ def run(rank, world_size, args):
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")
replace_whisper_encoder_forward() replace_whisper_encoder_forward()
model = whisper.load_model(params.model_name, 'cpu') model = whisper.load_model(params.model_name, "cpu")
del model.alignment_heads del model.alignment_heads
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
tokenizer = whisper.tokenizer.get_tokenizer( tokenizer = whisper.tokenizer.get_tokenizer(
model.is_multilingual, num_languages=model.num_languages, language="zh", task="transcribe" model.is_multilingual,
num_languages=model.num_languages,
language="zh",
task="transcribe",
) )
model_avg: Optional[nn.Module] = None model_avg: Optional[nn.Module] = None
@ -791,7 +781,8 @@ def run(rank, world_size, args):
if params.deepspeed: if params.deepspeed:
logging.info("Using DeepSpeed") logging.info("Using DeepSpeed")
model, optimizer, _, scheduler = deepspeed.initialize( model, optimizer, _, scheduler = deepspeed.initialize(
args=params, model=model, model_parameters=model.parameters()) args=params, model=model, model_parameters=model.parameters()
)
else: else:
logging.info("Using DDP") logging.info("Using DDP")
setup_dist(use_ddp_launch=True) setup_dist(use_ddp_launch=True)
@ -860,13 +851,17 @@ def run(rank, world_size, args):
break break
if params.deepspeed: if params.deepspeed:
model.save_checkpoint(save_dir=params.exp_dir, model.save_checkpoint(
tag=f"epoch-{params.cur_epoch}", save_dir=params.exp_dir,
client_state={}) tag=f"epoch-{params.cur_epoch}",
client_state={},
)
if rank == 0: if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict( convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir, f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", params.exp_dir,
tag=f"epoch-{params.cur_epoch}") f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
tag=f"epoch-{params.cur_epoch}",
)
else: else:
save_checkpoint( save_checkpoint(
params=params, params=params,
@ -924,5 +919,6 @@ def main():
torch.set_num_interop_threads(1) torch.set_num_interop_threads(1)
run(rank=rank, world_size=world_size, args=args) run(rank=rank, world_size=world_size, args=args)
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -1,6 +1,8 @@
import torch import torch
import torch.nn.functional as F
import whisper import whisper
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
""" """
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
@ -10,7 +12,7 @@ def forward(self, x: torch.Tensor):
x = F.gelu(self.conv2(x)) x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1) x = x.permute(0, 2, 1)
x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype) x = (x + self.positional_embedding[: x.shape[1], :]).to(x.dtype)
for block in self.blocks: for block in self.blocks:
x = block(x) x = block(x)
@ -18,6 +20,7 @@ def forward(self, x: torch.Tensor):
x = self.ln_post(x) x = self.ln_post(x)
return x return x
def replace_whisper_encoder_forward(): def replace_whisper_encoder_forward():
""" """
This function monkey patches the forward method of the whisper encoder. This function monkey patches the forward method of the whisper encoder.

View File

@ -22,16 +22,25 @@ It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank. The generated fbank features are saved in data/fbank.
""" """
import argparse
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter, MonoCut, combine from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
MonoCut,
WhisperFbank,
WhisperFbankConfig,
combine,
)
from lhotse.recipes.utils import read_manifests_if_cached from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or # Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down. # it wastes a lot of CPU and slow things down.
@ -81,7 +90,9 @@ def compute_fbank_musan(num_mel_bins: int = 80, whisper_fbank: bool = False):
logging.info("Extracting features for Musan") logging.info("Extracting features for Musan")
if whisper_fbank: if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda')) extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else: else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -103,6 +114,7 @@ def compute_fbank_musan(num_mel_bins: int = 80, whisper_fbank: bool = False):
) )
musan_cuts.to_file(musan_cuts_path) musan_cuts.to_file(musan_cuts_path)
def get_args(): def get_args():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
@ -119,10 +131,12 @@ def get_args():
) )
return parser.parse_args() return parser.parse_args()
if __name__ == "__main__": if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO) logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_musan( compute_fbank_musan(
num_mel_bins=args.num_mel_bins, whisper_fbank=args.whisper_fbank num_mel_bins=args.num_mel_bins, whisper_fbank=args.whisper_fbank
) )