mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Fix bug where fewer layers were dropped than should be; remove unnecesary print statement.
This commit is contained in:
parent
09c9b02f6f
commit
857b3735e7
@ -18,6 +18,7 @@
|
||||
import copy
|
||||
import math
|
||||
import warnings
|
||||
import itertools
|
||||
from typing import List, Optional, Tuple, Union
|
||||
import logging
|
||||
import torch
|
||||
@ -473,11 +474,10 @@ class ConformerEncoder(nn.Module):
|
||||
|
||||
layers = list(range(num_layers))
|
||||
independent_rng.shuffle(layers)
|
||||
# go through the shuffled layers twice, in case, the first time round,
|
||||
# we did not drop out the target number of layers.
|
||||
layers = layers + layers
|
||||
for layer in layers:
|
||||
if independent_rng.random() < get_layerdrop_prob(layer):
|
||||
|
||||
# go through the shuffled layers until we get the required number of samples.
|
||||
for layer in itertools.cycle(layers):
|
||||
if independent_rng.random() < layerdrop_probs[layer]:
|
||||
ans.add(layer)
|
||||
if len(ans) == num_to_drop:
|
||||
break
|
||||
|
||||
@ -323,7 +323,6 @@ class ScaledAdam(BatchedOptimizer):
|
||||
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
||||
|
||||
if step % clipping_update_period == 0:
|
||||
print(f"step = {step}")
|
||||
# Print some stats.
|
||||
# We don't reach here if step == 0 because we would have returned
|
||||
# above.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user