mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
This should just be a cosmetic change, regularizing how we get the warmup times from the layers.
This commit is contained in:
parent
7d8e460a53
commit
ae6478c687
@ -436,16 +436,14 @@ class ConformerEncoder(nn.Module):
|
|||||||
def get_layers_to_drop(self, rnd_seed: int, warmup_count: float):
|
def get_layers_to_drop(self, rnd_seed: int, warmup_count: float):
|
||||||
|
|
||||||
num_layers = len(self.layers)
|
num_layers = len(self.layers)
|
||||||
warmup_begin = self.warmup_begin
|
|
||||||
warmup_end = self.warmup_end
|
|
||||||
|
|
||||||
def get_layerdrop_prob(layer: int) -> float:
|
def get_layerdrop_prob(layer: int) -> float:
|
||||||
layer_warmup_delta = (warmup_end - warmup_begin) / num_layers
|
layer_warmup_begin = self.layers[layer].warmup_begin
|
||||||
layer_warmup_begin = warmup_begin + layer * layer_warmup_delta
|
layer_warmup_end = self.layers[layer].warmup_end
|
||||||
|
|
||||||
initial_layerdrop_prob = 0.5
|
initial_layerdrop_prob = 0.5
|
||||||
final_layerdrop_prob = 0.05
|
final_layerdrop_prob = 0.05
|
||||||
|
|
||||||
layer_warmup_end = layer_warmup_begin + layer_warmup_delta
|
|
||||||
if warmup_count < layer_warmup_begin:
|
if warmup_count < layer_warmup_begin:
|
||||||
return initial_layerdrop_prob
|
return initial_layerdrop_prob
|
||||||
elif warmup_count > layer_warmup_end:
|
elif warmup_count > layer_warmup_end:
|
||||||
@ -483,7 +481,7 @@ class ConformerEncoder(nn.Module):
|
|||||||
if len(ans) == num_to_drop:
|
if len(ans) == num_to_drop:
|
||||||
break
|
break
|
||||||
if shared_rng.random() < 0.005 or __name__ == "__main__":
|
if shared_rng.random() < 0.005 or __name__ == "__main__":
|
||||||
logging.info(f"warmup_begin={warmup_begin:.1f}, warmup_end={warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
|
logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, warmup_count={warmup_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}")
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user