mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
This should just be a cosmetic change, regularizing how we get the warmup times from the layers.
This commit is contained in:
parent
7d8e460a53
commit
ae6478c687
@ -436,16 +436,14 @@ class ConformerEncoder(nn.Module):
|
||||
def get_layers_to_drop(self, rnd_seed: int, warmup_count: float):
|
||||
|
||||
num_layers = len(self.layers)
|
||||
warmup_begin = self.warmup_begin
|
||||
warmup_end = self.warmup_end
|
||||
|
||||
def get_layerdrop_prob(layer: int) -> float:
|
||||
layer_warmup_delta = (warmup_end - warmup_begin) / num_layers
|
||||
layer_warmup_begin = warmup_begin + layer * layer_warmup_delta
|
||||
layer_warmup_begin = self.layers[layer].warmup_begin
|
||||
layer_warmup_end = self.layers[layer].warmup_end
|
||||
|
||||
initial_layerdrop_prob = 0.5
|
||||
final_layerdrop_prob = 0.05
|
||||
|
||||
layer_warmup_end = layer_warmup_begin + layer_warmup_delta
|
||||
if warmup_count < layer_warmup_begin:
|
||||
return initial_layerdrop_prob
|
||||
elif warmup_count > layer_warmup_end:
|
||||
@ -483,7 +481,7 @@ class ConformerEncoder(nn.Module):
|
||||
if len(ans) == num_to_drop:
|
||||
break
|
||||
if shared_rng.random() < 0.005 or __name__ == "__main__":
|
||||
logging.info(f"warmup_begin={warmup_begin:.1f}, warmup_end={warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
|
||||
logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
|
||||
return ans
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user