diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 1870818eb..2f46be5f1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -992,14 +992,19 @@ class ScheduledFloat(torch.nn.Module): first x or after the last x, we just use the first or last y value. Example: - self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0)) + self.dropout = ScheduledFloat((0.0, 0.2), (4000.0, 0.0), default=0.0) + + `default` is used when self.batch_count is not set or in training or mode or in + torch.jit scripting mode. """ def __init__(self, - *args): + *args, + default: float = 0.0): super().__init__() # self.batch_count and self.name will be written to in the training loop. - self.batch_count = 0 - self.name = '' + self.batch_count = None + self.name = None + self.default = default assert len(args) >= 1 for (x,y) in args: assert x >= 0 @@ -1012,19 +1017,21 @@ class ScheduledFloat(torch.nn.Module): self.schedule) def __float__(self): - print_prob = 0.0001 + print_prob = 0.0002 def maybe_print(ans): if random.random() < print_prob: logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") batch_count = self.batch_count + if batch_count is None or self.training or torch.jit.is_scripting(): + return float(self.default) if batch_count <= self.schedule[0][0]: ans = self.schedule[0][1] maybe_print(ans) - return ans + return float(ans) elif batch_count >= self.schedule[-1][0]: ans = self.schedule[-1][1] maybe_print(ans) - return ans + return float(ans) else: cur_x, cur_y = self.schedule[0] for i in range(1, len(self.schedule)): @@ -1032,7 +1039,7 @@ class ScheduledFloat(torch.nn.Module): if batch_count >= cur_x and batch_count <= next_x: ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x) maybe_print(ans) - return ans + return float(ans) cur_x, cur_y = next_x, next_y assert False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index f42022498..313fdf56e 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -365,9 +365,9 @@ class ZipformerEncoderLayer(nn.Module): # layer_skip_prob will be overwritten to change warmup begin and end times. # treating batch_index == 0.0 specially is just to get scan_pessimistic_batches_for_oom() # to work correctly. - layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.0), (1.0, 0.5), (2000.0, 0.05)), - dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.0), (1.0, 0.2), (2000.0, 0.0)), - bypass_clamp_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25)), + layer_skip_prob: FloatLike = ScheduledFloat((0.0, 0.5), (2000.0, 0.05), default=0), + dynamic_skip_prob: FloatLike = ScheduledFloat((0.0, 0.2), (2000.0, 0.0), default=0), + bypass_clamp_min: FloatLike = ScheduledFloat((0.0, 0.75), (20000.0, 0.25), default=0), bypass_clamp_max: FloatLike = 1.0, ) -> None: super(ZipformerEncoderLayer, self).__init__() @@ -568,10 +568,9 @@ class ZipformerEncoder(nn.Module): for i in range(num_layers): cur_end = cur_begin + delta # treating batch_index=0.0 specially is just to get scan_pessimistic_batches_for_oom() - self.layers[i].layer_skip_prob = ScheduledFloat((0.0, 0.0), - (1.0, initial_layerdrop_prob), - (cur_begin, initial_layerdrop_prob), - (cur_end, final_layerdrop_prob)) + self.layers[i].layer_skip_prob = ScheduledFloat((cur_begin, initial_layerdrop_prob), + (cur_end, final_layerdrop_prob), + default=0.0) cur_begin = cur_end @@ -1028,7 +1027,7 @@ class RelPositionMultiheadAttentionWeights(nn.Module): pos_head_dim: int, dropout: float = 0.0, pos_emb_skip: FloatLike = ScheduledFloat((0.0, 0.5), - (4000.0, 0.025)) + (4000.0, 0.025), default=0.0) ) -> None: super().__init__() self.embed_dim = embed_dim