diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index fa9d47c35..1870818eb 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -502,7 +502,6 @@ class ActivationBalancer(torch.nn.Module): min_prob: float = 0.1, ): super(ActivationBalancer, self).__init__() - # CAUTION: this code expects self.batch_count to be overwritten in the main training # loop. self.batch_count = 0 @@ -998,8 +997,9 @@ class ScheduledFloat(torch.nn.Module): def __init__(self, *args): super().__init__() - # self.batch_count will be written to in the training loop. + # self.batch_count and self.name will be written to in the training loop. self.batch_count = 0 + self.name = '' assert len(args) >= 1 for (x,y) in args: assert x >= 0 @@ -1012,17 +1012,27 @@ class ScheduledFloat(torch.nn.Module): self.schedule) def __float__(self): + print_prob = 0.0001 + def maybe_print(ans): + if random.random() < print_prob: + logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}") batch_count = self.batch_count if batch_count <= self.schedule[0][0]: - return self.schedule[0][1] + ans = self.schedule[0][1] + maybe_print(ans) + return ans elif batch_count >= self.schedule[-1][0]: - return self.schedule[-1][1] + ans = self.schedule[-1][1] + maybe_print(ans) + return ans else: cur_x, cur_y = self.schedule[0] for i in range(1, len(self.schedule)): next_x, next_y = self.schedule[i] if batch_count >= cur_x and batch_count <= next_x: - return cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x) + ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x) + maybe_print(ans) + return ans cur_x, cur_y = next_x, next_y assert False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 47be7f4f3..138cba3db 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -95,9 +95,11 @@ def set_batch_count( if isinstance(model, DDP): # get underlying nn.Module model = model.module - for module in model.modules(): + for name, module in model.named_modules(): if hasattr(module, 'batch_count'): module.batch_count = batch_count + if hasattr(module, 'name'): + module.name = name def add_model_arguments(parser: argparse.ArgumentParser):