diff --git a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py index 6abe6c084..293a1a569 100644 --- a/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/aishell/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -182,7 +182,7 @@ class AishellAsrDataModule: ) def train_dataloaders( - self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None + self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None, rank = None, world_size = None ) -> DataLoader: """ Args: @@ -276,6 +276,8 @@ class AishellAsrDataModule: shuffle=self.args.shuffle, num_buckets=self.args.num_buckets, drop_last=self.args.drop_last, + world_size=world_size, + rank=rank, ) else: logging.info("Using SimpleCutSampler.") @@ -300,7 +302,7 @@ class AishellAsrDataModule: return train_dl - def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + def valid_dataloaders(self, cuts_valid: CutSet, rank = None, world_size = None) -> DataLoader: transforms = [] if self.args.concatenate_cuts: transforms = [ @@ -325,6 +327,8 @@ class AishellAsrDataModule: cuts_valid, max_duration=self.args.max_duration, shuffle=False, + rank=rank, + world_size=world_size, ) logging.info("About to create dev dataloader") valid_dl = DataLoader( diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py index 28ac83562..f9d3d4c06 100644 --- a/egs/aishell/ASR/whisper/decode.py +++ b/egs/aishell/ASR/whisper/decode.py @@ -109,20 +109,17 @@ def get_parser(): default="beam-search", help="""Decoding method. Supported values are: - - (0) ctc-decoding. Use CTC decoding. It maps the tokens ids to - tokens using token symbol tabel directly. - - (1) 1best. Extract the best path from the decoding lattice as the - decoding result. - - (2) nbest. Extract n paths from the decoding lattice; the path - with the highest score is the decoding result. - - (3) attention-decoder. Extract n paths from the lattice, - the path with the highest score is the decoding result. - - (4) nbest-oracle. Its WER is the lower bound of any n-best - rescoring method can achieve. Useful for debugging n-best - rescoring method. + - beam-search """, ) + parser.add_argument( + "--beam-size", + type=int, + default=1, + help="beam size for beam search decoding", + ) + parser.add_argument( "--exp-dir", type=str, @@ -357,10 +354,9 @@ def main(): params = get_params() params.update(vars(args)) params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" - setup_logger(f"{params.exp_dir}/log-{params.method}/log-decode-{params.suffix}") + setup_logger(f"{params.exp_dir}/log-{params.method}-beam{params.beam_size}/log-decode-{params.suffix}") - #options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=10) - options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=None) + options = whisper.DecodingOptions(task="transcribe", language="zh", without_timestamps=True, beam_size=params.beam_size) params.decoding_options = options params.cleaner = BasicTextNormalizer() params.normalizer = Normalizer() diff --git a/egs/aishell/ASR/whisper/ds_config_zero1.json b/egs/aishell/ASR/whisper/ds_config_zero1.json new file mode 100644 index 000000000..59318968e --- /dev/null +++ b/egs/aishell/ASR/whisper/ds_config_zero1.json @@ -0,0 +1,32 @@ +{ + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 1000, + "initial_scale_power": 16, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "zero_optimization": { + "stage": 1, + "allgather_partitions": true, + "allgather_bucket_size": 2e8, + "overlap_comm": true, + "reduce_scatter": true, + "reduce_bucket_size": 2e8, + "contiguous_gradients": true + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "warmup_min_lr": 5e-6, + "warmup_max_lr": 1e-5, + "warmup_num_steps": 100 + } + }, + "gradient_accumulation_steps": 1, + "gradient_clipping": 5, + "steps_per_print": 50, + "train_micro_batch_size_per_gpu": 1, + "wall_clock_breakdown": false +} \ No newline at end of file diff --git a/egs/aishell/ASR/whisper/model.py b/egs/aishell/ASR/whisper/model.py index 953b80ff4..89d76383a 100644 --- a/egs/aishell/ASR/whisper/model.py +++ b/egs/aishell/ASR/whisper/model.py @@ -372,7 +372,7 @@ def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]: def load_model( name: str, - device: Optional[Union[str, torch.device]] = None, + device: Optional[Union[str, torch.device]] = 'cpu', download_root: str = None, in_memory: bool = False, ) -> Whisper: @@ -397,8 +397,8 @@ def load_model( The Whisper ASR model instance """ - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" + # if device is None: + # device = "cuda" if torch.cuda.is_available() else "cpu" if download_root is None: default = os.path.join(os.path.expanduser("~"), ".cache") download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper") diff --git a/egs/aishell/ASR/whisper/requirements.txt b/egs/aishell/ASR/whisper/requirements.txt index 654851b73..b71e4a4ad 100644 --- a/egs/aishell/ASR/whisper/requirements.txt +++ b/egs/aishell/ASR/whisper/requirements.txt @@ -8,3 +8,4 @@ librosa openai-whisper zhconv WeTextProcessing +deepspeed diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index edff5edfc..5ae261335 100644 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -42,6 +42,8 @@ import warnings from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union +import deepspeed +from deepspeed.utils.zero_to_fp32 import convert_zero_checkpoint_to_fp32_state_dict import k2 import optim @@ -102,15 +104,6 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: if hasattr(module, "batch_count"): module.batch_count = batch_count - -def add_deepspeed_arguments(parser: argparse.ArgumentParser): - parser.add_argument( - "--deepspeed-config", - type=str, - default=None, - help="Path to deepspeed json config file.", - ) - def get_parser(): parser = argparse.ArgumentParser( formatter_class=argparse.ArgumentDefaultsHelpFormatter @@ -251,7 +244,7 @@ def get_parser(): help="Whether to use half precision training.", ) - add_deepspeed_arguments(parser) + parser = deepspeed.add_config_arguments(parser) return parser @@ -495,7 +488,6 @@ def compute_loss( feature = feature.transpose(1, 2) # (N, C, T) # pad feature from B,80,T to B,80,3000 #feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1])) - #print(feature.shape, 23333333) supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) @@ -629,24 +621,25 @@ def train_one_epoch( for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - tokenizer=tokenizer, - model=model, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) + # if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + # logging.info("Computing validation loss") + # valid_info = compute_validation_loss( + # params=params, + # tokenizer=tokenizer, + # model=model, + # valid_dl=valid_dl, + # world_size=world_size, + # ) + # model.train() + # logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + # logging.info( + # f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + # ) + # if tb_writer is not None: + # valid_info.write_summary( + # tb_writer, "train/valid_", params.batch_idx_train + # ) + try: with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( @@ -661,13 +654,20 @@ def train_one_epoch( # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. - scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) - scheduler.step_batch(params.batch_idx_train) + if params.deepspeed: + # deepspeed's backward() is different from torch's backward() + # in that it does not accept a loss tensor as input. + # It computes the loss internally. + model.backward(loss) + model.step() + else: + scaler.scale(loss).backward() + set_batch_count(model, params.batch_idx_train) + scheduler.step_batch(params.batch_idx_train) - scaler.step(optimizer) - scaler.update() - optimizer.zero_grad() + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() except: # noqa display_and_save_batch(batch, params=params) raise @@ -679,6 +679,7 @@ def train_one_epoch( rank == 0 and params.batch_idx_train > 0 and params.batch_idx_train % params.average_period == 0 + and not params.deepspeed ): update_averaged_model( params=params, @@ -686,29 +687,28 @@ def train_one_epoch( model_avg=model_avg, ) - if ( - params.batch_idx_train > 0 - and params.batch_idx_train % params.save_every_n == 0 - ): - save_checkpoint_with_global_batch_idx( - out_dir=params.exp_dir, - global_batch_idx=params.batch_idx_train, - model=model, - model_avg=model_avg, - params=params, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) - remove_checkpoints( - out_dir=params.exp_dir, - topk=params.keep_last_k, - rank=rank, - ) - - if batch_idx % 100 == 0 and params.use_fp16: + # if ( + # params.batch_idx_train > 0 + # and params.batch_idx_train % params.save_every_n == 0 + # ): + # save_checkpoint_with_global_batch_idx( + # out_dir=params.exp_dir, + # global_batch_idx=params.batch_idx_train, + # model=model, + # model_avg=model_avg, + # params=params, + # optimizer=optimizer, + # scheduler=scheduler, + # sampler=train_dl.sampler, + # scaler=scaler, + # rank=rank, + # ) + # remove_checkpoints( + # out_dir=params.exp_dir, + # topk=params.keep_last_k, + # rank=rank, + # ) + if batch_idx % 100 == 0 and params.use_fp16 and not params.deepspeed: # If the grad scale was less than 1, try increasing it. The _growth_interval # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. @@ -723,14 +723,14 @@ def train_one_epoch( ) if batch_idx % params.log_interval == 0: cur_lr = scheduler.get_last_lr()[0] - cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0 + cur_grad_scale = scaler._scale.item() if (params.use_fp16 and not params.deepspeed) else 1.0 logging.info( f"Epoch {params.cur_epoch}, " f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + + (f"grad_scale: {scaler._scale.item()}" if (params.use_fp16 and not params.deepspeed) else "") ) if tb_writer is not None: @@ -774,37 +774,21 @@ def run(rank, world_size, args): fix_random_seed(params.seed) - setup_dist(use_ddp_launch=True) - setup_logger(f"{params.exp_dir}/log/log-train") - logging.info("Training started") - - if args.tensorboard and rank == 0: - tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") - else: - tb_writer = None - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", rank) - logging.info(f"Device: {device}") - - - + logging.info(params) logging.info("About to create model") - #model = whisper.load_model("medium") # TODO download model only on rank 0 # TODO may change compute validation loss using multiple cards - model = load_model("medium") + # model = load_model("medium") + model = load_model("large-v2") del model.alignment_heads + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + tokenizer = whisper.tokenizer.get_tokenizer( model.is_multilingual, language="zh", task="transcribe" ) - logging.info(params) - - num_param = sum([p.numel() for p in model.parameters()]) - logging.info(f"Number of model parameters: {num_param}") assert params.save_every_n >= params.average_period model_avg: Optional[nn.Module] = None @@ -817,10 +801,12 @@ def run(rank, world_size, args): params=params, model=model, model_avg=model_avg ) + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + else: + device = torch.device("cpu") + logging.info(f"Device: {device}") model.to(device) - if world_size > 1: - logging.info("Using DDP") - model = DDP(model, device_ids=[rank], find_unused_parameters=True) optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -837,6 +823,17 @@ def run(rank, world_size, args): logging.info("Loading scheduler state dict") scheduler.load_state_dict(checkpoints["scheduler"]) + if world_size > 1: + if params.deepspeed: + logging.info("Using DeepSpeed") + model, optimizer, _, _ = deepspeed.initialize( + args=params, model=model, optimizer=optimizer, + model_parameters=model.parameters()) + else: + logging.info("Using DDP") + setup_dist(use_ddp_launch=True) + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( 2**22 @@ -846,51 +843,8 @@ def run(rank, world_size, args): if params.inf_check: register_inf_check_hooks(model) - def remove_short_and_long_utt(c: Cut): - # Keep only utterances with duration between 1 second and 20 seconds - # - # Caution: There is a reason to select 20.0 here. Please see - # ../local/display_manifest_statistics.py - # - # You should use ../local/display_manifest_statistics.py to get - # an utterance duration distribution for your dataset to select - # the threshold - if c.duration < 1.0 or c.duration > 12.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) - return False - - # In pruned RNN-T, we require that T >= S - # where T is the number of feature frames after subsampling - # and S is the number of tokens in the utterance - - # In ./zipformer.py, the conv module uses the following expression - # for subsampling - # T = ((c.num_frames - 7) // 2 + 1) // 2 - # tokens = sp.encode(c.supervisions[0].text, out_type=str) - - # if T < len(tokens): - # logging.warning( - # f"Exclude cut with ID {c.id} from training. " - # f"Number of frames (before subsampling): {c.num_frames}. " - # f"Number of frames (after subsampling): {T}. " - # f"Text: {c.supervisions[0].text}. " - # f"Tokens: {tokens}. " - # f"Number of tokens: {len(tokens)}" - # ) - # return False - - return True - - - aishell = AishellAsrDataModule(args) - - - - if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint # saved in the middle of an epoch @@ -899,22 +853,19 @@ def run(rank, world_size, args): sampler_state_dict = None - train_dl = aishell.train_dataloaders(aishell.train_cuts()) - valid_dl = aishell.valid_dataloaders(aishell.valid_cuts()) - # if not params.print_diagnostics: - # scan_pessimistic_batches_for_oom( - # model=model, - # train_dl=train_dl, - # optimizer=optimizer, - # graph_compiler=graph_compiler, - # params=params, - # ) + train_dl = aishell.train_dataloaders(aishell.train_cuts(), rank=rank, world_size=world_size) + valid_dl = aishell.valid_dataloaders(aishell.valid_cuts(), rank=rank, world_size=world_size) scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0) if checkpoints and "grad_scaler" in checkpoints: logging.info("Loading grad scaler state dict") scaler.load_state_dict(checkpoints["grad_scaler"]) + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + logging.info(f"start training from epoch {params.start_epoch}") for epoch in range(params.start_epoch, params.num_epochs + 1): scheduler.step_epoch(epoch - 1) @@ -945,20 +896,28 @@ def run(rank, world_size, args): diagnostic.print_diagnostics() break - save_checkpoint( - params=params, - model=model, - model_avg=model_avg, - optimizer=optimizer, - scheduler=scheduler, - sampler=train_dl.sampler, - scaler=scaler, - rank=rank, - ) + if params.deepspeed: + model.save_checkpoint(save_dir=params.exp_dir, + tag=f"epoch-{params.cur_epoch}", + client_state={}) + convert_zero_checkpoint_to_fp32_state_dict( + params.exp_dir, f"epoch-{params.cur_epoch}.pt", + tag=f"epoch-{params.cur_epoch}") + else: + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) logging.info("Done!") - if world_size > 1: + if world_size > 1 and not params.deepspeed: torch.distributed.barrier() cleanup_dist() @@ -988,48 +947,6 @@ def display_and_save_batch( logging.info(f"features shape: {features.shape}") -# def scan_pessimistic_batches_for_oom( -# model: Union[nn.Module, DDP], -# tokenizer: whisper.tokenizer.Tokenizer, -# train_dl: torch.utils.data.DataLoader, -# optimizer: torch.optim.Optimizer, -# params: AttributeDict, -# ): -# from lhotse.dataset import find_pessimistic_batches - -# logging.info( -# "Sanity check -- see if any of the batches in epoch 1 would cause OOM." -# ) -# batches, crit_values = find_pessimistic_batches(train_dl.sampler) -# for criterion, cuts in batches.items(): -# batch = train_dl.dataset[cuts] -# try: -# with torch.cuda.amp.autocast(enabled=params.use_fp16): -# loss, _ = compute_loss( -# params=params, -# tokenizer=tokenizer, -# model=model, -# batch=batch, -# is_training=True, -# ) -# loss.backward() -# optimizer.zero_grad() -# except Exception as e: -# if "CUDA out of memory" in str(e): -# logging.error( -# "Your GPU ran out of memory with the current " -# "max_duration setting. We recommend decreasing " -# "max_duration and trying again.\n" -# f"Failing criterion: {criterion} " -# f"(={crit_values[criterion]}) ..." -# ) -# display_and_save_batch(batch, params=params) -# raise -# logging.info( -# f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" -# ) - - def main(): parser = get_parser() AishellAsrDataModule.add_arguments(parser) @@ -1038,13 +955,10 @@ def main(): world_size = get_world_size() rank = get_rank() - assert world_size >= 1 + torch.set_num_threads(1) + torch.set_num_interop_threads(1) run(rank=rank, world_size=world_size, args=args) - -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": main()