mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-03 22:24:19 +00:00
remove padding to 30s, compute validation loss once
This commit is contained in:
parent
07cefa82a7
commit
98d11abedb
@ -430,9 +430,9 @@ def get_params() -> AttributeDict:
|
||||
"best_train_epoch": -1,
|
||||
"best_valid_epoch": -1,
|
||||
"batch_idx_train": 0,
|
||||
"log_interval": 1,
|
||||
"log_interval": 50,
|
||||
"reset_interval": 200,
|
||||
"valid_interval": 50, # For the 100h subset, use 800
|
||||
"valid_interval": 99999999999, # For the 100h subset, use 800
|
||||
# parameters for zipformer
|
||||
"feature_dim": 80,
|
||||
"subsampling_factor": 4, # not passed in, this is fixed.
|
||||
@ -632,8 +632,8 @@ def compute_loss(
|
||||
feature = feature.to(device)
|
||||
feature = feature.transpose(1, 2) # (N, C, T)
|
||||
# pad feature from B,80,T to B,80,3000
|
||||
feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1]))
|
||||
print(feature.shape, 23333333)
|
||||
#feature = torch.nn.functional.pad(feature, (0, 3000 - feature.shape[-1]))
|
||||
#print(feature.shape, 23333333)
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
@ -783,24 +783,24 @@ def train_one_epoch(
|
||||
for batch_idx, batch in enumerate(train_dl):
|
||||
params.batch_idx_train += 1
|
||||
batch_size = len(batch["supervisions"]["text"])
|
||||
# if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
# logging.info("Computing validation loss")
|
||||
# valid_info = compute_validation_loss(
|
||||
# params=params,
|
||||
# tokenizer=tokenizer,
|
||||
# model=model,
|
||||
# valid_dl=valid_dl,
|
||||
# world_size=world_size,
|
||||
# )
|
||||
# model.train()
|
||||
# logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
# logging.info(
|
||||
# f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
# )
|
||||
# if tb_writer is not None:
|
||||
# valid_info.write_summary(
|
||||
# tb_writer, "train/valid_", params.batch_idx_train
|
||||
# )
|
||||
if batch_idx % params.valid_interval == 0 and not params.print_diagnostics:
|
||||
logging.info("Computing validation loss")
|
||||
valid_info = compute_validation_loss(
|
||||
params=params,
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
valid_dl=valid_dl,
|
||||
world_size=world_size,
|
||||
)
|
||||
model.train()
|
||||
logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}")
|
||||
logging.info(
|
||||
f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB"
|
||||
)
|
||||
if tb_writer is not None:
|
||||
valid_info.write_summary(
|
||||
tb_writer, "train/valid_", params.batch_idx_train
|
||||
)
|
||||
try:
|
||||
with torch.cuda.amp.autocast(enabled=params.use_fp16):
|
||||
loss, loss_info = compute_loss(
|
||||
@ -967,8 +967,10 @@ def run(rank, world_size, args):
|
||||
|
||||
|
||||
logging.info("About to create model")
|
||||
model = whisper.load_model("medium")
|
||||
#model = load_model("medium")
|
||||
#model = whisper.load_model("medium")
|
||||
# TODO download model only on rank 0
|
||||
# TODO may change compute validation loss using multiple cards
|
||||
model = load_model("medium")
|
||||
del model.alignment_heads
|
||||
tokenizer = whisper.tokenizer.get_tokenizer(
|
||||
model.is_multilingual, language="zh", task="transcribe"
|
||||
|
Loading…
x
Reference in New Issue
Block a user