mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-05 07:04:18 +00:00
fix lint
This commit is contained in:
parent
b623c3be15
commit
8d9ab308af
@ -24,3 +24,10 @@ The following table lists the differences among them.
|
|||||||
The decoder in `transducer_stateless` is modified from the paper
|
The decoder in `transducer_stateless` is modified from the paper
|
||||||
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
[Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/).
|
||||||
We place an additional Conv1d layer right after the input embedding layer.
|
We place an additional Conv1d layer right after the input embedding layer.
|
||||||
|
|
||||||
|
# Whisper
|
||||||
|
|
||||||
|
Recipe to finetune large pretrained models
|
||||||
|
| | Encoder | Decoder | Comment |
|
||||||
|
|------------------------------------|-----------|--------------------|-----------------------------------------------------------------------------------|
|
||||||
|
| `whisper` | Transformer | Transformer | support fine-tuning using deepspeed
|
||||||
|
@ -77,7 +77,7 @@ It's reworked Zipformer with Pruned RNNT loss.
|
|||||||
|
|
||||||
Command for training is:
|
Command for training is:
|
||||||
```bash
|
```bash
|
||||||
./prepare.sh
|
./prepare.sh
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1"
|
export CUDA_VISIBLE_DEVICES="0,1"
|
||||||
|
|
||||||
@ -142,7 +142,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
|
|||||||
--feedforward-dim 512,768,768,768,768,768 \
|
--feedforward-dim 512,768,768,768,768,768 \
|
||||||
--encoder-dim 192,256,256,256,256,256 \
|
--encoder-dim 192,256,256,256,256,256 \
|
||||||
--encoder-unmasked-dim 192,192,192,192,192,192 \
|
--encoder-unmasked-dim 192,192,192,192,192,192 \
|
||||||
--max-duration 1200
|
--max-duration 1200
|
||||||
```
|
```
|
||||||
|
|
||||||
Command for decoding is:
|
Command for decoding is:
|
||||||
@ -192,7 +192,7 @@ export CUDA_VISIBLE_DEVICES="0,1"
|
|||||||
--feedforward-dim 512,768,1536,2048,1536,768 \
|
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||||
--encoder-dim 192,256,512,768,512,256 \
|
--encoder-dim 192,256,512,768,512,256 \
|
||||||
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
||||||
--max-duration 800
|
--max-duration 800
|
||||||
```
|
```
|
||||||
|
|
||||||
Command for decoding is:
|
Command for decoding is:
|
||||||
@ -208,7 +208,7 @@ for m in greedy_search modified_beam_search fast_beam_search ; do
|
|||||||
--num-encoder-layers 2,2,4,5,4,2 \
|
--num-encoder-layers 2,2,4,5,4,2 \
|
||||||
--feedforward-dim 512,768,1536,2048,1536,768 \
|
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||||
--encoder-dim 192,256,512,768,512,256 \
|
--encoder-dim 192,256,512,768,512,256 \
|
||||||
--encoder-unmasked-dim 192,192,256,320,256,192
|
--encoder-unmasked-dim 192,192,256,320,256,192
|
||||||
done
|
done
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -29,7 +29,14 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor, str2bool
|
from icefall.utils import get_executor, str2bool
|
||||||
@ -42,7 +49,9 @@ torch.set_num_threads(1)
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
|
||||||
def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False):
|
def compute_fbank_aishell(
|
||||||
|
num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False
|
||||||
|
):
|
||||||
src_dir = Path("data/manifests")
|
src_dir = Path("data/manifests")
|
||||||
output_dir = Path("data/fbank")
|
output_dir = Path("data/fbank")
|
||||||
num_jobs = min(15, os.cpu_count())
|
num_jobs = min(15, os.cpu_count())
|
||||||
@ -69,7 +78,9 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, w
|
|||||||
dataset_parts,
|
dataset_parts,
|
||||||
)
|
)
|
||||||
if whisper_fbank:
|
if whisper_fbank:
|
||||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
@ -84,7 +95,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, w
|
|||||||
supervisions=m["supervisions"],
|
supervisions=m["supervisions"],
|
||||||
)
|
)
|
||||||
if "train" in partition and perturb_speed:
|
if "train" in partition and perturb_speed:
|
||||||
logging.info(f"Doing speed perturb")
|
logging.info("Doing speed perturb")
|
||||||
cut_set = (
|
cut_set = (
|
||||||
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1)
|
||||||
)
|
)
|
||||||
@ -129,5 +140,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = get_args()
|
args = get_args()
|
||||||
compute_fbank_aishell(
|
compute_fbank_aishell(
|
||||||
num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank
|
num_mel_bins=args.num_mel_bins,
|
||||||
|
perturb_speed=args.perturb_speed,
|
||||||
|
whisper_fbank=args.whisper_fbank,
|
||||||
)
|
)
|
||||||
|
@ -387,4 +387,4 @@ if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
|
|||||||
./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true
|
||||||
touch data/fbank/.aishell.whisper.done
|
touch data/fbank/.aishell.whisper.done
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
|
# Copyright 2021 Xiaomi Corporation (Author: Liyong Guo,
|
||||||
# Fangjun Kuang,
|
# Fangjun Kuang,
|
||||||
# Wei Kang)
|
# Wei Kang)
|
||||||
|
# 2024 Yuekai Zhang
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -42,44 +43,37 @@ python3 ./whisper/decode.py \
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import whisper
|
|
||||||
from whisper.normalizers import BasicTextNormalizer
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import whisper
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from asr_datamodule import AishellAsrDataModule
|
||||||
#from model import load_model
|
from tn.chinese.normalizer import Normalizer
|
||||||
|
from whisper.normalizers import BasicTextNormalizer
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model
|
from zhconv import convert
|
||||||
from icefall.decode import (
|
|
||||||
get_lattice,
|
from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint
|
||||||
nbest_decoding,
|
|
||||||
nbest_oracle,
|
|
||||||
one_best_decoding,
|
|
||||||
rescore_with_attention_decoder,
|
|
||||||
)
|
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
get_texts,
|
|
||||||
setup_logger,
|
setup_logger,
|
||||||
store_transcripts,
|
store_transcripts,
|
||||||
write_error_stats,
|
|
||||||
str2bool,
|
str2bool,
|
||||||
|
write_error_stats,
|
||||||
)
|
)
|
||||||
from zhconv import convert
|
|
||||||
from tn.chinese.normalizer import Normalizer
|
|
||||||
import re
|
|
||||||
|
|
||||||
def average_checkpoints(
|
def average_checkpoints(
|
||||||
filenames: List[Path], device: torch.device = torch.device("cpu")
|
filenames: List[Path], device: torch.device = torch.device("cpu")
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Average a list of checkpoints.
|
"""Average a list of checkpoints.
|
||||||
|
The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
filenames:
|
filenames:
|
||||||
@ -94,9 +88,9 @@ def average_checkpoints(
|
|||||||
n = len(filenames)
|
n = len(filenames)
|
||||||
|
|
||||||
if "model" in torch.load(filenames[0], map_location=device):
|
if "model" in torch.load(filenames[0], map_location=device):
|
||||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||||
else:
|
else:
|
||||||
avg = torch.load(filenames[0], map_location=device)
|
avg = torch.load(filenames[0], map_location=device)
|
||||||
|
|
||||||
# Identify shared parameters. Two parameters are said to be shared
|
# Identify shared parameters. Two parameters are said to be shared
|
||||||
# if they have the same data_ptr
|
# if they have the same data_ptr
|
||||||
@ -112,9 +106,9 @@ def average_checkpoints(
|
|||||||
|
|
||||||
for i in range(1, n):
|
for i in range(1, n):
|
||||||
if "model" in torch.load(filenames[i], map_location=device):
|
if "model" in torch.load(filenames[i], map_location=device):
|
||||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||||
else:
|
else:
|
||||||
state_dict = torch.load(filenames[i], map_location=device)
|
state_dict = torch.load(filenames[i], map_location=device)
|
||||||
for k in uniqued_names:
|
for k in uniqued_names:
|
||||||
avg[k] += state_dict[k]
|
avg[k] += state_dict[k]
|
||||||
|
|
||||||
@ -126,33 +120,48 @@ def average_checkpoints(
|
|||||||
|
|
||||||
return avg
|
return avg
|
||||||
|
|
||||||
|
|
||||||
def remove_punctuation(text: str or List[str]):
|
def remove_punctuation(text: str or List[str]):
|
||||||
# https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
|
"""Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py
|
||||||
punctuation = '!,.;:?、!,。;:?《》 '
|
|
||||||
|
Args:
|
||||||
|
text: It can be a string or a list of strings.
|
||||||
|
Returns:
|
||||||
|
Return a string or a list of strings without any punctuation.
|
||||||
|
"""
|
||||||
|
punctuation = "!,.;:?、!,。;:?《》 "
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
text = re.sub(r'[{}]+'.format(punctuation), '', text).strip()
|
text = re.sub(r"[{}]+".format(punctuation), "", text).strip()
|
||||||
return text
|
return text
|
||||||
elif isinstance(text, list):
|
elif isinstance(text, list):
|
||||||
result_text = []
|
result_text = []
|
||||||
for t in text:
|
for t in text:
|
||||||
t = re.sub(r'[{}]+'.format(punctuation), '', t).strip()
|
t = re.sub(r"[{}]+".format(punctuation), "", t).strip()
|
||||||
result_text.append(t)
|
result_text.append(t)
|
||||||
return result_text
|
return result_text
|
||||||
else:
|
else:
|
||||||
raise Exception(f'Not support type {type(text)}')
|
raise Exception(f"Not support type {type(text)}")
|
||||||
|
|
||||||
|
|
||||||
def to_simple(text: str or List[str]):
|
def to_simple(text: str or List[str]):
|
||||||
|
"""Convert traditional Chinese to simplified Chinese.
|
||||||
|
Args:
|
||||||
|
text: It can be a string or a list of strings.
|
||||||
|
Returns:
|
||||||
|
Return a string or a list of strings converted to simplified Chinese.
|
||||||
|
"""
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
text = convert(text, 'zh-cn')
|
text = convert(text, "zh-cn")
|
||||||
return text
|
return text
|
||||||
elif isinstance(text, list):
|
elif isinstance(text, list):
|
||||||
result_text = []
|
result_text = []
|
||||||
for t in text:
|
for t in text:
|
||||||
t = convert(t, 'zh-cn')
|
t = convert(t, "zh-cn")
|
||||||
result_text.append(t)
|
result_text.append(t)
|
||||||
return result_text
|
return result_text
|
||||||
else:
|
else:
|
||||||
raise Exception(f'Not support type{type(text)}')
|
raise Exception(f"Not support type{type(text)}")
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
@ -214,7 +223,7 @@ def get_parser():
|
|||||||
default=True,
|
default=True,
|
||||||
help="replace whisper encoder forward method to remove input length restriction",
|
help="replace whisper encoder forward method to remove input length restriction",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -226,6 +235,7 @@ def get_params() -> AttributeDict:
|
|||||||
)
|
)
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -234,42 +244,17 @@ def decode_one_batch(
|
|||||||
"""Decode one batch and return the result in a dict. The dict has the
|
"""Decode one batch and return the result in a dict. The dict has the
|
||||||
following format:
|
following format:
|
||||||
|
|
||||||
- key: It indicates the setting used for decoding. For example,
|
- key: "beam-search"
|
||||||
if decoding method is 1best, the key is the string `no_rescore`.
|
- value: A list of lists. Each sublist is a list of token IDs.
|
||||||
If attention rescoring is used, the key is the string
|
|
||||||
`ngram_lm_scale_xxx_attention_scale_xxx`, where `xxx` is the
|
|
||||||
value of `lm_scale` and `attention_scale`. An example key is
|
|
||||||
`ngram_lm_scale_0.7_attention_scale_0.5`
|
|
||||||
- value: It contains the decoding result. `len(value)` equals to
|
|
||||||
batch size. `value[i]` is the decoding result for the i-th
|
|
||||||
utterance in the given batch.
|
|
||||||
Args:
|
Args:
|
||||||
params:
|
params:
|
||||||
It's the return value of :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
The neural model.
|
||||||
- params.method is "nbest", it uses nbest decoding without LM rescoring.
|
batch:
|
||||||
- params.method is "attention-decoder", it uses attention rescoring.
|
It is returned by :meth:`torch.utils.data.DataLoader.__iter__`.
|
||||||
|
|
||||||
model:
|
|
||||||
The neural model.
|
|
||||||
HLG:
|
|
||||||
The decoding graph. Used when params.method is NOT ctc-decoding.
|
|
||||||
H:
|
|
||||||
The ctc topo. Used only when params.method is ctc-decoding.
|
|
||||||
batch:
|
|
||||||
It is the return value from iterating
|
|
||||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
|
||||||
for the format of the `batch`.
|
|
||||||
lexicon:
|
|
||||||
It contains the token symbol table and the word symbol table.
|
|
||||||
sos_id:
|
|
||||||
The token ID of the SOS.
|
|
||||||
eos_id:
|
|
||||||
The token ID of the EOS.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return the decoding result. See above description for the format of
|
Return a dict, whose key may be "beam-search".
|
||||||
the returned dict.
|
|
||||||
"""
|
"""
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
@ -280,22 +265,27 @@ def decode_one_batch(
|
|||||||
if not params.remove_whisper_encoder_input_length_restriction:
|
if not params.remove_whisper_encoder_input_length_restriction:
|
||||||
T = 3000
|
T = 3000
|
||||||
if feature.shape[2] < T:
|
if feature.shape[2] < T:
|
||||||
feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2)
|
feature = torch.cat(
|
||||||
|
[
|
||||||
|
feature,
|
||||||
|
torch.zeros(
|
||||||
|
feature.shape[0], feature.shape[1], T - feature.shape[2]
|
||||||
|
).to(device, dtype=dtype),
|
||||||
|
],
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
|
||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_len = supervisions["num_frames"]
|
feature_len = supervisions["num_frames"]
|
||||||
feature_len = feature_len.to(device, dtype=dtype)
|
feature_len = feature_len.to(device, dtype=dtype)
|
||||||
results = model.decode(feature, params.decoding_options)
|
results = model.decode(feature, params.decoding_options)
|
||||||
hyps = [result.text for result in results]
|
hyps = [result.text for result in results]
|
||||||
|
|
||||||
hyps = remove_punctuation(hyps)
|
hyps = remove_punctuation(hyps)
|
||||||
hyps = to_simple(hyps)
|
hyps = to_simple(hyps)
|
||||||
|
|
||||||
hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
|
hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
|
||||||
|
|
||||||
key = "beam-search"
|
return {"beam-search": hyps}
|
||||||
|
|
||||||
return {key: hyps}
|
|
||||||
|
|
||||||
|
|
||||||
def decode_dataset(
|
def decode_dataset(
|
||||||
@ -306,28 +296,14 @@ def decode_dataset(
|
|||||||
"""Decode dataset.
|
"""Decode dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dl:
|
dl:
|
||||||
PyTorch's dataloader containing the dataset to decode.
|
The dataloader.
|
||||||
params:
|
params:
|
||||||
It is returned by :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
model:
|
||||||
The neural model.
|
The neural model.
|
||||||
HLG:
|
|
||||||
The decoding graph. Used when params.method is NOT ctc-decoding.
|
|
||||||
H:
|
|
||||||
The ctc topo. Used only when params.method is ctc-decoding.
|
|
||||||
lexicon:
|
|
||||||
It contains the token symbol table and the word symbol table.
|
|
||||||
sos_id:
|
|
||||||
The token ID for SOS.
|
|
||||||
eos_id:
|
|
||||||
The token ID for EOS.
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a dict, whose key may be "no-rescore" if the decoding method is
|
Return a dict, whose key may be "beam-search".
|
||||||
1best or it may be "ngram_lm_scale_0.7_attention_scale_0.5" if attention
|
|
||||||
rescoring is used. Its value is a list of tuples. Each tuple contains two
|
|
||||||
elements: The first is the reference transcript, and the second is the
|
|
||||||
predicted result.
|
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
@ -376,7 +352,9 @@ def save_results(
|
|||||||
enable_log = True
|
enable_log = True
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
recog_path = (
|
||||||
|
params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
results = sorted(results)
|
results = sorted(results)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
if enable_log:
|
if enable_log:
|
||||||
@ -384,7 +362,9 @@ def save_results(
|
|||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
errs_filename = (
|
||||||
|
params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt"
|
||||||
|
)
|
||||||
# we compute CER for aishell dataset.
|
# we compute CER for aishell dataset.
|
||||||
results_char = []
|
results_char = []
|
||||||
for res in results:
|
for res in results:
|
||||||
@ -423,13 +403,20 @@ def main():
|
|||||||
params = get_params()
|
params = get_params()
|
||||||
params.update(vars(args))
|
params.update(vars(args))
|
||||||
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
params.suffix = f"epoch-{params.epoch}-avg-{params.avg}"
|
||||||
setup_logger(f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}")
|
setup_logger(
|
||||||
|
f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}"
|
||||||
|
)
|
||||||
|
|
||||||
options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=params.beam_size)
|
options = whisper.DecodingOptions(
|
||||||
|
task="transcribe",
|
||||||
|
language="zh",
|
||||||
|
without_timestamps=True,
|
||||||
|
beam_size=params.beam_size,
|
||||||
|
)
|
||||||
params.decoding_options = options
|
params.decoding_options = options
|
||||||
params.cleaner = BasicTextNormalizer()
|
params.cleaner = BasicTextNormalizer()
|
||||||
params.normalizer = Normalizer()
|
params.normalizer = Normalizer()
|
||||||
|
|
||||||
logging.info("Decoding started")
|
logging.info("Decoding started")
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
@ -441,39 +428,47 @@ def main():
|
|||||||
|
|
||||||
if params.remove_whisper_encoder_input_length_restriction:
|
if params.remove_whisper_encoder_input_length_restriction:
|
||||||
replace_whisper_encoder_forward()
|
replace_whisper_encoder_forward()
|
||||||
model = whisper.load_model(params.model_name, 'cpu')
|
model = whisper.load_model(params.model_name, "cpu")
|
||||||
if params.epoch > 0:
|
if params.epoch > 0:
|
||||||
if params.avg > 1:
|
if params.avg > 1:
|
||||||
start = params.epoch - params.avg
|
start = params.epoch - params.avg
|
||||||
assert start >= 1, start
|
assert start >= 1, start
|
||||||
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
|
checkpoint = torch.load(
|
||||||
if 'model' not in checkpoint:
|
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||||
filenames = [f"{params.exp_dir}/epoch-{epoch}.pt" for epoch in range(start, params.epoch + 1)]
|
|
||||||
model.load_state_dict(average_checkpoints(filenames))
|
|
||||||
else:
|
|
||||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
|
||||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
|
||||||
logging.info(
|
|
||||||
f"Calculating the averaged model over epoch range from "
|
|
||||||
f"{start} (excluded) to {params.epoch}"
|
|
||||||
)
|
)
|
||||||
model.to(device)
|
if "model" not in checkpoint:
|
||||||
model.load_state_dict(
|
# deepspeed converted checkpoint only contains model state_dict
|
||||||
average_checkpoints_with_averaged_model(
|
filenames = [
|
||||||
filename_start=filename_start,
|
f"{params.exp_dir}/epoch-{epoch}.pt"
|
||||||
filename_end=filename_end,
|
for epoch in range(start, params.epoch + 1)
|
||||||
device=device,
|
]
|
||||||
|
model.load_state_dict(average_checkpoints(filenames))
|
||||||
|
else:
|
||||||
|
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||||
|
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||||
|
logging.info(
|
||||||
|
f"Calculating the averaged model over epoch range from "
|
||||||
|
f"{start} (excluded) to {params.epoch}"
|
||||||
)
|
)
|
||||||
)
|
model.to(device)
|
||||||
# save checkpoints
|
model.load_state_dict(
|
||||||
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
average_checkpoints_with_averaged_model(
|
||||||
torch.save(model.state_dict(), filename)
|
filename_start=filename_start,
|
||||||
else:
|
filename_end=filename_end,
|
||||||
checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu')
|
device=device,
|
||||||
if 'model' not in checkpoint:
|
)
|
||||||
model.load_state_dict(checkpoint, strict=True)
|
)
|
||||||
|
# save checkpoints
|
||||||
|
filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt"
|
||||||
|
torch.save(model.state_dict(), filename)
|
||||||
else:
|
else:
|
||||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
checkpoint = torch.load(
|
||||||
|
f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu"
|
||||||
|
)
|
||||||
|
if "model" not in checkpoint:
|
||||||
|
model.load_state_dict(checkpoint, strict=True)
|
||||||
|
else:
|
||||||
|
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||||
model.to(device)
|
model.to(device)
|
||||||
model.eval()
|
model.eval()
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
|
@ -35,4 +35,4 @@
|
|||||||
"steps_per_print": 50,
|
"steps_per_print": 50,
|
||||||
"train_micro_batch_size_per_gpu": 1,
|
"train_micro_batch_size_per_gpu": 1,
|
||||||
"wall_clock_breakdown": false
|
"wall_clock_breakdown": false
|
||||||
}
|
}
|
||||||
|
@ -7,4 +7,4 @@ librosa
|
|||||||
git+https://github.com/yuekaizhang/whisper.git
|
git+https://github.com/yuekaizhang/whisper.git
|
||||||
zhconv
|
zhconv
|
||||||
WeTextProcessing
|
WeTextProcessing
|
||||||
deepspeed
|
deepspeed
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang)
|
||||||
|
# 2024 Yuekai Zhang
|
||||||
#
|
#
|
||||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
#
|
#
|
||||||
@ -41,44 +42,37 @@ import random
|
|||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
from typing import Any, Dict, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
import deepspeed
|
|
||||||
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
|
||||||
|
|
||||||
|
import deepspeed
|
||||||
import k2
|
import k2
|
||||||
import optim
|
import optim
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from typing import List
|
import whisper
|
||||||
|
|
||||||
from asr_datamodule import AishellAsrDataModule
|
from asr_datamodule import AishellAsrDataModule
|
||||||
|
from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict
|
||||||
|
from label_smoothing import LabelSmoothingLoss
|
||||||
from lhotse import CutSet, load_manifest
|
from lhotse import CutSet, load_manifest
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
|
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
from torch.nn.functional import pad as pad_tensor
|
from torch.nn.functional import pad as pad_tensor
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
||||||
|
|
||||||
from icefall import diagnostics
|
from icefall import diagnostics
|
||||||
|
|
||||||
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
from icefall.checkpoint import load_checkpoint, remove_checkpoints
|
||||||
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import update_averaged_model
|
||||||
save_checkpoint_with_global_batch_idx,
|
from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist
|
||||||
update_averaged_model,
|
|
||||||
)
|
|
||||||
from icefall.dist import cleanup_dist, setup_dist, get_world_size, get_rank, get_local_rank
|
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.hooks import register_inf_check_hooks
|
from icefall.hooks import register_inf_check_hooks
|
||||||
from icefall.lexicon import Lexicon
|
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
MetricsTracker,
|
MetricsTracker,
|
||||||
@ -87,10 +81,6 @@ from icefall.utils import (
|
|||||||
str2bool,
|
str2bool,
|
||||||
)
|
)
|
||||||
|
|
||||||
import whisper
|
|
||||||
from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward
|
|
||||||
from label_smoothing import LabelSmoothingLoss
|
|
||||||
|
|
||||||
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler]
|
||||||
|
|
||||||
|
|
||||||
@ -102,6 +92,7 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:
|
|||||||
if hasattr(module, "batch_count"):
|
if hasattr(module, "batch_count"):
|
||||||
module.batch_count = batch_count
|
module.batch_count = batch_count
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -247,39 +238,17 @@ def get_params() -> AttributeDict:
|
|||||||
|
|
||||||
Explanation of options saved in `params`:
|
Explanation of options saved in `params`:
|
||||||
|
|
||||||
- best_train_loss: Best training loss so far. It is used to select
|
- frame_shift_ms: The frame shift in milliseconds.
|
||||||
the model that has the lowest training loss. It is
|
- allowed_excess_duration_ratio: The allowed excess duration ratio.
|
||||||
updated during the training.
|
- best_train_loss: The best training loss so far.
|
||||||
|
- best_valid_loss: The best validation loss so far.
|
||||||
- best_valid_loss: Best validation loss so far. It is used to select
|
- best_train_epoch: The epoch where the best training loss is achieved.
|
||||||
the model that has the lowest validation loss. It is
|
- best_valid_epoch: The epoch where the best validation loss is achieved.
|
||||||
updated during the training.
|
- batch_idx_train: The batch index of the current batch.
|
||||||
|
- log_interval: Log training stats every `log_interval` batches.
|
||||||
- best_train_epoch: It is the epoch that has the best training loss.
|
- reset_interval: Reset the stats every `reset_interval` batches.
|
||||||
|
- valid_interval: Run validation every `valid_interval` batches.
|
||||||
- best_valid_epoch: It is the epoch that has the best validation loss.
|
- env_info: The environment information.
|
||||||
|
|
||||||
- batch_idx_train: Used to writing statistics to tensorboard. It
|
|
||||||
contains number of batches trained so far across
|
|
||||||
epochs.
|
|
||||||
|
|
||||||
- log_interval: Print training loss if batch_idx % log_interval` is 0
|
|
||||||
|
|
||||||
- reset_interval: Reset statistics if batch_idx % reset_interval is 0
|
|
||||||
|
|
||||||
- valid_interval: Run validation if batch_idx % valid_interval is 0
|
|
||||||
|
|
||||||
- feature_dim: The model input dim. It has to match the one used
|
|
||||||
in computing features.
|
|
||||||
|
|
||||||
- subsampling_factor: The subsampling factor for the model.
|
|
||||||
|
|
||||||
- encoder_dim: Hidden dim for multi-head attention model.
|
|
||||||
|
|
||||||
- num_decoder_layers: Number of decoder layer of transformer decoder.
|
|
||||||
|
|
||||||
- warm_step: The warmup period that dictates the decay of the
|
|
||||||
scale on "simple" (un-pruned) loss.
|
|
||||||
"""
|
"""
|
||||||
params = AttributeDict(
|
params = AttributeDict(
|
||||||
{
|
{
|
||||||
@ -292,13 +261,14 @@ def get_params() -> AttributeDict:
|
|||||||
"batch_idx_train": 0,
|
"batch_idx_train": 0,
|
||||||
"log_interval": 50,
|
"log_interval": 50,
|
||||||
"reset_interval": 200,
|
"reset_interval": 200,
|
||||||
"valid_interval": 9999999,
|
"valid_interval": 5000,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
def load_checkpoint_if_available(
|
def load_checkpoint_if_available(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -414,6 +384,7 @@ def save_checkpoint(
|
|||||||
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
best_valid_filename = params.exp_dir / "best-valid-loss.pt"
|
||||||
copyfile(src=filename, dst=best_valid_filename)
|
copyfile(src=filename, dst=best_valid_filename)
|
||||||
|
|
||||||
|
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
tokenizer: whisper.tokenizer.Tokenizer,
|
tokenizer: whisper.tokenizer.Tokenizer,
|
||||||
@ -422,22 +393,21 @@ def compute_loss(
|
|||||||
is_training: bool,
|
is_training: bool,
|
||||||
) -> Tuple[Tensor, MetricsTracker]:
|
) -> Tuple[Tensor, MetricsTracker]:
|
||||||
"""
|
"""
|
||||||
Compute RNN-T loss given the model and its inputs.
|
Compute the loss for the given batch.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
params:
|
params:
|
||||||
Parameters for training. See :func:`get_params`.
|
It is returned by :func:`get_params`.
|
||||||
model:
|
tokenizer:
|
||||||
The model for training. It is an instance of Zipformer in our case.
|
The tokenizer used to encode the text.
|
||||||
batch:
|
model:
|
||||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
The model for training.
|
||||||
for the content in it.
|
batch:
|
||||||
is_training:
|
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
||||||
True for training. False for validation. When it is True, this
|
for the content in it.
|
||||||
function enables autograd during computation; when it is False, it
|
is_training:
|
||||||
disables autograd.
|
Whether it is training.
|
||||||
warmup: a floating point value which increases throughout training;
|
Returns:
|
||||||
values >= 1.0 are fully warmed up and have all modules present.
|
Return a tuple of two elements. The first element is the loss tensor.
|
||||||
"""
|
"""
|
||||||
# For the uneven-sized batch, the total duration after padding would possibly
|
# For the uneven-sized batch, the total duration after padding would possibly
|
||||||
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
# cause OOM. Hence, for each batch, which is sorted descendingly by length,
|
||||||
@ -449,6 +419,7 @@ def compute_loss(
|
|||||||
if isinstance(model, DDP):
|
if isinstance(model, DDP):
|
||||||
# get underlying nn.Module
|
# get underlying nn.Module
|
||||||
model = model.module
|
model = model.module
|
||||||
|
|
||||||
def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor:
|
def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor:
|
||||||
padding_size = max(tensor.shape[0] for tensor in tensors)
|
padding_size = max(tensor.shape[0] for tensor in tensors)
|
||||||
dims = len(tensors[0].shape)
|
dims = len(tensors[0].shape)
|
||||||
@ -479,9 +450,16 @@ def compute_loss(
|
|||||||
# remove spaces in texts
|
# remove spaces in texts
|
||||||
texts = [text.replace(" ", "") for text in texts]
|
texts = [text.replace(" ", "") for text in texts]
|
||||||
|
|
||||||
text_tokens_list = [list(tokenizer.sot_sequence_including_notimestamps) + tokenizer.encode(text) + [tokenizer.eot] for text in texts]
|
text_tokens_list = [
|
||||||
|
list(tokenizer.sot_sequence_including_notimestamps)
|
||||||
|
+ tokenizer.encode(text)
|
||||||
|
+ [tokenizer.eot]
|
||||||
|
for text in texts
|
||||||
|
]
|
||||||
# convert it to torch tensor
|
# convert it to torch tensor
|
||||||
text_tokens_list = [torch.LongTensor(text_tokens) for text_tokens in text_tokens_list]
|
text_tokens_list = [
|
||||||
|
torch.LongTensor(text_tokens) for text_tokens in text_tokens_list
|
||||||
|
]
|
||||||
|
|
||||||
# 50256 is the index of <pad> for all whisper models
|
# 50256 is the index of <pad> for all whisper models
|
||||||
prev_outputs_tokens = _batch_tensors(
|
prev_outputs_tokens = _batch_tensors(
|
||||||
@ -494,9 +472,11 @@ def compute_loss(
|
|||||||
[tokens.shape[0] - 1 for tokens in text_tokens_list]
|
[tokens.shape[0] - 1 for tokens in text_tokens_list]
|
||||||
)
|
)
|
||||||
|
|
||||||
decoder_criterion = LabelSmoothingLoss(ignore_index=50256, label_smoothing=0.1, reduction="sum")
|
decoder_criterion = LabelSmoothingLoss(
|
||||||
|
ignore_index=50256, label_smoothing=0.1, reduction="sum"
|
||||||
|
)
|
||||||
|
|
||||||
# ignore the first 3 tokens, which are always <sos>, <lang_id>, <transcibe>
|
# ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|>
|
||||||
ignore_prefix_size = 3
|
ignore_prefix_size = 3
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
encoder_out = model.encoder(feature)
|
encoder_out = model.encoder(feature)
|
||||||
@ -623,7 +603,7 @@ def train_one_epoch(
|
|||||||
valid_info.write_summary(
|
valid_info.write_summary(
|
||||||
tb_writer, "train/valid_", params.batch_idx_train
|
tb_writer, "train/valid_", params.batch_idx_train
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||||
loss, loss_info = compute_loss(
|
loss, loss_info = compute_loss(
|
||||||
@ -687,16 +667,24 @@ def train_one_epoch(
|
|||||||
if batch_idx % params.log_interval == 0:
|
if batch_idx % params.log_interval == 0:
|
||||||
try:
|
try:
|
||||||
cur_lr = scheduler.get_last_lr()[0]
|
cur_lr = scheduler.get_last_lr()[0]
|
||||||
except:
|
except: # noqa
|
||||||
cur_lr = 0.0
|
cur_lr = 0.0
|
||||||
cur_grad_scale = scaler._scale.item() if (params.use_fp16 and not params.deepspeed) else 1.0
|
cur_grad_scale = (
|
||||||
|
scaler._scale.item()
|
||||||
|
if (params.use_fp16 and not params.deepspeed)
|
||||||
|
else 1.0
|
||||||
|
)
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Epoch {params.cur_epoch}, "
|
f"Epoch {params.cur_epoch}, "
|
||||||
f"batch {batch_idx}, loss[{loss_info}], "
|
f"batch {batch_idx}, loss[{loss_info}], "
|
||||||
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
|
||||||
f"lr: {cur_lr:.2e}, "
|
f"lr: {cur_lr:.2e}, "
|
||||||
+ (f"grad_scale: {scaler._scale.item()}" if (params.use_fp16 and not params.deepspeed) else "")
|
+ (
|
||||||
|
f"grad_scale: {scaler._scale.item()}"
|
||||||
|
if (params.use_fp16 and not params.deepspeed)
|
||||||
|
else ""
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if tb_writer is not None:
|
if tb_writer is not None:
|
||||||
@ -715,7 +703,6 @@ def train_one_epoch(
|
|||||||
params.batch_idx_train,
|
params.batch_idx_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||||
params.train_loss = loss_value
|
params.train_loss = loss_value
|
||||||
if params.train_loss < params.best_train_loss:
|
if params.train_loss < params.best_train_loss:
|
||||||
@ -744,15 +731,18 @@ def run(rank, world_size, args):
|
|||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
|
||||||
replace_whisper_encoder_forward()
|
replace_whisper_encoder_forward()
|
||||||
model = whisper.load_model(params.model_name, 'cpu')
|
model = whisper.load_model(params.model_name, "cpu")
|
||||||
del model.alignment_heads
|
del model.alignment_heads
|
||||||
num_param = sum([p.numel() for p in model.parameters()])
|
num_param = sum([p.numel() for p in model.parameters()])
|
||||||
logging.info(f"Number of model parameters: {num_param}")
|
logging.info(f"Number of model parameters: {num_param}")
|
||||||
|
|
||||||
tokenizer = whisper.tokenizer.get_tokenizer(
|
tokenizer = whisper.tokenizer.get_tokenizer(
|
||||||
model.is_multilingual, num_languages=model.num_languages, language="zh", task="transcribe"
|
model.is_multilingual,
|
||||||
|
num_languages=model.num_languages,
|
||||||
|
language="zh",
|
||||||
|
task="transcribe",
|
||||||
)
|
)
|
||||||
|
|
||||||
model_avg: Optional[nn.Module] = None
|
model_avg: Optional[nn.Module] = None
|
||||||
@ -791,7 +781,8 @@ def run(rank, world_size, args):
|
|||||||
if params.deepspeed:
|
if params.deepspeed:
|
||||||
logging.info("Using DeepSpeed")
|
logging.info("Using DeepSpeed")
|
||||||
model, optimizer, _, scheduler = deepspeed.initialize(
|
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||||
args=params, model=model, model_parameters=model.parameters())
|
args=params, model=model, model_parameters=model.parameters()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logging.info("Using DDP")
|
logging.info("Using DDP")
|
||||||
setup_dist(use_ddp_launch=True)
|
setup_dist(use_ddp_launch=True)
|
||||||
@ -860,13 +851,17 @@ def run(rank, world_size, args):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if params.deepspeed:
|
if params.deepspeed:
|
||||||
model.save_checkpoint(save_dir=params.exp_dir,
|
model.save_checkpoint(
|
||||||
tag=f"epoch-{params.cur_epoch}",
|
save_dir=params.exp_dir,
|
||||||
client_state={})
|
tag=f"epoch-{params.cur_epoch}",
|
||||||
|
client_state={},
|
||||||
|
)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
convert_zero_checkpoint_to_fp32_state_dict(
|
convert_zero_checkpoint_to_fp32_state_dict(
|
||||||
params.exp_dir, f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
params.exp_dir,
|
||||||
tag=f"epoch-{params.cur_epoch}")
|
f"{params.exp_dir}/epoch-{params.cur_epoch}.pt",
|
||||||
|
tag=f"epoch-{params.cur_epoch}",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
save_checkpoint(
|
save_checkpoint(
|
||||||
params=params,
|
params=params,
|
||||||
@ -924,5 +919,6 @@ def main():
|
|||||||
torch.set_num_interop_threads(1)
|
torch.set_num_interop_threads(1)
|
||||||
run(rank=rank, world_size=world_size, args=args)
|
run(rank=rank, world_size=world_size, args=args)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
import whisper
|
import whisper
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
|
||||||
@ -10,7 +12,7 @@ def forward(self, x: torch.Tensor):
|
|||||||
x = F.gelu(self.conv2(x))
|
x = F.gelu(self.conv2(x))
|
||||||
x = x.permute(0, 2, 1)
|
x = x.permute(0, 2, 1)
|
||||||
|
|
||||||
x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype)
|
x = (x + self.positional_embedding[: x.shape[1], :]).to(x.dtype)
|
||||||
|
|
||||||
for block in self.blocks:
|
for block in self.blocks:
|
||||||
x = block(x)
|
x = block(x)
|
||||||
@ -18,6 +20,7 @@ def forward(self, x: torch.Tensor):
|
|||||||
x = self.ln_post(x)
|
x = self.ln_post(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def replace_whisper_encoder_forward():
|
def replace_whisper_encoder_forward():
|
||||||
"""
|
"""
|
||||||
This function monkey patches the forward method of the whisper encoder.
|
This function monkey patches the forward method of the whisper encoder.
|
||||||
|
@ -22,16 +22,25 @@ It looks for manifests in the directory data/manifests.
|
|||||||
|
|
||||||
The generated fbank features are saved in data/fbank.
|
The generated fbank features are saved in data/fbank.
|
||||||
"""
|
"""
|
||||||
|
import argparse
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lhotse import CutSet, Fbank, FbankConfig, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter, MonoCut, combine
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
Fbank,
|
||||||
|
FbankConfig,
|
||||||
|
LilcomChunkyWriter,
|
||||||
|
MonoCut,
|
||||||
|
WhisperFbank,
|
||||||
|
WhisperFbankConfig,
|
||||||
|
combine,
|
||||||
|
)
|
||||||
from lhotse.recipes.utils import read_manifests_if_cached
|
from lhotse.recipes.utils import read_manifests_if_cached
|
||||||
|
|
||||||
from icefall.utils import get_executor
|
from icefall.utils import get_executor, str2bool
|
||||||
|
|
||||||
# Torch's multithreaded behavior needs to be disabled or
|
# Torch's multithreaded behavior needs to be disabled or
|
||||||
# it wastes a lot of CPU and slow things down.
|
# it wastes a lot of CPU and slow things down.
|
||||||
@ -81,7 +90,9 @@ def compute_fbank_musan(num_mel_bins: int = 80, whisper_fbank: bool = False):
|
|||||||
logging.info("Extracting features for Musan")
|
logging.info("Extracting features for Musan")
|
||||||
|
|
||||||
if whisper_fbank:
|
if whisper_fbank:
|
||||||
extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda'))
|
extractor = WhisperFbank(
|
||||||
|
WhisperFbankConfig(num_filters=num_mel_bins, device="cuda")
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins))
|
||||||
|
|
||||||
@ -103,6 +114,7 @@ def compute_fbank_musan(num_mel_bins: int = 80, whisper_fbank: bool = False):
|
|||||||
)
|
)
|
||||||
musan_cuts.to_file(musan_cuts_path)
|
musan_cuts.to_file(musan_cuts_path)
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
def get_args():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@ -119,10 +131,12 @@ def get_args():
|
|||||||
)
|
)
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
|
||||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
args = get_args()
|
||||||
compute_fbank_musan(
|
compute_fbank_musan(
|
||||||
num_mel_bins=args.num_mel_bins, whisper_fbank=args.whisper_fbank
|
num_mel_bins=args.num_mel_bins, whisper_fbank=args.whisper_fbank
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user