mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
fix lint
This commit is contained in:
parent
b623c3be15
commit
8d9ab308af
@ -24,3 +24,10 @@ The following table lists the differences among them.
|
||||
The decoder in `transducer_stateless` is modified from the paper
|
||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||
We place an additional Conv1d layer right after the input embedding layer.
|
||||
|
||||
# Whisper
|
||||
|
||||
Recipe to finetune large pretrained models
|
||||
| | Encoder | Decoder | Comment |
|
||||
|------------------------------------|-----------|--------------------|-----------------------------------------------------------------------------------|
|
||||
| `whisper` | Transformer | Transformer | support fine-tuning using deepspeed
|
||||
|
@ -77,7 +77,7 @@ It's reworked Zipformer with Pruned RNNT loss.
|
||||
|
||||
Command for training is:
|
||||
```bash
|
||||
./prepare.sh
|
||||
./prepare.sh
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
|
||||
@ -142,7 +142,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
|
||||
--feedforward-dim 512,768,768,768,768,768 \
|
||||
--encoder-dim 192,256,256,256,256,256 \
|
||||
--encoder-unmasked-dim 192,192,192,192,192,192 \
|
||||
--max-duration 1200
|
||||
--max-duration 1200
|
||||
```
|
||||
|
||||
Command for decoding is:
|
||||
@ -192,7 +192,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
|
||||
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||
--encoder-dim 192,256,512,768,512,256 \
|
||||
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
||||
--max-duration 800
|
||||
--max-duration 800
|
||||
```
|
||||
|
||||
Command for decoding is:
|
||||
@ -208,7 +208,7 @@ for m in greedy_search modified_beam_search fast_beam_search ; do
|
||||
--num-encoder-layers 2,2,4,5,4,2 \
|
||||
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||
--encoder-dim 192,256,512,768,512,256 \
|
||||
--encoder-unmasked-dim 192,192,256,320,256,192
|
||||
--encoder-unmasked-dim 192,192,256,320,256,192
|
||||
done
|
||||
```
|
||||
|
||||
|
@ -29,7 +29,14 @@ import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor, str2bool
|
||||
@ -42,7 +49,9 @@ torch.set_num_threads(1)
|
||||
torch.set_num_interop_threads(1)
|
||||
|
||||
|
||||
def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False):
|
||||
def compute_fbank_aishell(
|
||||
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||
):
|
||||
src_dir = Path("data/manifests")
|
||||
output_dir = Path("data/fbank")
|
||||
num_jobs = min(15, os.cpu_count())
|
||||
@ -69,7 +78,9 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, w
|
||||
dataset_parts,
|
||||
)
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -84,7 +95,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, w
|
||||
supervisions=m["supervisions"],
|
||||
)
|
||||
if "train" in partition and perturb_speed:
|
||||
logging.info(f"Doing speed perturb")
|
||||
logging.info("Doing speed perturb")
|
||||
cut_set = (
|
||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||
)
|
||||
@ -129,5 +140,7 @@ if __name__ == "__main__":
|
||||
|
||||
args = get_args()
|
||||
compute_fbank_aishell(
|
||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank
|
||||
num_mel_bins=args.num_mel_bins,
|
||||
perturb_speed=args.perturb_speed,
|
||||
whisper_fbank=args.whisper_fbank,
|
||||
)
|
||||
|
@ -387,4 +387,4 @@ if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
|
||||
./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||
touch data/fbank/.aishell.whisper.done
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
@ -2,6 +2,7 @@
|
||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
|
||||
# Fangjun Kuang,
|
||||
# Wei Kang)
|
||||
# 2024 Yuekai Zhang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -42,44 +43,37 @@ python3 ./whisper/decode.py \
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import whisper
|
||||
from whisper.normalizers import BasicTextNormalizer
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import whisper
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
#from model import load_model
|
||||
from tn.chinese.normalizer import Normalizer
|
||||
from whisper.normalizers import BasicTextNormalizer
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
)
|
||||
from zhconv import convert
|
||||
|
||||
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
setup_logger,
|
||||
store_transcripts,
|
||||
write_error_stats,
|
||||
str2bool,
|
||||
write_error_stats,
|
||||
)
|
||||
from zhconv import convert
|
||||
from tn.chinese.normalizer import Normalizer
|
||||
import re
|
||||
|
||||
|
||||
def average_checkpoints(
|
||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||
) -> dict:
|
||||
"""Average a list of checkpoints.
|
||||
The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict.
|
||||
|
||||
Args:
|
||||
filenames:
|
||||
@ -94,9 +88,9 @@ def average_checkpoints(
|
||||
n = len(filenames)
|
||||
|
||||
if "model" in torch.load(filenames[0], map_location=device):
|
||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||
else:
|
||||
avg = torch.load(filenames[0], map_location=device)
|
||||
avg = torch.load(filenames[0], map_location=device)
|
||||
|
||||
# Identify shared parameters. Two parameters are said to be shared
|
||||
# if they have the same data_ptr
|
||||
@ -112,9 +106,9 @@ def average_checkpoints(
|
||||
|
||||
for i in range(1, n):
|
||||
if "model" in torch.load(filenames[i], map_location=device):
|
||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||
else:
|
||||
state_dict = torch.load(filenames[i], map_location=device)
|
||||
state_dict = torch.load(filenames[i], map_location=device)
|
||||
for k in uniqued_names:
|
||||
avg[k] += state_dict[k]
|
||||
|
||||
@ -126,33 +120,48 @@ def average_checkpoints(
|
||||
|
||||
return avg
|
||||
|
||||
|
||||
def remove_punctuation(text: str or List[str]):
|
||||
# https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
|
||||
punctuation = '!,.;:?、!,。;:?《》 '
|
||||
"""Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
|
||||
|
||||
Args:
|
||||
text: It can be a string or a list of strings.
|
||||
Returns:
|
||||
Return a string or a list of strings without any punctuation.
|
||||
"""
|
||||
punctuation = "!,.;:?、!,。;:?《》 "
|
||||
if isinstance(text, str):
|
||||
text = re.sub(r'[{}]+'.format(punctuation), '', text).strip()
|
||||
text = re.sub(r"[{}]+".format(punctuation), "", text).strip()
|
||||
return text
|
||||
elif isinstance(text, list):
|
||||
result_text = []
|
||||
for t in text:
|
||||
t = re.sub(r'[{}]+'.format(punctuation), '', t).strip()
|
||||
t = re.sub(r"[{}]+".format(punctuation), "", t).strip()
|
||||
result_text.append(t)
|
||||
return result_text
|
||||
else:
|
||||
raise Exception(f'Not support type {type(text)}')
|
||||
raise Exception(f"Not support type {type(text)}")
|
||||
|
||||
|
||||
def to_simple(text: str or List[str]):
|
||||
"""Convert traditional Chinese to simplified Chinese.
|
||||
Args:
|
||||
text: It can be a string or a list of strings.
|
||||
Returns:
|
||||
Return a string or a list of strings converted to simplified Chinese.
|
||||
"""
|
||||
if isinstance(text, str):
|
||||
text = convert(text, 'zh-cn')
|
||||
text = convert(text, "zh-cn")
|
||||
return text
|
||||
elif isinstance(text, list):
|
||||
result_text = []
|
||||
for t in text:
|
||||
t = convert(t, 'zh-cn')
|
||||
t = convert(t, "zh-cn")
|
||||
result_text.append(t)
|
||||
return result_text
|
||||
else:
|
||||
raise Exception(f'Not support type{type(text)}')
|
||||
raise Exception(f"Not support type{type(text)}")
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -214,7 +223,7 @@ def get_parser():
|
||||
default=True,
|
||||
help="replace whisper encoder forward method to remove input length restriction",
|
||||
)
|
||||
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -226,6 +235,7 @@ def get_params() -> AttributeDict:
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -234,42 +244,17 @@ def decode_one_batch(
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if decoding method is 1best, the key is the string `no_rescore`.
|
||||
If attention rescoring is used, the key is the string
|
||||
`ngram_lm_scale_xxx_attention_scale_xxx`, where `xxx` is the
|
||||
value of `lm_scale` and `attention_scale`. An example key is
|
||||
`ngram_lm_scale_0.7_attention_scale_0.5`
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
- key: "beam-search"
|
||||
- value: A list of lists. Each sublist is a list of token IDs.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
|
||||
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
||||
- params.method is "nbest", it uses nbest decoding without LM rescoring.
|
||||
- params.method is "attention-decoder", it uses attention rescoring.
|
||||
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.method is ctc-decoding.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
lexicon:
|
||||
It contains the token symbol table and the word symbol table.
|
||||
sos_id:
|
||||
The token ID of the SOS.
|
||||
eos_id:
|
||||
The token ID of the EOS.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
batch:
|
||||
It is returned by :meth:`torch.utils.data.DataLoader.__iter__`.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
Return a dict, whose key may be "beam-search".
|
||||
"""
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
@ -280,22 +265,27 @@ def decode_one_batch(
|
||||
if not params.remove_whisper_encoder_input_length_restriction:
|
||||
T = 3000
|
||||
if feature.shape[2] < T:
|
||||
feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2)
|
||||
feature = torch.cat(
|
||||
[
|
||||
feature,
|
||||
torch.zeros(
|
||||
feature.shape[0], feature.shape[1], T - feature.shape[2]
|
||||
).to(device, dtype=dtype),
|
||||
],
|
||||
2,
|
||||
)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_len = supervisions["num_frames"]
|
||||
feature_len = feature_len.to(device, dtype=dtype)
|
||||
results = model.decode(feature, params.decoding_options)
|
||||
hyps = [result.text for result in results]
|
||||
|
||||
|
||||
hyps = remove_punctuation(hyps)
|
||||
hyps = to_simple(hyps)
|
||||
|
||||
hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
|
||||
|
||||
key = "beam-search"
|
||||
|
||||
return {key: hyps}
|
||||
return {"beam-search": hyps}
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
@ -306,28 +296,14 @@ def decode_dataset(
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph. Used when params.method is NOT ctc-decoding.
|
||||
H:
|
||||
The ctc topo. Used only when params.method is ctc-decoding.
|
||||
lexicon:
|
||||
It contains the token symbol table and the word symbol table.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
dl:
|
||||
The dataloader.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
Returns:
|
||||
Return a dict, whose key may be "no-rescore" if the decoding method is
|
||||
1best or it may be "ngram_lm_scale_0.7_attention_scale_0.5" if attention
|
||||
rescoring is used. Its value is a list of tuples. Each tuple contains two
|
||||
elements: The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
Return a dict, whose key may be "beam-search".
|
||||
"""
|
||||
results = []
|
||||
|
||||
@ -376,7 +352,9 @@ def save_results(
|
||||
enable_log = True
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
recog_path = (
|
||||
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
results = sorted(results)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
@ -384,7 +362,9 @@ def save_results(
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
errs_filename = (
|
||||
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||
)
|
||||
# we compute CER for aishell dataset.
|
||||
results_char = []
|
||||
for res in results:
|
||||
@ -423,13 +403,20 @@ def main():
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||
setup_logger(f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}")
|
||||
setup_logger(
|
||||
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
|
||||
)
|
||||
|
||||
options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=params.beam_size)
|
||||
options = whisper.DecodingOptions(
|
||||
task="transcribe",
|
||||
language="zh",
|
||||
without_timestamps=True,
|
||||
beam_size=params.beam_size,
|
||||
)
|
||||
params.decoding_options = options
|
||||
params.cleaner = BasicTextNormalizer()
|
||||
params.normalizer = Normalizer()
|
||||
|
||||
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
@ -441,39 +428,47 @@ def main():
|
||||
|
||||
if params.remove_whisper_encoder_input_length_restriction:
|
||||
replace_whisper_encoder_forward()
|
||||
model = whisper.load_model(params.model_name, 'cpu')
|
||||
model = whisper.load_model(params.model_name, "cpu")
|
||||
if params.epoch > 0:
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
|
||||
if 'model' not in checkpoint:
|
||||
filenames = [f"{params.exp_dir}/epoch-{epoch}.pt" for epoch in range(start, params.epoch + 1)]
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
else:
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
if "model" not in checkpoint:
|
||||
# deepspeed converted checkpoint only contains model state_dict
|
||||
filenames = [
|
||||
f"{params.exp_dir}/epoch-{epoch}.pt"
|
||||
for epoch in range(start, params.epoch + 1)
|
||||
]
|
||||
model.load_state_dict(average_checkpoints(filenames))
|
||||
else:
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
)
|
||||
# save checkpoints
|
||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save(model.state_dict(), filename)
|
||||
else:
|
||||
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
|
||||
if 'model' not in checkpoint:
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
# save checkpoints
|
||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||
torch.save(model.state_dict(), filename)
|
||||
else:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
checkpoint = torch.load(
|
||||
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||
)
|
||||
if "model" not in checkpoint:
|
||||
model.load_state_dict(checkpoint, strict=True)
|
||||
else:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
|
@ -35,4 +35,4 @@
|
||||
"steps_per_print": 50,
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"wall_clock_breakdown": false
|
||||
}
|
||||
}
|
||||
|
@ -7,4 +7,4 @@ librosa
|
||||
git+https://github.com/yuekaizhang/whisper.git
|
||||
zhconv
|
||||
WeTextProcessing
|
||||
deepspeed
|
||||
deepspeed
|
||||
|
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||
# 2024 Yuekai Zhang
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
@ -41,44 +42,37 @@ import random
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
import deepspeed
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import deepspeed
|
||||
import k2
|
||||
import optim
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from typing import List
|
||||
|
||||
import whisper
|
||||
from asr_datamodule import AishellAsrDataModule
|
||||
|
||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch.cuda.amp import GradScaler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.functional import pad as pad_tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
|
||||
from icefall import diagnostics
|
||||
|
||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||
from icefall.checkpoint import (
|
||||
save_checkpoint_with_global_batch_idx,
|
||||
update_averaged_model,
|
||||
)
|
||||
from icefall.dist import cleanup_dist, setup_dist, get_world_size, get_rank, get_local_rank
|
||||
from icefall.checkpoint import update_averaged_model
|
||||
from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist
|
||||
from icefall.env import get_env_info
|
||||
from icefall.hooks import register_inf_check_hooks
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
MetricsTracker,
|
||||
@ -87,10 +81,6 @@ from icefall.utils import (
|
||||
str2bool,
|
||||
)
|
||||
|
||||
import whisper
|
||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
|
||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||
|
||||
|
||||
@ -102,6 +92,7 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
||||
if hasattr(module, "batch_count"):
|
||||
module.batch_count = batch_count
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -247,39 +238,17 @@ def get_params() -> AttributeDict:
|
||||
|
||||
Explanation of options saved in `params`:
|
||||
|
||||
- best_train_loss: Best training loss so far. It is used to select
|
||||
the model that has the lowest training loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_valid_loss: Best validation loss so far. It is used to select
|
||||
the model that has the lowest validation loss. It is
|
||||
updated during the training.
|
||||
|
||||
- best_train_epoch: It is the epoch that has the best training loss.
|
||||
|
||||
- best_valid_epoch: It is the epoch that has the best validation loss.
|
||||
|
||||
- batch_idx_train: Used to writing statistics to tensorboard. It
|
||||
contains number of batches trained so far across
|
||||
epochs.
|
||||
|
||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
||||
|
||||
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
||||
|
||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
||||
|
||||
- feature_dim: The model input dim. It has to match the one used
|
||||
in computing features.
|
||||
|
||||
- subsampling_factor: The subsampling factor for the model.
|
||||
|
||||
- encoder_dim: Hidden dim for multi-head attention model.
|
||||
|
||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
||||
|
||||
- warm_step: The warmup period that dictates the decay of the
|
||||
scale on "simple" (un-pruned) loss.
|
||||
- frame_shift_ms: The frame shift in milliseconds.
|
||||
- allowed_excess_duration_ratio: The allowed excess duration ratio.
|
||||
- best_train_loss: The best training loss so far.
|
||||
- best_valid_loss: The best validation loss so far.
|
||||
- best_train_epoch: The epoch where the best training loss is achieved.
|
||||
- best_valid_epoch: The epoch where the best validation loss is achieved.
|
||||
- batch_idx_train: The batch index of the current batch.
|
||||
- log_interval: Log training stats every `log_interval` batches.
|
||||
- reset_interval: Reset the stats every `reset_interval` batches.
|
||||
- valid_interval: Run validation every `valid_interval` batches.
|
||||
- env_info: The environment information.
|
||||
"""
|
||||
params = AttributeDict(
|
||||
{
|
||||
@ -292,13 +261,14 @@ def get_params() -> AttributeDict:
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 9999999,
|
||||
"valid_interval": 5000,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -414,6 +384,7 @@ def save_checkpoint(
|
||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||
copyfile(src=filename, dst=best_valid_filename)
|
||||
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
tokenizer: whisper.tokenizer.Tokenizer,
|
||||
@ -422,22 +393,21 @@ def compute_loss(
|
||||
is_training: bool,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute RNN-T loss given the model and its inputs.
|
||||
|
||||
Compute the loss for the given batch.
|
||||
Args:
|
||||
params:
|
||||
Parameters for training. See :func:`get_params`.
|
||||
model:
|
||||
The model for training. It is an instance of Zipformer in our case.
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
is_training:
|
||||
True for training. False for validation. When it is True, this
|
||||
function enables autograd during computation; when it is False, it
|
||||
disables autograd.
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
tokenizer:
|
||||
The tokenizer used to encode the text.
|
||||
model:
|
||||
The model for training.
|
||||
batch:
|
||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||
for the content in it.
|
||||
is_training:
|
||||
Whether it is training.
|
||||
Returns:
|
||||
Return a tuple of two elements. The first element is the loss tensor.
|
||||
"""
|
||||
# For the uneven-sized batch, the total duration after padding would possibly
|
||||
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
||||
@ -449,6 +419,7 @@ def compute_loss(
|
||||
if isinstance(model, DDP):
|
||||
# get underlying nn.Module
|
||||
model = model.module
|
||||
|
||||
def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor:
|
||||
padding_size = max(tensor.shape[0] for tensor in tensors)
|
||||
dims = len(tensors[0].shape)
|
||||
@ -479,9 +450,16 @@ def compute_loss(
|
||||
# remove spaces in texts
|
||||
texts = [text.replace(" ", "") for text in texts]
|
||||
|
||||
text_tokens_list = [list(tokenizer.sot_sequence_including_notimestamps) + tokenizer.encode(text) + [tokenizer.eot] for text in texts]
|
||||
text_tokens_list = [
|
||||
list(tokenizer.sot_sequence_including_notimestamps)
|
||||
+ tokenizer.encode(text)
|
||||
+ [tokenizer.eot]
|
||||
for text in texts
|
||||
]
|
||||
# convert it to torch tensor
|
||||
text_tokens_list = [torch.LongTensor(text_tokens) for text_tokens in text_tokens_list]
|
||||
text_tokens_list = [
|
||||
torch.LongTensor(text_tokens) for text_tokens in text_tokens_list
|
||||
]
|
||||
|
||||
# 50256 is the index of <pad> for all whisper models
|
||||
prev_outputs_tokens = _batch_tensors(
|
||||
@ -494,9 +472,11 @@ def compute_loss(
|
||||
[tokens.shape[0] - 1 for tokens in text_tokens_list]
|
||||
)
|
||||
|
||||
decoder_criterion = LabelSmoothingLoss(ignore_index=50256, label_smoothing=0.1, reduction="sum")
|
||||
decoder_criterion = LabelSmoothingLoss(
|
||||
ignore_index=50256, label_smoothing=0.1, reduction="sum"
|
||||
)
|
||||
|
||||
# ignore the first 3 tokens, which are always <sos>, <lang_id>, <transcibe>
|
||||
# ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|>
|
||||
ignore_prefix_size = 3
|
||||
with torch.set_grad_enabled(is_training):
|
||||
encoder_out = model.encoder(feature)
|
||||
@ -623,7 +603,7 @@ def train_one_epoch(
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
@ -687,16 +667,24 @@ def train_one_epoch(
|
||||
if batch_idx % params.log_interval == 0:
|
||||
try:
|
||||
cur_lr = scheduler.get_last_lr()[0]
|
||||
except:
|
||||
except: # noqa
|
||||
cur_lr = 0.0
|
||||
cur_grad_scale = scaler._scale.item() if (params.use_fp16 and not params.deepspeed) else 1.0
|
||||
cur_grad_scale = (
|
||||
scaler._scale.item()
|
||||
if (params.use_fp16 and not params.deepspeed)
|
||||
else 1.0
|
||||
)
|
||||
|
||||
logging.info(
|
||||
f"Epoch {params.cur_epoch}, "
|
||||
f"batch {batch_idx}, loss[{loss_info}], "
|
||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||
f"lr: {cur_lr:.2e}, "
|
||||
+ (f"grad_scale: {scaler._scale.item()}" if (params.use_fp16 and not params.deepspeed) else "")
|
||||
+ (
|
||||
f"grad_scale: {scaler._scale.item()}"
|
||||
if (params.use_fp16 and not params.deepspeed)
|
||||
else ""
|
||||
)
|
||||
)
|
||||
|
||||
if tb_writer is not None:
|
||||
@ -715,7 +703,6 @@ def train_one_epoch(
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
if params.train_loss < params.best_train_loss:
|
||||
@ -744,15 +731,18 @@ def run(rank, world_size, args):
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
|
||||
|
||||
replace_whisper_encoder_forward()
|
||||
model = whisper.load_model(params.model_name, 'cpu')
|
||||
model = whisper.load_model(params.model_name, "cpu")
|
||||
del model.alignment_heads
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
tokenizer = whisper.tokenizer.get_tokenizer(
|
||||
model.is_multilingual, num_languages=model.num_languages, language="zh", task="transcribe"
|
||||
model.is_multilingual,
|
||||
num_languages=model.num_languages,
|
||||
language="zh",
|
||||
task="transcribe",
|
||||
)
|
||||
|
||||
model_avg: Optional[nn.Module] = None
|
||||
@ -791,7 +781,8 @@ def run(rank, world_size, args):
|
||||
if params.deepspeed:
|
||||
logging.info("Using DeepSpeed")
|
||||
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||
args=params, model=model, model_parameters=model.parameters())
|
||||
args=params, model=model, model_parameters=model.parameters()
|
||||
)
|
||||
else:
|
||||
logging.info("Using DDP")
|
||||
setup_dist(use_ddp_launch=True)
|
||||
@ -860,13 +851,17 @@ def run(rank, world_size, args):
|
||||
break
|
||||
|
||||
if params.deepspeed:
|
||||
model.save_checkpoint(save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
client_state={})
|
||||
model.save_checkpoint(
|
||||
save_dir=params.exp_dir,
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
client_state={},
|
||||
)
|
||||
if rank == 0:
|
||||
convert_zero_checkpoint_to_fp32_state_dict(
|
||||
params.exp_dir, f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}")
|
||||
params.exp_dir,
|
||||
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
||||
tag=f"epoch-{params.cur_epoch}",
|
||||
)
|
||||
else:
|
||||
save_checkpoint(
|
||||
params=params,
|
||||
@ -924,5 +919,6 @@ def main():
|
||||
torch.set_num_interop_threads(1)
|
||||
run(rank=rank, world_size=world_size, args=args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
@ -1,6 +1,8 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import whisper
|
||||
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
"""
|
||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||
@ -10,7 +12,7 @@ def forward(self, x: torch.Tensor):
|
||||
x = F.gelu(self.conv2(x))
|
||||
x = x.permute(0, 2, 1)
|
||||
|
||||
x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype)
|
||||
x = (x + self.positional_embedding[: x.shape[1], :]).to(x.dtype)
|
||||
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
@ -18,6 +20,7 @@ def forward(self, x: torch.Tensor):
|
||||
x = self.ln_post(x)
|
||||
return x
|
||||
|
||||
|
||||
def replace_whisper_encoder_forward():
|
||||
"""
|
||||
This function monkey patches the forward method of the whisper encoder.
|
||||
|
@ -22,16 +22,25 @@ It looks for manifests in the directory data/manifests.
|
||||
|
||||
The generated fbank features are saved in data/fbank.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import CutSet, Fbank, FbankConfig, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter, MonoCut, combine
|
||||
from lhotse import (
|
||||
CutSet,
|
||||
Fbank,
|
||||
FbankConfig,
|
||||
LilcomChunkyWriter,
|
||||
MonoCut,
|
||||
WhisperFbank,
|
||||
WhisperFbankConfig,
|
||||
combine,
|
||||
)
|
||||
from lhotse.recipes.utils import read_manifests_if_cached
|
||||
|
||||
from icefall.utils import get_executor
|
||||
from icefall.utils import get_executor, str2bool
|
||||
|
||||
# Torch's multithreaded behavior needs to be disabled or
|
||||
# it wastes a lot of CPU and slow things down.
|
||||
@ -81,7 +90,9 @@ def compute_fbank_musan(num_mel_bins: int = 80, whisper_fbank: bool = False):
|
||||
logging.info("Extracting features for Musan")
|
||||
|
||||
if whisper_fbank:
|
||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
|
||||
extractor = WhisperFbank(
|
||||
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||
)
|
||||
else:
|
||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||
|
||||
@ -103,6 +114,7 @@ def compute_fbank_musan(num_mel_bins: int = 80, whisper_fbank: bool = False):
|
||||
)
|
||||
musan_cuts.to_file(musan_cuts_path)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
@ -119,10 +131,12 @@ def get_args():
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
args = get_args()
|
||||
compute_fbank_musan(
|
||||
num_mel_bins=args.num_mel_bins, whisper_fbank=args.whisper_fbank
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user