mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-12 19:42:19 +00:00
Version I am running...
This commit is contained in:
parent
e6eefeba88
commit
0d97e689be
@ -21,7 +21,7 @@ class MaskedLmConformer(nn.Module):
|
||||
d_model: int = 256,
|
||||
nhead: int = 4,
|
||||
dim_feedforward: int = 2048,
|
||||
num_encoder_layers: int = 12,
|
||||
num_encoder_layers: int = 6,
|
||||
num_decoder_layers: int = 6,
|
||||
dropout: float = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
|
@ -317,7 +317,7 @@ def compute_validation_loss(
|
||||
break
|
||||
batch = tuple(x.to(device) for x in batch)
|
||||
|
||||
# `batch` is actually a tuple.. we'll unpack it later.
|
||||
|
||||
loss = compute_loss(model, batch, is_training=False)
|
||||
num_frames = batch[4].sum()
|
||||
|
||||
@ -390,17 +390,23 @@ def train_one_epoch(
|
||||
params.batch_idx_train += 1
|
||||
batch = tuple(x.to(device) for x in batch)
|
||||
|
||||
loss = compute_loss(
|
||||
model=model,
|
||||
try:
|
||||
loss = compute_loss(
|
||||
model=model,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
)
|
||||
is_training=True,
|
||||
)
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
# We are not normalizing by the num-frames, but Adam/Madam are insensitive to the total
|
||||
# gradient scale so this should not matter.
|
||||
# clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
except RuntimeError as e:
|
||||
print(f"Error on batch of shape (N,T) = {batch[0].shape}")
|
||||
raise e
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward() # We are not normalizing by the num-frames, but Adam/Madam are insensitive to the total
|
||||
# gradient scale so this should not matter.
|
||||
# clip_grad_norm_(model.parameters(), 5.0, 2.0)
|
||||
optimizer.step()
|
||||
|
||||
loss_cpu = loss.detach().cpu().item()
|
||||
num_frames_cpu = batch[4].sum().cpu().item()
|
||||
|
Loading…
x
Reference in New Issue
Block a user