diff --git a/egs/aishell/ASR/prepare.sh b/egs/aishell/ASR/prepare.sh index 9f73a2073..78c635690 100755 --- a/egs/aishell/ASR/prepare.sh +++ b/egs/aishell/ASR/prepare.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash - +export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall # fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python @@ -120,6 +120,15 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then fi fi +if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then + log "Stage 30: Compute whisper fbank for aishell" + if [ ! -f data/fbank/.aishell.done ]; then + mkdir -p data/fbank + ./local/compute_whisper_fbank_aishell.py --perturb-speed True + touch data/fbank/.aishell.done + fi +fi + if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Compute fbank for musan" if [ ! -f data/fbank/.msuan.done ]; then @@ -129,6 +138,15 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then fi fi +if [ $stage -le 40 ] && [ $stop_stage -ge 40 ]; then + log "Stage 4: Compute fbank for musan" + if [ ! -f data/fbank/.msuan.done ]; then + mkdir -p data/fbank + ./local/compute_whisper_fbank_musan.py + touch data/fbank/.msuan.done + fi +fi + lang_phone_dir=data/lang_phone if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Prepare phone based lang" diff --git a/egs/aishell/ASR/seamlessm4t/model.py b/egs/aishell/ASR/seamlessm4t/model.py index 18e5996bb..efe18d5ff 100644 --- a/egs/aishell/ASR/seamlessm4t/model.py +++ b/egs/aishell/ASR/seamlessm4t/model.py @@ -1,4 +1,6 @@ import torch +import torch.nn as nn +from fairseq2.nn.embedding import Embedding from seamless_communication.models.inference import Translator from seamless_communication.models.unity import ( UnitTokenizer, @@ -53,15 +55,28 @@ source_seqs = source_seqs.to(device=device, dtype=torch.float16) dtype = torch.float16 model = load_unity_model(model_name_or_card, device=device, dtype=dtype) model.eval() -text_tokenizer = load_unity_text_tokenizer(model_name_or_card) -print(text_tokenizer.model.eos_idx, text_tokenizer.model.pad_idx) -text_tokenizer_encoder = text_tokenizer.create_encoder(lang=target_lang, mode="target") -text_tokenizer_decoder = text_tokenizer.create_decoder() -# print attritbut of text_tokenizer_encoder - -print(text_tokenizer_encoder("")) -print(text_tokenizer_decoder(torch.tensor([3,45]))) +model.text_decoder_frontend.embed = Embedding(num_embeddings=6257, embedding_dim=1024 ,pad_idx=0, scaled=True) +model.final_proj = nn.Linear(1024, 6257) +model.half() +print(model.text_decoder_frontend.embed, model.text_encoder_frontend.embed.weight.dtype, type(model.text_encoder_frontend.embed), type(model.text_encoder_frontend.embed.weight)) +print(model.final_proj, model.final_proj.weight.dtype, type(model.final_proj), type(model.final_proj.weight)) +#input() exit(0) +text_tokenizer = load_unity_text_tokenizer(model_name_or_card) +#print(text_tokenizer.model.eos_idx, text_tokenizer.model.pad_idx) +#text_tokenizer_encoder = text_tokenizer.create_encoder(lang=target_lang, mode="target") +#text_tokenizer_decoder = text_tokenizer.create_decoder() +# print attritbut of text_tokenizer_encoder +#print(text_tokenizer.vocab_info) +#print(text_tokenizer_encoder("其中广州深圳甚至出现了多个日光盘")) +#print(text_tokenizer_decoder(torch.tensor([3,256200,137139,252603,250476,250590,1,84778,148897,249568,249352,249947,249050,250520,254508]))) + +# store all vocab in a file +# with open("vocab.txt", "w") as f: +# for i in range(256206): +# f.write(f"{i}: " + text_tokenizer_decoder(torch.tensor([i]))[0].bytes().decode("utf-8")+ "\n") +# f.close() +# exit(0) @@ -112,6 +127,7 @@ s2t_generator = SequenceToTextGenerator( ) text_output = s2t_generator.generate_ex(source_seqs, source_seq_lens) -sentence = text_output.sentences[0] -print(sentence, type(sentence)) -sentence = sentence.bytes().decode("utf-8") +print(text_output.generator_output.results[0][0].seq.cpu().tolist()) +# sentence = text_output.sentences[0] +# print(sentence, type(sentence)) +# sentence = sentence.bytes().decode("utf-8") diff --git a/egs/aishell/ASR/whisper/decode.py b/egs/aishell/ASR/whisper/decode.py index 44c6ea081..28ac83562 100644 --- a/egs/aishell/ASR/whisper/decode.py +++ b/egs/aishell/ASR/whisper/decode.py @@ -225,6 +225,7 @@ def decode_one_batch( hyps = to_simple(hyps) hyps = [params.normalizer.normalize(hyp) for hyp in hyps] + print(hyps, 233333333) key = "beam-search" @@ -374,26 +375,26 @@ def main(): logging.info(f"device: {device}") model = whisper.load_model("medium") - # if params.epoch > 0: - # if params.avg > 1: - # start = params.epoch - params.avg - # assert start >= 1, start - # filename_start = f"{params.exp_dir}/epoch-{start}.pt" - # filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" - # logging.info( - # f"Calculating the averaged model over epoch range from " - # f"{start} (excluded) to {params.epoch}" - # ) - # model.to(device) - # model.load_state_dict( - # average_checkpoints_with_averaged_model( - # filename_start=filename_start, - # filename_end=filename_end, - # device=device, - # ) - # ) - # else: - # load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + if params.epoch > 0: + if params.avg > 1: + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) model.to(device) model.eval() num_param = sum([p.numel() for p in model.parameters()]) diff --git a/egs/aishell/ASR/whisper/requirements.txt b/egs/aishell/ASR/whisper/requirements.txt index 6edea0afc..654851b73 100644 --- a/egs/aishell/ASR/whisper/requirements.txt +++ b/egs/aishell/ASR/whisper/requirements.txt @@ -7,3 +7,4 @@ tensorboard librosa openai-whisper zhconv +WeTextProcessing diff --git a/egs/aishell/ASR/whisper/train.py b/egs/aishell/ASR/whisper/train.py index 95af2f056..77035318d 100644 --- a/egs/aishell/ASR/whisper/train.py +++ b/egs/aishell/ASR/whisper/train.py @@ -430,9 +430,9 @@ def get_params() -> AttributeDict: "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, - "log_interval": 50, + "log_interval": 1, "reset_interval": 200, - "valid_interval": 3000, # For the 100h subset, use 800 + "valid_interval": 50, # For the 100h subset, use 800 # parameters for zipformer "feature_dim": 80, "subsampling_factor": 4, # not passed in, this is fixed. @@ -578,6 +578,7 @@ def save_checkpoint( def compute_loss( params: AttributeDict, + tokenizer: whisper.tokenizer.Tokenizer, model: Union[nn.Module, DDP], batch: dict, is_training: bool, @@ -631,8 +632,8 @@ def compute_loss( feature = feature.to(device) feature = feature.transpose(1, 2) # (N, C, T) # pad feature from B,80,T to B,80,3000 - # feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1])) - + feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1])) + print(feature.shape, 23333333) supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) @@ -640,27 +641,48 @@ def compute_loss( warm_step = params.warm_step texts = batch["supervisions"]["text"] - - text_tokens_list = [list(params.tokenizer.sot_sequence_including_notimestamps) + params.tokenizer.encode(text) + [params.tokenizer.eot] for text in texts] + # remove spaces in texts + texts = [text.replace(" ", "") for text in texts] + #print(texts) + text_tokens_list = [list(tokenizer.sot_sequence_including_notimestamps) + tokenizer.encode(text) + [tokenizer.eot] for text in texts] # convert it to torch tensor text_tokens_list = [torch.LongTensor(text_tokens) for text_tokens in text_tokens_list] + # prev_outputs_tokens = _batch_tensors( + # [tokens[:-1] for tokens in text_tokens_list], pad_value=tokenizer.eot + # ) + # target_tokens = _batch_tensors( + # [tokens[1:] for tokens in text_tokens_list], pad_value=tokenizer.eot + # ) prev_outputs_tokens = _batch_tensors( - [tokens[:-1] for tokens in text_tokens_list], pad_value=params.tokenizer.eot + [tokens[:-1] for tokens in text_tokens_list], pad_value=50256 ) target_tokens = _batch_tensors( - [tokens[1:] for tokens in text_tokens_list], pad_value=params.tokenizer.eot + [tokens[1:] for tokens in text_tokens_list], pad_value=50256 ) target_lengths = torch.LongTensor( [tokens.shape[0] - 1 for tokens in text_tokens_list] ) - decoder_criterion = LabelSmoothingLoss(ignore_index=params.tokenizer.eot, label_smoothing=0.1, reduction="sum") + #print(prev_outputs_tokens.shape, prev_outputs_tokens) + #print(target_tokens.shape, target_tokens) + #print(target_lengths.shape, target_lengths) + #print(text_tokens_list) + #print("==========================================") + decoder_criterion = LabelSmoothingLoss(ignore_index=50256, label_smoothing=0.1, reduction="sum") + ignore_prefix_size = 3 with torch.set_grad_enabled(is_training): encoder_out = model.encoder(feature) text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out) loss = decoder_criterion(text_logits, target_tokens.to(device)) + text_logits = text_logits[:, ignore_prefix_size:, :] + target_tokens = target_tokens[:, ignore_prefix_size:] + #print(text_logits.shape) + # print greedy results of text_logits + #print(text_logits.argmax(dim=-1)) + # convert it to list of list then decode + #print([tokenizer.decode(tokens) for tokens in text_logits.argmax(dim=-1).tolist()]) assert loss.requires_grad == is_training @@ -677,6 +699,7 @@ def compute_loss( def compute_validation_loss( params: AttributeDict, + tokenizer: whisper.tokenizer.Tokenizer, model: Union[nn.Module, DDP], valid_dl: torch.utils.data.DataLoader, world_size: int = 1, @@ -689,6 +712,7 @@ def compute_validation_loss( for batch_idx, batch in enumerate(valid_dl): loss, loss_info = compute_loss( params=params, + tokenizer=tokenizer, model=model, batch=batch, is_training=False, @@ -709,6 +733,7 @@ def compute_validation_loss( def train_one_epoch( params: AttributeDict, + tokenizer: whisper.tokenizer.Tokenizer, model: Union[nn.Module, DDP], optimizer: torch.optim.Optimizer, scheduler: LRSchedulerType, @@ -758,11 +783,29 @@ def train_one_epoch( for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - + # if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + # logging.info("Computing validation loss") + # valid_info = compute_validation_loss( + # params=params, + # tokenizer=tokenizer, + # model=model, + # valid_dl=valid_dl, + # world_size=world_size, + # ) + # model.train() + # logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + # logging.info( + # f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + # ) + # if tb_writer is not None: + # valid_info.write_summary( + # tb_writer, "train/valid_", params.batch_idx_train + # ) try: with torch.cuda.amp.autocast(enabled=params.use_fp16): loss, loss_info = compute_loss( params=params, + tokenizer=tokenizer, model=model, batch=batch, is_training=True, @@ -860,23 +903,24 @@ def train_one_epoch( params.batch_idx_train, ) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: - logging.info("Computing validation loss") - valid_info = compute_validation_loss( - params=params, - model=model, - valid_dl=valid_dl, - world_size=world_size, - ) - model.train() - logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) - if tb_writer is not None: - valid_info.write_summary( - tb_writer, "train/valid_", params.batch_idx_train - ) + # if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + # logging.info("Computing validation loss") + # valid_info = compute_validation_loss( + # params=params, + # tokenizer=tokenizer, + # model=model, + # valid_dl=valid_dl, + # world_size=world_size, + # ) + # model.train() + # logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + # logging.info( + # f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + # ) + # if tb_writer is not None: + # valid_info.write_summary( + # tb_writer, "train/valid_", params.batch_idx_train + # ) loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value @@ -923,10 +967,10 @@ def run(rank, world_size, args): logging.info("About to create model") - #model = whisper.load_model("medium") - model = load_model("medium") + model = whisper.load_model("medium") + #model = load_model("medium") del model.alignment_heads - params.tokenizer = whisper.tokenizer.get_tokenizer( + tokenizer = whisper.tokenizer.get_tokenizer( model.is_multilingual, language="zh", task="transcribe" ) logging.info(params) @@ -960,11 +1004,13 @@ def run(rank, world_size, args): # clipping_scale=2.0, # parameters_names=parameters_names, # ) - optimizer = ScaledAdam( - model.parameters(), - lr=params.base_lr, - clipping_scale=2.0, - ) + # optimizer = ScaledAdam( + # model.parameters(), + # lr=params.base_lr, + # clipping_scale=2.0, + # ) + + optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) if checkpoints and "optimizer" in checkpoints: @@ -1085,6 +1131,7 @@ def run(rank, world_size, args): train_one_epoch( params=params, + tokenizer=tokenizer, model=model, model_avg=model_avg, optimizer=optimizer, @@ -1148,44 +1195,46 @@ def display_and_save_batch( # logging.info(f"num tokens: {num_tokens}") -def scan_pessimistic_batches_for_oom( - model: Union[nn.Module, DDP], - train_dl: torch.utils.data.DataLoader, - optimizer: torch.optim.Optimizer, - params: AttributeDict, -): - from lhotse.dataset import find_pessimistic_batches +# def scan_pessimistic_batches_for_oom( +# model: Union[nn.Module, DDP], +# tokenizer: whisper.tokenizer.Tokenizer, +# train_dl: torch.utils.data.DataLoader, +# optimizer: torch.optim.Optimizer, +# params: AttributeDict, +# ): +# from lhotse.dataset import find_pessimistic_batches - logging.info( - "Sanity check -- see if any of the batches in epoch 1 would cause OOM." - ) - batches, crit_values = find_pessimistic_batches(train_dl.sampler) - for criterion, cuts in batches.items(): - batch = train_dl.dataset[cuts] - try: - with torch.cuda.amp.autocast(enabled=params.use_fp16): - loss, _ = compute_loss( - params=params, - model=model, - batch=batch, - is_training=True, - ) - loss.backward() - optimizer.zero_grad() - except Exception as e: - if "CUDA out of memory" in str(e): - logging.error( - "Your GPU ran out of memory with the current " - "max_duration setting. We recommend decreasing " - "max_duration and trying again.\n" - f"Failing criterion: {criterion} " - f"(={crit_values[criterion]}) ..." - ) - display_and_save_batch(batch, params=params) - raise - logging.info( - f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" - ) +# logging.info( +# "Sanity check -- see if any of the batches in epoch 1 would cause OOM." +# ) +# batches, crit_values = find_pessimistic_batches(train_dl.sampler) +# for criterion, cuts in batches.items(): +# batch = train_dl.dataset[cuts] +# try: +# with torch.cuda.amp.autocast(enabled=params.use_fp16): +# loss, _ = compute_loss( +# params=params, +# tokenizer=tokenizer, +# model=model, +# batch=batch, +# is_training=True, +# ) +# loss.backward() +# optimizer.zero_grad() +# except Exception as e: +# if "CUDA out of memory" in str(e): +# logging.error( +# "Your GPU ran out of memory with the current " +# "max_duration setting. We recommend decreasing " +# "max_duration and trying again.\n" +# f"Failing criterion: {criterion} " +# f"(={crit_values[criterion]}) ..." +# ) +# display_and_save_batch(batch, params=params) +# raise +# logging.info( +# f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" +# ) def main():