Add printing capability

This commit is contained in:
Daniel Povey 2022-11-14 14:14:01 +08:00
parent 4c8575878a
commit b32dec1119
2 changed files with 18 additions and 6 deletions

View File

@ -502,7 +502,6 @@ class ActivationBalancer(torch.nn.Module):
min_prob: float = 0.1, min_prob: float = 0.1,
): ):
super(ActivationBalancer, self).__init__() super(ActivationBalancer, self).__init__()
# CAUTION: this code expects self.batch_count to be overwritten in the main training # CAUTION: this code expects self.batch_count to be overwritten in the main training
# loop. # loop.
self.batch_count = 0 self.batch_count = 0
@ -998,8 +997,9 @@ class ScheduledFloat(torch.nn.Module):
def __init__(self, def __init__(self,
*args): *args):
super().__init__() super().__init__()
# self.batch_count will be written to in the training loop. # self.batch_count and self.name will be written to in the training loop.
self.batch_count = 0 self.batch_count = 0
self.name = ''
assert len(args) >= 1 assert len(args) >= 1
for (x,y) in args: for (x,y) in args:
assert x >= 0 assert x >= 0
@ -1012,17 +1012,27 @@ class ScheduledFloat(torch.nn.Module):
self.schedule) self.schedule)
def __float__(self): def __float__(self):
print_prob = 0.0001
def maybe_print(ans):
if random.random() < print_prob:
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
batch_count = self.batch_count batch_count = self.batch_count
if batch_count <= self.schedule[0][0]: if batch_count <= self.schedule[0][0]:
return self.schedule[0][1] ans = self.schedule[0][1]
maybe_print(ans)
return ans
elif batch_count >= self.schedule[-1][0]: elif batch_count >= self.schedule[-1][0]:
return self.schedule[-1][1] ans = self.schedule[-1][1]
maybe_print(ans)
return ans
else: else:
cur_x, cur_y = self.schedule[0] cur_x, cur_y = self.schedule[0]
for i in range(1, len(self.schedule)): for i in range(1, len(self.schedule)):
next_x, next_y = self.schedule[i] next_x, next_y = self.schedule[i]
if batch_count >= cur_x and batch_count <= next_x: if batch_count >= cur_x and batch_count <= next_x:
return cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x) ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
maybe_print(ans)
return ans
cur_x, cur_y = next_x, next_y cur_x, cur_y = next_x, next_y
assert False assert False

View File

@ -95,9 +95,11 @@ def set_batch_count(
if isinstance(model, DDP): if isinstance(model, DDP):
# get underlying nn.Module # get underlying nn.Module
model = model.module model = model.module
for module in model.modules(): for name, module in model.named_modules():
if hasattr(module, 'batch_count'): if hasattr(module, 'batch_count'):
module.batch_count = batch_count module.batch_count = batch_count
if hasattr(module, 'name'):
module.name = name
def add_model_arguments(parser: argparse.ArgumentParser): def add_model_arguments(parser: argparse.ArgumentParser):