This commit is contained in:
root 2024-01-22 08:10:26 +00:00
parent b623c3be15
commit 8d9ab308af
10 changed files with 257 additions and 229 deletions

View File

@ -24,3 +24,10 @@ The following table lists the differences among them.
The decoder in `transducer_stateless` is modified from the paper
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
We place an additional Conv1d layer right after the input embedding layer.
# Whisper
Recipe to finetune large pretrained models
| | Encoder | Decoder | Comment |
|------------------------------------|-----------|--------------------|-----------------------------------------------------------------------------------|
| `whisper` | Transformer | Transformer | support fine-tuning using deepspeed

View File

@ -77,7 +77,7 @@ It's reworked Zipformer with Pruned RNNT loss.
Command for training is:
```bash
./prepare.sh
./prepare.sh
export CUDA_VISIBLE_DEVICES="0,1"
@ -142,7 +142,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--max-duration 1200
--max-duration 1200
```
Command for decoding is:
@ -192,7 +192,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--max-duration 800
--max-duration 800
```
Command for decoding is:
@ -208,7 +208,7 @@ for m in greedy_search modified_beam_search fast_beam_search ; do
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192
--encoder-unmasked-dim 192,192,256,320,256,192
done
```

View File

@ -29,7 +29,14 @@ import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
WhisperFbank,
WhisperFbankConfig,
)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor, str2bool
@ -42,7 +49,9 @@ torch.set_num_threads(1)
torch.set_num_interop_threads(1)
def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False):
def compute_fbank_aishell(
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
):
src_dir = Path("data/manifests")
output_dir = Path("data/fbank")
num_jobs = min(15, os.cpu_count())
@ -69,7 +78,9 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, w
dataset_parts,
)
if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -84,7 +95,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, w
supervisions=m["supervisions"],
)
if "train" in partition and perturb_speed:
logging.info(f"Doing speed perturb")
logging.info("Doing speed perturb")
cut_set = (
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
)
@ -129,5 +140,7 @@ if __name__ == "__main__":
args = get_args()
compute_fbank_aishell(
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank
num_mel_bins=args.num_mel_bins,
perturb_speed=args.perturb_speed,
whisper_fbank=args.whisper_fbank,
)

View File

@ -387,4 +387,4 @@ if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
touch data/fbank/.aishell.whisper.done
fi
fi
fi

View File

