mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
change scaleadam to adamw
This commit is contained in:
parent
8b832f168d
commit
07cefa82a7
@ -1,5 +1,5 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
export PYTHONPATH=$PYTHONPATH:/mnt/samsung-t7/yuekai/asr/icefall
|
||||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||
|
||||
@ -120,6 +120,15 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 30 ] && [ $stop_stage -ge 30 ]; then
|
||||
log "Stage 30: Compute whisper fbank for aishell"
|
||||
if [ ! -f data/fbank/.aishell.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_whisper_fbank_aishell.py --perturb-speed True
|
||||
touch data/fbank/.aishell.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for musan"
|
||||
if [ ! -f data/fbank/.msuan.done ]; then
|
||||
@ -129,6 +138,15 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 40 ] && [ $stop_stage -ge 40 ]; then
|
||||
log "Stage 4: Compute fbank for musan"
|
||||
if [ ! -f data/fbank/.msuan.done ]; then
|
||||
mkdir -p data/fbank
|
||||
./local/compute_whisper_fbank_musan.py
|
||||
touch data/fbank/.msuan.done
|
||||
fi
|
||||
fi
|
||||
|
||||
lang_phone_dir=data/lang_phone
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "Stage 5: Prepare phone based lang"
|
||||
|
@ -1,4 +1,6 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from fairseq2.nn.embedding import Embedding
|
||||
from seamless_communication.models.inference import Translator
|
||||
from seamless_communication.models.unity import (
|
||||
UnitTokenizer,
|
||||
@ -53,15 +55,28 @@ source_seqs = source_seqs.to(device=device, dtype=torch.float16)
|
||||
dtype = torch.float16
|
||||
model = load_unity_model(model_name_or_card, device=device, dtype=dtype)
|
||||
model.eval()
|
||||
text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
||||
print(text_tokenizer.model.eos_idx, text_tokenizer.model.pad_idx)
|
||||
text_tokenizer_encoder = text_tokenizer.create_encoder(lang=target_lang, mode="target")
|
||||
text_tokenizer_decoder = text_tokenizer.create_decoder()
|
||||
# print attritbut of text_tokenizer_encoder
|
||||
|
||||
print(text_tokenizer_encoder("<eos>"))
|
||||
print(text_tokenizer_decoder(torch.tensor([3,45])))
|
||||
model.text_decoder_frontend.embed = Embedding(num_embeddings=6257, embedding_dim=1024 ,pad_idx=0, scaled=True)
|
||||
model.final_proj = nn.Linear(1024, 6257)
|
||||
model.half()
|
||||
print(model.text_decoder_frontend.embed, model.text_encoder_frontend.embed.weight.dtype, type(model.text_encoder_frontend.embed), type(model.text_encoder_frontend.embed.weight))
|
||||
print(model.final_proj, model.final_proj.weight.dtype, type(model.final_proj), type(model.final_proj.weight))
|
||||
#input()
|
||||
exit(0)
|
||||
text_tokenizer = load_unity_text_tokenizer(model_name_or_card)
|
||||
#print(text_tokenizer.model.eos_idx, text_tokenizer.model.pad_idx)
|
||||
#text_tokenizer_encoder = text_tokenizer.create_encoder(lang=target_lang, mode="target")
|
||||
#text_tokenizer_decoder = text_tokenizer.create_decoder()
|
||||
# print attritbut of text_tokenizer_encoder
|
||||
#print(text_tokenizer.vocab_info)
|
||||
#print(text_tokenizer_encoder("其中广州深圳甚至出现了多个日光盘"))
|
||||
#print(text_tokenizer_decoder(torch.tensor([3,256200,137139,252603,250476,250590,1,84778,148897,249568,249352,249947,249050,250520,254508])))
|
||||
|
||||
# store all vocab in a file
|
||||
# with open("vocab.txt", "w") as f:
|
||||
# for i in range(256206):
|
||||
# f.write(f"{i}: " + text_tokenizer_decoder(torch.tensor([i]))[0].bytes().decode("utf-8")+ "\n")
|
||||
# f.close()
|
||||
# exit(0)
|
||||
|
||||
|
||||
|
||||
@ -112,6 +127,7 @@ s2t_generator = SequenceToTextGenerator(
|
||||
)
|
||||
|
||||
text_output = s2t_generator.generate_ex(source_seqs, source_seq_lens)
|
||||
sentence = text_output.sentences[0]
|
||||
print(sentence, type(sentence))
|
||||
sentence = sentence.bytes().decode("utf-8")
|
||||
print(text_output.generator_output.results[0][0].seq.cpu().tolist())
|
||||
# sentence = text_output.sentences[0]
|
||||
# print(sentence, type(sentence))
|
||||
# sentence = sentence.bytes().decode("utf-8")
|
||||
|
@ -225,6 +225,7 @@ def decode_one_batch(
|
||||
hyps = to_simple(hyps)
|
||||
|
||||
hyps = [params.normalizer.normalize(hyp) for hyp in hyps]
|
||||
print(hyps, 233333333)
|
||||
|
||||
key = "beam-search"
|
||||
|
||||
@ -374,26 +375,26 @@ def main():
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
model = whisper.load_model("medium")
|
||||
# if params.epoch > 0:
|
||||
# if params.avg > 1:
|
||||
# start = params.epoch - params.avg
|
||||
# assert start >= 1, start
|
||||
# filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
# filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
# logging.info(
|
||||
# f"Calculating the averaged model over epoch range from "
|
||||
# f"{start} (excluded) to {params.epoch}"
|
||||
# )
|
||||
# model.to(device)
|
||||
# model.load_state_dict(
|
||||
# average_checkpoints_with_averaged_model(
|
||||
# filename_start=filename_start,
|
||||
# filename_end=filename_end,
|
||||
# device=device,
|
||||
# )
|
||||
# )
|
||||
# else:
|
||||
# load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
if params.epoch > 0:
|
||||
if params.avg > 1:
|
||||
start = params.epoch - params.avg
|
||||
assert start >= 1, start
|
||||
filename_start = f"{params.exp_dir}/epoch-{start}.pt"
|
||||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
|
||||
logging.info(
|
||||
f"Calculating the averaged model over epoch range from "
|
||||
f"{start} (excluded) to {params.epoch}"
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints_with_averaged_model(
|
||||
filename_start=filename_start,
|
||||
filename_end=filename_end,
|
||||
device=device,
|
||||
)
|
||||
)
|
||||
else:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
|
@ -7,3 +7,4 @@ tensorboard
|
||||
librosa
|
||||
openai-whisper
|
||||
zhconv
|
||||
WeTextProcessing
|
||||
|
@ -430,9 +430,9 @@ def get_params() -> AttributeDict:
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 50,
|
||||
"log_interval": 1,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 3000, # For the 100h subset, use 800
|
||||
"valid_interval": 50, # For the 100h subset, use 800
|
||||
# parameters for zipformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||
@ -578,6 +578,7 @@ def save_checkpoint(
|
||||
|
||||
def compute_loss(
|
||||
params: AttributeDict,
|
||||
tokenizer: whisper.tokenizer.Tokenizer,
|
||||
model: Union[nn.Module, DDP],
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
@ -631,8 +632,8 @@ def compute_loss(
|
||||
feature = feature.to(device)
|
||||
feature = feature.transpose(1, 2) # (N, C, T)
|
||||
# pad feature from B,80,T to B,80,3000
|
||||
# feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1]))
|
||||
|
||||
feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1]))
|
||||
print(feature.shape, 23333333)
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
@ -640,27 +641,48 @@ def compute_loss(
|
||||
warm_step = params.warm_step
|
||||
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
text_tokens_list = [list(params.tokenizer.sot_sequence_including_notimestamps) + params.tokenizer.encode(text) + [params.tokenizer.eot] for text in texts]
|
||||
# remove spaces in texts
|
||||
texts = [text.replace(" ", "") for text in texts]
|
||||
#print(texts)
|
||||
text_tokens_list = [list(tokenizer.sot_sequence_including_notimestamps) + tokenizer.encode(text) + [tokenizer.eot] for text in texts]
|
||||
# convert it to torch tensor
|
||||
text_tokens_list = [torch.LongTensor(text_tokens) for text_tokens in text_tokens_list]
|
||||
|
||||
# prev_outputs_tokens = _batch_tensors(
|
||||
# [tokens[:-1] for tokens in text_tokens_list], pad_value=tokenizer.eot
|
||||
# )
|
||||
# target_tokens = _batch_tensors(
|
||||
# [tokens[1:] for tokens in text_tokens_list], pad_value=tokenizer.eot
|
||||
# )
|
||||
prev_outputs_tokens = _batch_tensors(
|
||||
[tokens[:-1] for tokens in text_tokens_list], pad_value=params.tokenizer.eot
|
||||
[tokens[:-1] for tokens in text_tokens_list], pad_value=50256
|
||||
)
|
||||
target_tokens = _batch_tensors(
|
||||
[tokens[1:] for tokens in text_tokens_list], pad_value=params.tokenizer.eot
|
||||
[tokens[1:] for tokens in text_tokens_list], pad_value=50256
|
||||
)
|
||||
target_lengths = torch.LongTensor(
|
||||
[tokens.shape[0] - 1 for tokens in text_tokens_list]
|
||||
)
|
||||
decoder_criterion = LabelSmoothingLoss(ignore_index=params.tokenizer.eot, label_smoothing=0.1, reduction="sum")
|
||||
|
||||
#print(prev_outputs_tokens.shape, prev_outputs_tokens)
|
||||
#print(target_tokens.shape, target_tokens)
|
||||
#print(target_lengths.shape, target_lengths)
|
||||
#print(text_tokens_list)
|
||||
#print("==========================================")
|
||||
decoder_criterion = LabelSmoothingLoss(ignore_index=50256, label_smoothing=0.1, reduction="sum")
|
||||
ignore_prefix_size = 3
|
||||
with torch.set_grad_enabled(is_training):
|
||||
|
||||
encoder_out = model.encoder(feature)
|
||||
text_logits = model.decoder(prev_outputs_tokens.to(device), encoder_out)
|
||||
loss = decoder_criterion(text_logits, target_tokens.to(device))
|
||||
text_logits = text_logits[:, ignore_prefix_size:, :]
|
||||
target_tokens = target_tokens[:, ignore_prefix_size:]
|
||||
#print(text_logits.shape)
|
||||
# print greedy results of text_logits
|
||||
#print(text_logits.argmax(dim=-1))
|
||||
# convert it to list of list then decode
|
||||
#print([tokenizer.decode(tokens) for tokens in text_logits.argmax(dim=-1).tolist()])
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
@ -677,6 +699,7 @@ def compute_loss(
|
||||
|
||||
def compute_validation_loss(
|
||||
params: AttributeDict,
|
||||
tokenizer: whisper.tokenizer.Tokenizer,
|
||||
model: Union[nn.Module, DDP],
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
world_size: int = 1,
|
||||
@ -689,6 +712,7 @@ def compute_validation_loss(
|
||||
for batch_idx, batch in enumerate(valid_dl):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=False,
|
||||
@ -709,6 +733,7 @@ def compute_validation_loss(
|
||||
|
||||
def train_one_epoch(
|
||||
params: AttributeDict,
|
||||
tokenizer: whisper.tokenizer.Tokenizer,
|
||||
model: Union[nn.Module, DDP],
|
||||
optimizer: torch.optim.Optimizer,
|
||||
scheduler: LRSchedulerType,
|
||||
@ -758,11 +783,29 @@ def train_one_epoch(
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
|
||||
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
# logging.info("Computing validation loss")
|
||||
# valid_info = compute_validation_loss(
|
||||
# params=params,
|
||||
# tokenizer=tokenizer,
|
||||
# model=model,
|
||||
# valid_dl=valid_dl,
|
||||
# world_size=world_size,
|
||||
# )
|
||||
# model.train()
|
||||
# logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
# logging.info(
|
||||
# f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
# )
|
||||
# if tb_writer is not None:
|
||||
# valid_info.write_summary(
|
||||
# tb_writer, "train/valid_", params.batch_idx_train
|
||||
# )
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
@ -860,23 +903,24 @@ def train_one_epoch(
|
||||
params.batch_idx_train,
|
||||
)
|
||||
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
# logging.info("Computing validation loss")
|
||||
# valid_info = compute_validation_loss(
|
||||
# params=params,
|
||||
# tokenizer=tokenizer,
|
||||
# model=model,
|
||||
# valid_dl=valid_dl,
|
||||
# world_size=world_size,
|
||||
# )
|
||||
# model.train()
|
||||
# logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
# logging.info(
|
||||
# f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
# )
|
||||
# if tb_writer is not None:
|
||||
# valid_info.write_summary(
|
||||
# tb_writer, "train/valid_", params.batch_idx_train
|
||||
# )
|
||||
|
||||
loss_value = tot_loss["loss"] / tot_loss["frames"]
|
||||
params.train_loss = loss_value
|
||||
@ -923,10 +967,10 @@ def run(rank, world_size, args):
|
||||
|
||||
|
||||
logging.info("About to create model")
|
||||
#model = whisper.load_model("medium")
|
||||
model = load_model("medium")
|
||||
model = whisper.load_model("medium")
|
||||
#model = load_model("medium")
|
||||
del model.alignment_heads
|
||||
params.tokenizer = whisper.tokenizer.get_tokenizer(
|
||||
tokenizer = whisper.tokenizer.get_tokenizer(
|
||||
model.is_multilingual, language="zh", task="transcribe"
|
||||
)
|
||||
logging.info(params)
|
||||
@ -960,11 +1004,13 @@ def run(rank, world_size, args):
|
||||
# clipping_scale=2.0,
|
||||
# parameters_names=parameters_names,
|
||||
# )
|
||||
optimizer = ScaledAdam(
|
||||
model.parameters(),
|
||||
lr=params.base_lr,
|
||||
clipping_scale=2.0,
|
||||
)
|
||||
# optimizer = ScaledAdam(
|
||||
# model.parameters(),
|
||||
# lr=params.base_lr,
|
||||
# clipping_scale=2.0,
|
||||
# )
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=params.base_lr)
|
||||
scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
|
||||
|
||||
if checkpoints and "optimizer" in checkpoints:
|
||||
@ -1085,6 +1131,7 @@ def run(rank, world_size, args):
|
||||
|
||||
train_one_epoch(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
model_avg=model_avg,
|
||||
optimizer=optimizer,
|
||||
@ -1148,44 +1195,46 @@ def display_and_save_batch(
|
||||
# logging.info(f"num tokens: {num_tokens}")
|
||||
|
||||
|
||||
def scan_pessimistic_batches_for_oom(
|
||||
model: Union[nn.Module, DDP],
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
params: AttributeDict,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
# def scan_pessimistic_batches_for_oom(
|
||||
# model: Union[nn.Module, DDP],
|
||||
# tokenizer: whisper.tokenizer.Tokenizer,
|
||||
# train_dl: torch.utils.data.DataLoader,
|
||||
# optimizer: torch.optim.Optimizer,
|
||||
# params: AttributeDict,
|
||||
# ):
|
||||
# from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
logging.info(
|
||||
"Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
||||
)
|
||||
batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||
for criterion, cuts in batches.items():
|
||||
batch = train_dl.dataset[cuts]
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, _ = compute_loss(
|
||||
params=params,
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
loss.backward()
|
||||
optimizer.zero_grad()
|
||||
except Exception as e:
|
||||
if "CUDA out of memory" in str(e):
|
||||
logging.error(
|
||||
"Your GPU ran out of memory with the current "
|
||||
"max_duration setting. We recommend decreasing "
|
||||
"max_duration and trying again.\n"
|
||||
f"Failing criterion: {criterion} "
|
||||
f"(={crit_values[criterion]}) ..."
|
||||
)
|
||||
display_and_save_batch(batch, params=params)
|
||||
raise
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
# logging.info(
|
||||
# "Sanity check -- see if any of the batches in epoch 1 would cause OOM."
|
||||
# )
|
||||
# batches, crit_values = find_pessimistic_batches(train_dl.sampler)
|
||||
# for criterion, cuts in batches.items():
|
||||
# batch = train_dl.dataset[cuts]
|
||||
# try:
|
||||
# with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
# loss, _ = compute_loss(
|
||||
# params=params,
|
||||
# tokenizer=tokenizer,
|
||||
# model=model,
|
||||
# batch=batch,
|
||||
# is_training=True,
|
||||
# )
|
||||
# loss.backward()
|
||||
# optimizer.zero_grad()
|
||||
# except Exception as e:
|
||||
# if "CUDA out of memory" in str(e):
|
||||
# logging.error(
|
||||
# "Your GPU ran out of memory with the current "
|
||||
# "max_duration setting. We recommend decreasing "
|
||||
# "max_duration and trying again.\n"
|
||||
# f"Failing criterion: {criterion} "
|
||||
# f"(={crit_values[criterion]}) ..."
|
||||
# )
|
||||
# display_and_save_batch(batch, params=params)
|
||||
# raise
|
||||
# logging.info(
|
||||
# f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
# )
|
||||
|
||||
|
||||
def main():
|
||||
|
Loading…
x
Reference in New Issue
Block a user