From 8d9ab308af69d36c01dcb96bca867b0d5be4ffc2 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 22 Jan 2024 08:10:26 +0000 Subject: [PATCH] fix lint --- egs/aishell/ASR/README.md | 7 + egs/aishell/ASR/RESULTS.md | 8 +- .../ASR/local/compute_fbank_aishell.py | 23 +- egs/aishell/ASR/prepare.sh | 2 +- egs/aishell/ASR/whisper/decode.py | 245 +++++++++--------- egs/aishell/ASR/whisper/ds_config_zero1.json | 2 +- egs/aishell/ASR/whisper/requirements.txt | 2 +- egs/aishell/ASR/whisper/train.py | 170 ++++++------ .../whisper_encoder_forward_monkey_patch.py | 5 +- .../ASR/local/compute_fbank_musan.py | 22 +- 10 files changed, 257 insertions(+), 229 deletions(-) diff --git a/egs/aishell/ASR/README.md b/egs/aishell/ASR/README.md index 176f065e5..b54719162 100644 --- a/egs/aishell/ASR/README.md +++ b/egs/aishell/ASR/README.md @@ -24,3 +24,10 @@ The following table lists the differences among them. The decoder in `transducer_stateless` is modified from the paper [Rnn-Transducer with Stateless Prediction Network](https://ieeexplore.ieee.org/document/9054419/). We place an additional Conv1d layer right after the input embedding layer. + +# Whisper + +Recipe to finetune large pretrained models +| | Encoder | Decoder | Comment | +|------------------------------------|-----------|--------------------|-----------------------------------------------------------------------------------| +| `whisper` | Transformer | Transformer | support fine-tuning using deepspeed diff --git a/egs/aishell/ASR/RESULTS.md b/egs/aishell/ASR/RESULTS.md index 00241dca7..3cdb07c11 100644 --- a/egs/aishell/ASR/RESULTS.md +++ b/egs/aishell/ASR/RESULTS.md @@ -77,7 +77,7 @@ It's reworked Zipformer with Pruned RNNT loss. Command for training is: ```bash -./prepare.sh +./prepare.sh export CUDA_VISIBLE_DEVICES="0,1" @@ -142,7 +142,7 @@ export CUDA_VISIBLE_DEVICES="0,1" --feedforward-dim 512,768,768,768,768,768 \ --encoder-dim 192,256,256,256,256,256 \ --encoder-unmasked-dim 192,192,192,192,192,192 \ - --max-duration 1200 + --max-duration 1200 ``` Command for decoding is: @@ -192,7 +192,7 @@ export CUDA_VISIBLE_DEVICES="0,1" --feedforward-dim 512,768,1536,2048,1536,768 \ --encoder-dim 192,256,512,768,512,256 \ --encoder-unmasked-dim 192,192,256,320,256,192 \ - --max-duration 800 + --max-duration 800 ``` Command for decoding is: @@ -208,7 +208,7 @@ for m in greedy_search modified_beam_search fast_beam_search ; do --num-encoder-layers 2,2,4,5,4,2 \ --feedforward-dim 512,768,1536,2048,1536,768 \ --encoder-dim 192,256,512,768,512,256 \ - --encoder-unmasked-dim 192,192,256,320,256,192 + --encoder-unmasked-dim 192,192,256,320,256,192 done ``` diff --git a/egs/aishell/ASR/local/compute_fbank_aishell.py b/egs/aishell/ASR/local/compute_fbank_aishell.py index 0ca619d98..1a8ce1e8f 100755 --- a/egs/aishell/ASR/local/compute_fbank_aishell.py +++ b/egs/aishell/ASR/local/compute_fbank_aishell.py @@ -29,7 +29,14 @@ import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + WhisperFbank, + WhisperFbankConfig, +) from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor, str2bool @@ -42,7 +49,9 @@ torch.set_num_threads(1) torch.set_num_interop_threads(1) -def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False): +def compute_fbank_aishell( + num_mel_bins: int = 80, perturb_speed: bool = False, whisper_fbank: bool = False +): src_dir = Path("data/manifests") output_dir = Path("data/fbank") num_jobs = min(15, os.cpu_count()) @@ -69,7 +78,9 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, w dataset_parts, ) if whisper_fbank: - extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda')) + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") + ) else: extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) @@ -84,7 +95,7 @@ def compute_fbank_aishell(num_mel_bins: int = 80, perturb_speed: bool = False, w supervisions=m["supervisions"], ) if "train" in partition and perturb_speed: - logging.info(f"Doing speed perturb") + logging.info("Doing speed perturb") cut_set = ( cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) ) @@ -129,5 +140,7 @@ if __name__ == "__main__": args = get_args() compute_fbank_aishell( - num_mel_bins=args.num_mel_bins, perturb_speed=args.perturb_speed, whisper_fbank=args.whisper_fbank + num_mel_bins=args.num_mel_bins, + perturb_speed=args.perturb_speed, + whisper_fbank=args.whisper_fbank, ) diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index aaeba39f8..f0578f4d6 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -387,4 +387,4 @@ if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then ./local/compute_fbank_musan.py --num-mel-bins ${whisper_mel_bins} --whisper-fbank true touch data/fbank/.aishell.whisper.done fi -fi \ No newline at end of file +fi diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py index bb8aaabd0..07e28a8d4 100755 --- a/egs/aishell/ASR/whisper/decode.py +++ b/egs/aishell/ASR/whisper/decode.py @@ -2,6 +2,7 @@ # Copyright 2021 Xiaomi Corporation (Author: Liyong Guo, # Fangjun Kuang, # Wei Kang) +# 2024 Yuekai Zhang # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -42,44 +43,37 @@ python3 ./whisper/decode.py \ import argparse import logging +import re from collections import defaultdict from pathlib import Path from typing import Dict, List, Optional, Tuple -import whisper -from whisper.normalizers import BasicTextNormalizer import k2 import torch import torch.nn as nn +import whisper from asr_datamodule import AishellAsrDataModule -#from model import load_model +from tn.chinese.normalizer import Normalizer +from whisper.normalizers import BasicTextNormalizer from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from icefall.checkpoint import load_checkpoint, average_checkpoints_with_averaged_model -from icefall.decode import ( - get_lattice, - nbest_decoding, - nbest_oracle, - one_best_decoding, - rescore_with_attention_decoder, -) +from zhconv import convert + +from icefall.checkpoint import average_checkpoints_with_averaged_model, load_checkpoint from icefall.env import get_env_info -from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, - get_texts, setup_logger, store_transcripts, - write_error_stats, str2bool, + write_error_stats, ) -from zhconv import convert -from tn.chinese.normalizer import Normalizer -import re + def average_checkpoints( filenames: List[Path], device: torch.device = torch.device("cpu") ) -> dict: """Average a list of checkpoints. + The function is mainly used for deepspeed converted checkpoint averaging, which only include model state_dict. Args: filenames: @@ -94,9 +88,9 @@ def average_checkpoints( n = len(filenames) if "model" in torch.load(filenames[0], map_location=device): - avg = torch.load(filenames[0], map_location=device)["model"] + avg = torch.load(filenames[0], map_location=device)["model"] else: - avg = torch.load(filenames[0], map_location=device) + avg = torch.load(filenames[0], map_location=device) # Identify shared parameters. Two parameters are said to be shared # if they have the same data_ptr @@ -112,9 +106,9 @@ def average_checkpoints( for i in range(1, n): if "model" in torch.load(filenames[i], map_location=device): - state_dict = torch.load(filenames[i], map_location=device)["model"] + state_dict = torch.load(filenames[i], map_location=device)["model"] else: - state_dict = torch.load(filenames[i], map_location=device) + state_dict = torch.load(filenames[i], map_location=device) for k in uniqued_names: avg[k] += state_dict[k] @@ -126,33 +120,48 @@ def average_checkpoints( return avg + def remove_punctuation(text: str or List[str]): - # https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py - punctuation = '!,.;:?、!,。;:?《》 ' + """Modified from https://github.com/yeyupiaoling/Whisper-Finetune/blob/master/utils/data_utils.py + + Args: + text: It can be a string or a list of strings. + Returns: + Return a string or a list of strings without any punctuation. + """ + punctuation = "!,.;:?、!,。;:?《》 " if isinstance(text, str): - text = re.sub(r'[{}]+'.format(punctuation), '', text).strip() + text = re.sub(r"[{}]+".format(punctuation), "", text).strip() return text elif isinstance(text, list): result_text = [] for t in text: - t = re.sub(r'[{}]+'.format(punctuation), '', t).strip() + t = re.sub(r"[{}]+".format(punctuation), "", t).strip() result_text.append(t) return result_text else: - raise Exception(f'Not support type {type(text)}') + raise Exception(f"Not support type {type(text)}") + def to_simple(text: str or List[str]): + """Convert traditional Chinese to simplified Chinese. + Args: + text: It can be a string or a list of strings. + Returns: + Return a string or a list of strings converted to simplified Chinese. + """ if isinstance(text, str): - text = convert(text, 'zh-cn') + text = convert(text, "zh-cn") return text elif isinstance(text, list): result_text = [] for t in text: - t = convert(t, 'zh-cn') + t = convert(t, "zh-cn") result_text.append(t) return result_text else: - raise Exception(f'Not support type{type(text)}') + raise Exception(f"Not support type{type(text)}") + def get_parser(): parser = argparse.ArgumentParser( @@ -214,7 +223,7 @@ def get_parser(): default=True, help="replace whisper encoder forward method to remove input length restriction", ) - + return parser @@ -226,6 +235,7 @@ def get_params() -> AttributeDict: ) return params + def decode_one_batch( params: AttributeDict, model: nn.Module, @@ -234,42 +244,17 @@ def decode_one_batch( """Decode one batch and return the result in a dict. The dict has the following format: - - key: It indicates the setting used for decoding. For example, - if decoding method is 1best, the key is the string `no_rescore`. - If attention rescoring is used, the key is the string - `ngram_lm_scale_xxx_attention_scale_xxx`, where `xxx` is the - value of `lm_scale` and `attention_scale`. An example key is - `ngram_lm_scale_0.7_attention_scale_0.5` - - value: It contains the decoding result. `len(value)` equals to - batch size. `value[i]` is the decoding result for the i-th - utterance in the given batch. + - key: "beam-search" + - value: A list of lists. Each sublist is a list of token IDs. Args: - params: - It's the return value of :func:`get_params`. - - - params.method is "1best", it uses 1best decoding without LM rescoring. - - params.method is "nbest", it uses nbest decoding without LM rescoring. - - params.method is "attention-decoder", it uses attention rescoring. - - model: - The neural model. - HLG: - The decoding graph. Used when params.method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.method is ctc-decoding. - batch: - It is the return value from iterating - `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation - for the format of the `batch`. - lexicon: - It contains the token symbol table and the word symbol table. - sos_id: - The token ID of the SOS. - eos_id: - The token ID of the EOS. + params: + It is returned by :func:`get_params`. + model: + The neural model. + batch: + It is returned by :meth:`torch.utils.data.DataLoader.__iter__`. Returns: - Return the decoding result. See above description for the format of - the returned dict. + Return a dict, whose key may be "beam-search". """ dtype = torch.float16 device = torch.device("cuda") @@ -280,22 +265,27 @@ def decode_one_batch( if not params.remove_whisper_encoder_input_length_restriction: T = 3000 if feature.shape[2] < T: - feature = torch.cat([feature, torch.zeros(feature.shape[0], feature.shape[1], T - feature.shape[2]).to(device, dtype=dtype)], 2) + feature = torch.cat( + [ + feature, + torch.zeros( + feature.shape[0], feature.shape[1], T - feature.shape[2] + ).to(device, dtype=dtype), + ], + 2, + ) supervisions = batch["supervisions"] feature_len = supervisions["num_frames"] feature_len = feature_len.to(device, dtype=dtype) results = model.decode(feature, params.decoding_options) hyps = [result.text for result in results] - + hyps = remove_punctuation(hyps) hyps = to_simple(hyps) - hyps = [params.normalizer.normalize(hyp) for hyp in hyps] - key = "beam-search" - - return {key: hyps} + return {"beam-search": hyps} def decode_dataset( @@ -306,28 +296,14 @@ def decode_dataset( """Decode dataset. Args: - dl: - PyTorch's dataloader containing the dataset to decode. - params: - It is returned by :func:`get_params`. - model: - The neural model. - HLG: - The decoding graph. Used when params.method is NOT ctc-decoding. - H: - The ctc topo. Used only when params.method is ctc-decoding. - lexicon: - It contains the token symbol table and the word symbol table. - sos_id: - The token ID for SOS. - eos_id: - The token ID for EOS. + dl: + The dataloader. + params: + It is returned by :func:`get_params`. + model: + The neural model. Returns: - Return a dict, whose key may be "no-rescore" if the decoding method is - 1best or it may be "ngram_lm_scale_0.7_attention_scale_0.5" if attention - rescoring is used. Its value is a list of tuples. Each tuple contains two - elements: The first is the reference transcript, and the second is the - predicted result. + Return a dict, whose key may be "beam-search". """ results = [] @@ -376,7 +352,9 @@ def save_results( enable_log = True test_set_wers = dict() for key, results in results_dict.items(): - recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + recog_path = ( + params.exp_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) results = sorted(results) store_transcripts(filename=recog_path, texts=results) if enable_log: @@ -384,7 +362,9 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. - errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + errs_filename = ( + params.exp_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) # we compute CER for aishell dataset. results_char = [] for res in results: @@ -423,13 +403,20 @@ def main(): params = get_params() params.update(vars(args)) params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - setup_logger(f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}") + setup_logger( + f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}" + ) - options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=params.beam_size) + options = whisper.DecodingOptions( + task="transcribe", + language="zh", + without_timestamps=True, + beam_size=params.beam_size, + ) params.decoding_options = options params.cleaner = BasicTextNormalizer() params.normalizer = Normalizer() - + logging.info("Decoding started") logging.info(params) @@ -441,39 +428,47 @@ def main(): if params.remove_whisper_encoder_input_length_restriction: replace_whisper_encoder_forward() - model = whisper.load_model(params.model_name, 'cpu') + model = whisper.load_model(params.model_name, "cpu") if params.epoch > 0: - if params.avg > 1: - start = params.epoch - params.avg - assert start >= 1, start - checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu') - if 'model' not in checkpoint: - filenames = [f"{params.exp_dir}/epoch-{epoch}.pt" for epoch in range(start, params.epoch + 1)] - model.load_state_dict(average_checkpoints(filenames)) - else: - filename_start = f"{params.exp_dir}/epoch-{start}.pt" - filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - logging.info( - f"Calculating the averaged model over epoch range from " - f"{start} (excluded) to {params.epoch}" + if params.avg > 1: + start = params.epoch - params.avg + assert start >= 1, start + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" ) - model.to(device) - model.load_state_dict( - average_checkpoints_with_averaged_model( - filename_start=filename_start, - filename_end=filename_end, - device=device, + if "model" not in checkpoint: + # deepspeed converted checkpoint only contains model state_dict + filenames = [ + f"{params.exp_dir}/epoch-{epoch}.pt" + for epoch in range(start, params.epoch + 1) + ] + model.load_state_dict(average_checkpoints(filenames)) + else: + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" ) - ) - # save checkpoints - filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" - torch.save(model.state_dict(), filename) - else: - checkpoint = torch.load(f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location='cpu') - if 'model' not in checkpoint: - model.load_state_dict(checkpoint, strict=True) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + # save checkpoints + filename = f"{params.exp_dir}/epoch-{params.epoch}-avg-{params.avg}.pt" + torch.save(model.state_dict(), filename) else: - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + checkpoint = torch.load( + f"{params.exp_dir}/epoch-{params.epoch}.pt", map_location="cpu" + ) + if "model" not in checkpoint: + model.load_state_dict(checkpoint, strict=True) + else: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) diff --git a/egs/aishell/ASR/whisper/ds_config_zero1.json b/egs/aishell/ASR/whisper/ds_config_zero1.json index b95b1cee4..bf8cc0452 100644 --- a/egs/aishell/ASR/whisper/ds_config_zero1.json +++ b/egs/aishell/ASR/whisper/ds_config_zero1.json @@ -35,4 +35,4 @@ "steps_per_print": 50, "train_micro_batch_size_per_gpu": 1, "wall_clock_breakdown": false -} \ No newline at end of file +} diff --git a/egs/aishell/ASR/whisper/requirements.txt b/egs/aishell/ASR/whisper/requirements.txt index 319f9ff4a..0708f2344 100755 --- a/egs/aishell/ASR/whisper/requirements.txt +++ b/egs/aishell/ASR/whisper/requirements.txt @@ -7,4 +7,4 @@ librosa git+https://github.com/yuekaizhang/whisper.git zhconv WeTextProcessing -deepspeed \ No newline at end of file +deepspeed diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index 4251536ad..8d5930437 100755 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 # Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# 2024 Yuekai Zhang # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -41,44 +42,37 @@ import random import warnings from pathlib import Path from shutil import copyfile -from typing import Any, Dict, Optional, Tuple, Union -import deepspeed -from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict +from typing import Any, Dict, List, Optional, Tuple, Union +import deepspeed import k2 import optim import torch import torch.multiprocessing as mp import torch.nn as nn -from typing import List - +import whisper from asr_datamodule import AishellAsrDataModule - +from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict +from label_smoothing import LabelSmoothingLoss from lhotse import CutSet, load_manifest from lhotse.cut import Cut from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed - from optim import Eden, ScaledAdam from torch import Tensor from torch.cuda.amp import GradScaler -from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.functional import pad as pad_tensor +from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter - +from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward from icefall import diagnostics - from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl -from icefall.checkpoint import ( - save_checkpoint_with_global_batch_idx, - update_averaged_model, -) -from icefall.dist import cleanup_dist, setup_dist, get_world_size, get_rank, get_local_rank +from icefall.checkpoint import update_averaged_model +from icefall.dist import cleanup_dist, get_rank, get_world_size, setup_dist from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks -from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, MetricsTracker, @@ -87,10 +81,6 @@ from icefall.utils import ( str2bool, ) -import whisper -from whisper_encoder_forward_monkey_patch import replace_whisper_encoder_forward -from label_smoothing import LabelSmoothingLoss - LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] @@ -102,6 +92,7 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: if hasattr(module, "batch_count"): module.batch_count = batch_count + def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -247,39 +238,17 @@ def get_params() -> AttributeDict: Explanation of options saved in `params`: - - best_train_loss: Best training loss so far. It is used to select - the model that has the lowest training loss. It is - updated during the training. - - - best_valid_loss: Best validation loss so far. It is used to select - the model that has the lowest validation loss. It is - updated during the training. - - - best_train_epoch: It is the epoch that has the best training loss. - - - best_valid_epoch: It is the epoch that has the best validation loss. - - - batch_idx_train: Used to writing statistics to tensorboard. It - contains number of batches trained so far across - epochs. - - - log_interval: Print training loss if batch_idx % log_interval` is 0 - - - reset_interval: Reset statistics if batch_idx % reset_interval is 0 - - - valid_interval: Run validation if batch_idx % valid_interval is 0 - - - feature_dim: The model input dim. It has to match the one used - in computing features. - - - subsampling_factor: The subsampling factor for the model. - - - encoder_dim: Hidden dim for multi-head attention model. - - - num_decoder_layers: Number of decoder layer of transformer decoder. - - - warm_step: The warmup period that dictates the decay of the - scale on "simple" (un-pruned) loss. + - frame_shift_ms: The frame shift in milliseconds. + - allowed_excess_duration_ratio: The allowed excess duration ratio. + - best_train_loss: The best training loss so far. + - best_valid_loss: The best validation loss so far. + - best_train_epoch: The epoch where the best training loss is achieved. + - best_valid_epoch: The epoch where the best validation loss is achieved. + - batch_idx_train: The batch index of the current batch. + - log_interval: Log training stats every `log_interval` batches. + - reset_interval: Reset the stats every `reset_interval` batches. + - valid_interval: Run validation every `valid_interval` batches. + - env_info: The environment information. """ params = AttributeDict( { @@ -292,13 +261,14 @@ def get_params() -> AttributeDict: "batch_idx_train": 0, "log_interval": 50, "reset_interval": 200, - "valid_interval": 9999999, + "valid_interval": 5000, "env_info": get_env_info(), } ) return params + def load_checkpoint_if_available( params: AttributeDict, model: nn.Module, @@ -414,6 +384,7 @@ def save_checkpoint( best_valid_filename = params.exp_dir / "best-valid-loss.pt" copyfile(src=filename, dst=best_valid_filename) + def compute_loss( params: AttributeDict, tokenizer: whisper.tokenizer.Tokenizer, @@ -422,22 +393,21 @@ def compute_loss( is_training: bool, ) -> Tuple[Tensor, MetricsTracker]: """ - Compute RNN-T loss given the model and its inputs. - + Compute the loss for the given batch. Args: - params: - Parameters for training. See :func:`get_params`. - model: - The model for training. It is an instance of Zipformer in our case. - batch: - A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` - for the content in it. - is_training: - True for training. False for validation. When it is True, this - function enables autograd during computation; when it is False, it - disables autograd. - warmup: a floating point value which increases throughout training; - values >= 1.0 are fully warmed up and have all modules present. + params: + It is returned by :func:`get_params`. + tokenizer: + The tokenizer used to encode the text. + model: + The model for training. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + Whether it is training. + Returns: + Return a tuple of two elements. The first element is the loss tensor. """ # For the uneven-sized batch, the total duration after padding would possibly # cause OOM. Hence, for each batch, which is sorted descendingly by length, @@ -449,6 +419,7 @@ def compute_loss( if isinstance(model, DDP): # get underlying nn.Module model = model.module + def _batch_tensors(tensors: List[Tensor], pad_value: Any) -> Tensor: padding_size = max(tensor.shape[0] for tensor in tensors) dims = len(tensors[0].shape) @@ -479,9 +450,16 @@ def compute_loss( # remove spaces in texts texts = [text.replace(" ", "") for text in texts] - text_tokens_list = [list(tokenizer.sot_sequence_including_notimestamps) + tokenizer.encode(text) + [tokenizer.eot] for text in texts] + text_tokens_list = [ + list(tokenizer.sot_sequence_including_notimestamps) + + tokenizer.encode(text) + + [tokenizer.eot] + for text in texts + ] # convert it to torch tensor - text_tokens_list = [torch.LongTensor(text_tokens) for text_tokens in text_tokens_list] + text_tokens_list = [ + torch.LongTensor(text_tokens) for text_tokens in text_tokens_list + ] # 50256 is the index of for all whisper models prev_outputs_tokens = _batch_tensors( @@ -494,9 +472,11 @@ def compute_loss( [tokens.shape[0] - 1 for tokens in text_tokens_list] ) - decoder_criterion = LabelSmoothingLoss(ignore_index=50256, label_smoothing=0.1, reduction="sum") + decoder_criterion = LabelSmoothingLoss( + ignore_index=50256, label_smoothing=0.1, reduction="sum" + ) - # ignore the first 3 tokens, which are always , , + # ignore the first 3 tokens, which are always <|lang_id|>, <|transcibe|>, <|notimestampes|> ignore_prefix_size = 3 with torch.set_grad_enabled(is_training): encoder_out = model.encoder(feature) @@ -623,7 +603,7 @@ def train_one_epoch( valid_info.write_summary( tb_writer, "train/valid_", params.batch_idx_train ) - + try: with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( @@ -687,16 +667,24 @@ def train_one_epoch( if batch_idx % params.log_interval == 0: try: cur_lr = scheduler.get_last_lr()[0] - except: + except: # noqa cur_lr = 0.0 - cur_grad_scale = scaler._scale.item() if (params.use_fp16 and not params.deepspeed) else 1.0 + cur_grad_scale = ( + scaler._scale.item() + if (params.use_fp16 and not params.deepspeed) + else 1.0 + ) logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if (params.use_fp16 and not params.deepspeed) else "") + + ( + f"grad_scale: {scaler._scale.item()}" + if (params.use_fp16 and not params.deepspeed) + else "" + ) ) if tb_writer is not None: @@ -715,7 +703,6 @@ def train_one_epoch( params.batch_idx_train, ) - loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: @@ -744,15 +731,18 @@ def run(rank, world_size, args): logging.info(params) logging.info("About to create model") - + replace_whisper_encoder_forward() - model = whisper.load_model(params.model_name, 'cpu') + model = whisper.load_model(params.model_name, "cpu") del model.alignment_heads num_param = sum([p.numel() for p in model.parameters()]) logging.info(f"Number of model parameters: {num_param}") tokenizer = whisper.tokenizer.get_tokenizer( - model.is_multilingual, num_languages=model.num_languages, language="zh", task="transcribe" + model.is_multilingual, + num_languages=model.num_languages, + language="zh", + task="transcribe", ) model_avg: Optional[nn.Module] = None @@ -791,7 +781,8 @@ def run(rank, world_size, args): if params.deepspeed: logging.info("Using DeepSpeed") model, optimizer, _, scheduler = deepspeed.initialize( - args=params, model=model, model_parameters=model.parameters()) + args=params, model=model, model_parameters=model.parameters() + ) else: logging.info("Using DDP") setup_dist(use_ddp_launch=True) @@ -860,13 +851,17 @@ def run(rank, world_size, args): break if params.deepspeed: - model.save_checkpoint(save_dir=params.exp_dir, - tag=f"epoch-{params.cur_epoch}", - client_state={}) + model.save_checkpoint( + save_dir=params.exp_dir, + tag=f"epoch-{params.cur_epoch}", + client_state={}, + ) if rank == 0: convert_zero_checkpoint_to_fp32_state_dict( - params.exp_dir, f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", - tag=f"epoch-{params.cur_epoch}") + params.exp_dir, + f"{params.exp_dir}/epoch-{params.cur_epoch}.pt", + tag=f"epoch-{params.cur_epoch}", + ) else: save_checkpoint( params=params, @@ -924,5 +919,6 @@ def main(): torch.set_num_interop_threads(1) run(rank=rank, world_size=world_size, args=args) + if __name__ == "__main__": main() diff --git a/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py b/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py index 0f2b94adf..5bfbdce3b 100644 --- a/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py +++ b/egs/aishell/ASR/whisper/whisper_encoder_forward_monkey_patch.py @@ -1,6 +1,8 @@ import torch +import torch.nn.functional as F import whisper + def forward(self, x: torch.Tensor): """ x : torch.Tensor, shape = (batch_size, n_mels, n_ctx) @@ -10,7 +12,7 @@ def forward(self, x: torch.Tensor): x = F.gelu(self.conv2(x)) x = x.permute(0, 2, 1) - x = (x + self.positional_embedding[:x.shape[1],:]).to(x.dtype) + x = (x + self.positional_embedding[: x.shape[1], :]).to(x.dtype) for block in self.blocks: x = block(x) @@ -18,6 +20,7 @@ def forward(self, x: torch.Tensor): x = self.ln_post(x) return x + def replace_whisper_encoder_forward(): """ This function monkey patches the forward method of the whisper encoder. diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 9001aa214..a1b695243 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -22,16 +22,25 @@ It looks for manifests in the directory data/manifests. The generated fbank features are saved in data/fbank. """ - +import argparse import logging import os from pathlib import Path import torch -from lhotse import CutSet, Fbank, FbankConfig, WhisperFbank, WhisperFbankConfig, LilcomChunkyWriter, MonoCut, combine +from lhotse import ( + CutSet, + Fbank, + FbankConfig, + LilcomChunkyWriter, + MonoCut, + WhisperFbank, + WhisperFbankConfig, + combine, +) from lhotse.recipes.utils import read_manifests_if_cached -from icefall.utils import get_executor +from icefall.utils import get_executor, str2bool # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -81,7 +90,9 @@ def compute_fbank_musan(num_mel_bins: int = 80, whisper_fbank: bool = False): logging.info("Extracting features for Musan") if whisper_fbank: - extractor = WhisperFbank(WhisperFbankConfig(num_filters=num_mel_bins, device='cuda')) + extractor = WhisperFbank( + WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") + ) else: extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) @@ -103,6 +114,7 @@ def compute_fbank_musan(num_mel_bins: int = 80, whisper_fbank: bool = False): ) musan_cuts.to_file(musan_cuts_path) + def get_args(): parser = argparse.ArgumentParser() parser.add_argument( @@ -119,10 +131,12 @@ def get_args(): ) return parser.parse_args() + if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) + args = get_args() compute_fbank_musan( num_mel_bins=args.num_mel_bins, whisper_fbank=args.whisper_fbank )