@ -2,6 +2,7 @@
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
# Fangjun Kuang,
# Wei Kang)
# 2024 Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -42,44 +43,37 @@ python3 ./whisper/decode.py \
import argparse
import logging
import re
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import whisper
from whisper.normalizers import BasicTextNormalizer
import k2
import torch
import torch.nn as nn
import whisper
from asr_datamodule import AishellAsrDataModule
#from model import load_model
from tn.chinese.normalizer import Normalizer
from whisper.normalizers import BasicTextNormalizer
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model
from icefall.decode import (
get_lattice,
nbest_decoding,
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
)
from zhconv import convert
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
from icefall.env import get_env_info
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
get_texts,
setup_logger,
store_transcripts,
write_error_stats,
str2bool,
write_error_stats,
)
from zhconv import convert
from tn.chinese.normalizer import Normalizer
import re
def average_checkpoints(
filenames: List[Path], device: torch.device = torch.device("cpu")
) -> dict:
"""Average a list of checkpoints.
The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict.
Args:
filenames:
@ -94,9 +88,9 @@ def average_checkpoints(
n = len(filenames)
if "model" in torch.load(filenames[0], map_location=device):
avg = torch.load(filenames[0], map_location=device)["model"]
avg = torch.load(filenames[0], map_location=device)["model"]
else:
avg = torch.load(filenames[0], map_location=device)
avg = torch.load(filenames[0], map_location=device)
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
@ -112,9 +106,9 @@ def average_checkpoints(
for i in range(1, n):
if "model" in torch.load(filenames[i], map_location=device):
state_dict = torch.load(filenames[i], map_location=device)["model"]
state_dict = torch.load(filenames[i], map_location=device)["model"]
else:
state_dict = torch.load(filenames[i], map_location=device)
state_dict = torch.load(filenames[i], map_location=device)
for k in uniqued_names:
avg[k] += state_dict[k]
@ -126,33 +120,48 @@ def average_checkpoints(
return avg
def remove_punctuation(text: str or List[str]):
# https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
punctuation = '!,.;:?、!,。;:?《》 '
"""Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
Args:
text: It can be a string or a list of strings.
Returns:
Return a string or a list of strings without any punctuation.
"""
punctuation = "!,.;:?、!,。;:?《》 "
if isinstance(text, str):
text = re.sub(r'[{}]+'.format(punctuation), '', text).strip()
text = re.sub(r"[{}]+".format(punctuation), "", text).strip()
return text
elif isinstance(text, list):
result_text = []
for t in text:
t = re.sub(r'[{}]+'.format(punctuation), '', t).strip()
t = re.sub(r"[{}]+".format(punctuation), "", t).strip()
result_text.append(t)
return result_text
else:
raise Exception(f'Not support type {type(text)}')
raise Exception(f"Not support type {type(text)}")
def to_simple(text: str or List[str]):
"""Convert traditional Chinese to simplified Chinese.
Args:
text: It can be a string or a list of strings.
Returns:
Return a string or a list of strings converted to simplified Chinese.
"""
if isinstance(text, str):
text = convert(text, 'zh-cn')
text = convert(text, "zh-cn")
return text
elif isinstance(text, list):
result_text = []
for t in text:
t = convert(t, 'zh-cn')
t = convert(t, "zh-cn")
result_text.append(t)
return result_text
else:
raise Exception(f'Not support type{type(text)}')
raise Exception(f"Not support type{type(text)}")
def get_parser():
parser = argparse.ArgumentParser(
@ -214,7 +223,7 @@ def get_parser():
default=True,
help="replace whisper encoder forward method to remove input length restriction",
)
return parser
@ -226,6 +235,7 @@ def get_params() -> AttributeDict:
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
@ -234,42 +244,17 @@ def decode_one_batch(
"""Decode one batch and return the result in a dict. The dict has the
following format:
- key: It indicates the setting used for decoding. For example,
if decoding method is 1best, the key is the string `no_rescore`.
If attention rescoring is used, the key is the string
`ngram_lm_scale_xxx_attention_scale_xxx`, where `xxx` is the
value of `lm_scale` and `attention_scale`. An example key is
`ngram_lm_scale_0.7_attention_scale_0.5`
- value: It contains the decoding result. `len(value)` equals to
batch size. `value[i]` is the decoding result for the i-th
utterance in the given batch.
- key: "beam-search"
- value: A list of lists. Each sublist is a list of token IDs.
Args:
params:
It's the return value of :func:`get_params`.
- params.method is "1best", it uses 1best decoding without LM rescoring.
- params.method is "nbest", it uses nbest decoding without LM rescoring.
- params.method is "attention-decoder", it uses attention rescoring.
model:
The neural model.
HLG:
The decoding graph. Used when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
lexicon:
It contains the token symbol table and the word symbol table.
sos_id:
The token ID of the SOS.
eos_id:
The token ID of the EOS.
params:
It is returned by :func:`get_params`.
model:
The neural model.
batch:
It is returned by :meth:`torch.utils.data.DataLoader.__iter__`.
Returns:
Return the decoding result. See above description for the format of
the returned dict.
Return a dict, whose key may be "beam-search".
"""
dtype = torch.float16
device = torch.device("cuda")
@ -280,22 +265,27 @@ def decode_one_batch(
if not params.remove_whisper_encoder_input_length_restriction:
T = 3000
if feature.shape[2] < T:
feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2)
feature = torch.cat(
[
feature,
torch.zeros(
feature.shape[0], feature.shape[1], T - feature.shape[2]
).to(device, dtype=dtype),
],
2,
)
supervisions = batch["supervisions"]
feature_len = supervisions["num_frames"]
feature_len = feature_len.to(device, dtype=dtype)
results = model.decode(feature, params.decoding_options)
hyps = [result.text for result in results]
hyps = remove_punctuation(hyps)
hyps = to_simple(hyps)
hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
key = "beam-search"
return {key: hyps}
return {"beam-search": hyps}
def decode_dataset(
@ -306,28 +296,14 @@ def decode_dataset(
"""Decode dataset.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
HLG:
The decoding graph. Used when params.method is NOT ctc-decoding.
H:
The ctc topo. Used only when params.method is ctc-decoding.
lexicon:
It contains the token symbol table and the word symbol table.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
dl:
The dataloader.
params:
It is returned by :func:`get_params`.
model:
The neural model.
Returns:
Return a dict, whose key may be "no-rescore" if the decoding method is
1best or it may be "ngram_lm_scale_0.7_attention_scale_0.5" if attention
rescoring is used. Its value is a list of tuples. Each tuple contains two
elements: The first is the reference transcript, and the second is the
predicted result.
Return a dict, whose key may be "beam-search".
"""
results = []
@ -376,7 +352,9 @@ def save_results(
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
recog_path = (
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
)
results = sorted(results)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
@ -384,7 +362,9 @@ def save_results(
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
errs_filename = (
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
)
# we compute CER for aishell dataset.
results_char = []
for res in results:
@ -423,13 +403,20 @@ def main():
params = get_params()
params.update(vars(args))
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
setup_logger(f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}")
setup_logger(
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
)
options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=params.beam_size)
options = whisper.DecodingOptions(
task="transcribe",
language="zh",
without_timestamps=True,
beam_size=params.beam_size,
)
params.decoding_options = options
params.cleaner = BasicTextNormalizer()
params.normalizer = Normalizer()
logging.info("Decoding started")
logging.info(params)
@ -441,39 +428,47 @@ def main():
if params.remove_whisper_encoder_input_length_restriction:
replace_whisper_encoder_forward()
model = whisper.load_model(params.model_name, 'cpu')
model = whisper.load_model(params.model_name, "cpu")
if params.epoch > 0:
if params.avg > 1:
start = params.epoch - params.avg
assert start >= 1, start
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
if 'model' not in checkpoint:
filenames = [f"{params.exp_dir}/epoch-{epoch}.pt" for epoch in range(start, params.epoch + 1)]
model.load_state_dict(average_checkpoints(filenames))
else:
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
if params.avg > 1:
start = params.epoch - params.avg
assert start >= 1, start
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
if "model" not in checkpoint:
# deepspeed converted checkpoint only contains model state_dict
filenames = [
f"{params.exp_dir}/epoch-{epoch}.pt"
for epoch in range(start, params.epoch + 1)
]
model.load_state_dict(average_checkpoints(filenames))
else:
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"Calculating the averaged model over epoch range from "
f"{start} (excluded) to {params.epoch}"
)
)
# save checkpoints
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save(model.state_dict(), filename)
else:
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
if 'model' not in checkpoint:
model.load_state_dict(checkpoint, strict=True)
model.to(device)
model.load_state_dict(
average_checkpoints_with_averaged_model(
filename_start=filename_start,
filename_end=filename_end,
device=device,
)
)
# save checkpoints
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
torch.save(model.state_dict(), filename)
else:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
checkpoint = torch.load(
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
)
if "model" not in checkpoint:
model.load_state_dict(checkpoint, strict=True)
else:
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
model.to(device)
model.eval()
num_param = sum([p.numel() for p in model.parameters()])

View File

@ -35,4 +35,4 @@
"steps_per_print": 50,
"train_micro_batch_size_per_gpu": 1,
"wall_clock_breakdown": false
}
}

View File

@ -7,4 +7,4 @@ librosa
git+https://github.com/yuekaizhang/whisper.git
zhconv
WeTextProcessing
deepspeed
deepspeed

View File

@ -1,5 +1,6 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
# 2024 Yuekai Zhang
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
@ -41,44 +42,37 @@ import random
import warnings
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Optional, Tuple, Union
import deepspeed
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from typing import Any, Dict, List, Optional, Tuple, Union
import deepspeed
import k2
import optim
import torch
import torch.multiprocessing as mp
import torch.nn as nn
from typing import List
import whisper
from asr_datamodule import AishellAsrDataModule
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
from label_smoothing import LabelSmoothingLoss
from lhotse import CutSet, load_manifest
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from optim import Eden, ScaledAdam
from torch import Tensor
from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.functional import pad as pad_tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from icefall import diagnostics
from icefall.checkpoint import load_checkpoint, remove_checkpoints
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.checkpoint import (
save_checkpoint_with_global_batch_idx,
update_averaged_model,
)
from icefall.dist import cleanup_dist, setup_dist, get_world_size, get_rank, get_local_rank
from icefall.checkpoint import update_averaged_model
from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist
from icefall.env import get_env_info
from icefall.hooks import register_inf_check_hooks
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
MetricsTracker,
@ -87,10 +81,6 @@ from icefall.utils import (
str2bool,
)
import whisper
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
from label_smoothing import LabelSmoothingLoss
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
@ -102,6 +92,7 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
if hasattr(module, "batch_count"):
module.batch_count = batch_count
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
@ -247,39 +238,17 @@ def get_params() -> AttributeDict:
Explanation of options saved in `params`:
- best_train_loss: Best training loss so far. It is used to select
the model that has the lowest training loss. It is
updated during the training.
- best_valid_loss: Best validation loss so far. It is used to select
the model that has the lowest validation loss. It is
updated during the training.
- best_train_epoch: It is the epoch that has the best training loss.
- best_valid_epoch: It is the epoch that has the best validation loss.
- batch_idx_train: Used to writing statistics to tensorboard. It
contains number of batches trained so far across
epochs.
- log_interval: Print training loss if batch_idx % log_interval` is 0
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
- valid_interval: Run validation if batch_idx % valid_interval is 0
- feature_dim: The model input dim. It has to match the one used
in computing features.
- subsampling_factor: The subsampling factor for the model.
- encoder_dim: Hidden dim for multi-head attention model.
- num_decoder_layers: Number of decoder layer of transformer decoder.
- warm_step: The warmup period that dictates the decay of the
scale on "simple" (un-pruned) loss.
- frame_shift_ms: The frame shift in milliseconds.
- allowed_excess_duration_ratio: The allowed excess duration ratio.
- best_train_loss: The best training loss so far.
- best_valid_loss: The best validation loss so far.
- best_train_epoch: The epoch where the best training loss is achieved.
- best_valid_epoch: The epoch where the best validation loss is achieved.
- batch_idx_train: The batch index of the current batch.
- log_interval: Log training stats every `log_interval` batches.
- reset_interval: Reset the stats every `reset_interval` batches.
- valid_interval: Run validation every `valid_interval` batches.
- env_info: The environment information.
"""
params = AttributeDict(
{
@ -292,13 +261,14 @@ def get_params() -> AttributeDict:
"batch_idx_train": 0,
"log_interval": 50,
"reset_interval": 200,
"valid_interval": 9999999,
"valid_interval": 5000,
"env_info": get_env_info(),
}
)
return params
def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
@ -414,6 +384,7 @@ def save_checkpoint(
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
copyfile(src=filename, dst=best_valid_filename)
def compute_loss(
params: AttributeDict,
tokenizer: whisper.tokenizer.Tokenizer,
@ -422,22 +393,21 @@ def compute_loss(
is_training: bool,
) -> Tuple[Tensor, MetricsTracker]:
"""
Compute RNN-T loss given the model and its inputs.
Compute the loss for the given batch.
Args:
params:
Parameters for training. See :func:`get_params`.
model:
The model for training. It is an instance of Zipformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it
disables autograd.
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
params:
It is returned by :func:`get_params`.
tokenizer:
The tokenizer used to encode the text.
model:
The model for training.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
for the content in it.
is_training:
Whether it is training.
Returns:
Return a tuple of two elements. The first element is the loss tensor.
"""
# For the uneven-sized batch, the total duration after padding would possibly
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
@ -449,6 +419,7 @@ def compute_loss(
if isinstance(model, DDP):
# get underlying nn.Module
model = model.module
def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor:
padding_size = max(tensor.shape[0] for tensor in tensors)
dims = len(tensors[0].shape)
@ -479,9 +450,16 @@ def compute_loss(
# remove spaces in texts
texts = [text.replace(" ", "") for text in texts]
text_tokens_list = [list(tokenizer.sot_sequence_including_notimestamps) + tokenizer.encode(text) + [tokenizer.eot] for text in texts]
text_tokens_list = [
list(tokenizer.sot_sequence_including_notimestamps)
+ tokenizer.encode(text)
+ [tokenizer.eot]
for text in texts
]
# convert it to torch tensor
text_tokens_list = [torch.LongTensor(text_tokens) for text_tokens in text_tokens_list]
text_tokens_list = [
torch.LongTensor(text_tokens) for text_tokens in text_tokens_list
]
# 50256 is the index of <pad> for all whisper models
prev_outputs_tokens = _batch_tensors(
@ -494,9 +472,11 @@ def compute_loss(
[tokens.shape[0] - 1 for tokens in text_tokens_list]
)
decoder_criterion = LabelSmoothingLoss(ignore_index=50256, label_smoothing=0.1, reduction="sum")
decoder_criterion = LabelSmoothingLoss(
ignore_index=50256, label_smoothing=0.1, reduction="sum"
)
# ignore the first 3 tokens, which are always <sos>, <lang_id>, <transcibe>
# ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|>
ignore_prefix_size = 3
with torch.set_grad_enabled(is_training):
encoder_out = model.encoder(feature)
@ -623,7 +603,7 @@ def train_one_epoch(
valid_info.write_summary(
tb_writer, "train/valid_", params.batch_idx_train
)
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
loss, loss_info = compute_loss(
@ -687,16 +667,24 @@ def train_one_epoch(
if batch_idx % params.log_interval == 0:
try:
cur_lr = scheduler.get_last_lr()[0]
except:
except: # noqa
cur_lr = 0.0
cur_grad_scale = scaler._scale.item() if (params.use_fp16 and not params.deepspeed) else 1.0
cur_grad_scale = (
scaler._scale.item()
if (params.use_fp16 and not params.deepspeed)
else 1.0
)
logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if (params.use_fp16 and not params.deepspeed) else "")
+ (
f"grad_scale: {scaler._scale.item()}"
if (params.use_fp16 and not params.deepspeed)
else ""
)
)
if tb_writer is not None:
@ -715,7 +703,6 @@ def train_one_epoch(
params.batch_idx_train,
)
loss_value = tot_loss["loss"] / tot_loss["frames"]
params.train_loss = loss_value
if params.train_loss < params.best_train_loss:
@ -744,15 +731,18 @@ def run(rank, world_size, args):
logging.info(params)
logging.info("About to create model")
replace_whisper_encoder_forward()
model = whisper.load_model(params.model_name, 'cpu')
model = whisper.load_model(params.model_name, "cpu")
del model.alignment_heads
num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}")
tokenizer = whisper.tokenizer.get_tokenizer(
model.is_multilingual, num_languages=model.num_languages, language="zh", task="transcribe"
model.is_multilingual,
num_languages=model.num_languages,
language="zh",
task="transcribe",
)
model_avg: Optional[nn.Module] = None
@ -791,7 +781,8 @@ def run(rank, world_size, args):
if params.deepspeed:
logging.info("Using DeepSpeed")
model, optimizer, _, scheduler = deepspeed.initialize(
args=params, model=model, model_parameters=model.parameters())
args=params, model=model, model_parameters=model.parameters()
)
else:
logging.info("Using DDP")
setup_dist(use_ddp_launch=True)
@ -860,13 +851,17 @@ def run(rank, world_size, args):
break
if params.deepspeed:
model.save_checkpoint(save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}",
client_state={})
model.save_checkpoint(
save_dir=params.exp_dir,
tag=f"epoch-{params.cur_epoch}",
client_state={},
)
if rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(
params.exp_dir, f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
tag=f"epoch-{params.cur_epoch}")
params.exp_dir,
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
tag=f"epoch-{params.cur_epoch}",
)
else:
save_checkpoint(
params=params,
@ -924,5 +919,6 @@ def main():
torch.set_num_interop_threads(1)
run(rank=rank, world_size=world_size, args=args)
if __name__ == "__main__":
main()

View File

@ -1,6 +1,8 @@
import torch
import torch.nn.functional as F
import whisper
def forward(self, x: torch.Tensor):
"""
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
@ -10,7 +12,7 @@ def forward(self, x: torch.Tensor):
x = F.gelu(self.conv2(x))
x = x.permute(0, 2, 1)
x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype)
x = (x + self.positional_embedding[: x.shape[1], :]).to(x.dtype)
for block in self.blocks:
x = block(x)
@ -18,6 +20,7 @@ def forward(self, x: torch.Tensor):
x = self.ln_post(x)
return x
def replace_whisper_encoder_forward():
"""
This function monkey patches the forward method of the whisper encoder.

View File

@ -22,16 +22,25 @@ It looks for manifests in the directory data/manifests.
The generated fbank features are saved in data/fbank.
"""
import argparse
import logging
import os
from pathlib import Path
import torch
from lhotse import CutSet, Fbank, FbankConfig, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter, MonoCut, combine
from lhotse import (
CutSet,
Fbank,
FbankConfig,
LilcomChunkyWriter,
MonoCut,
WhisperFbank,
WhisperFbankConfig,
combine,
)
from lhotse.recipes.utils import read_manifests_if_cached
from icefall.utils import get_executor
from icefall.utils import get_executor, str2bool
# Torch's multithreaded behavior needs to be disabled or
# it wastes a lot of CPU and slow things down.
@ -81,7 +90,9 @@ def compute_fbank_musan(num_mel_bins: int = 80, whisper_fbank: bool = False):
logging.info("Extracting features for Musan")
if whisper_fbank:
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
extractor = WhisperFbank(
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
)
else:
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
@ -103,6 +114,7 @@ def compute_fbank_musan(num_mel_bins: int = 80, whisper_fbank: bool = False):
)
musan_cuts.to_file(musan_cuts_path)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
@ -119,10 +131,12 @@ def get_args():
)
return parser.parse_args()
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
args = get_args()
compute_fbank_musan(
num_mel_bins=args.num_mel_bins, whisper_fbank=args.whisper_fbank
)