Add printing capability

This commit is contained in:
Daniel Povey 2022-11-14 14:14:01 +08:00
parent 4c8575878a
commit b32dec1119
2 changed files with 18 additions and 6 deletions

View File

@ -502,7 +502,6 @@ class ActivationBalancer(torch.nn.Module):
min_prob: float = 0.1,
):
super(ActivationBalancer, self).__init__()
# CAUTION: this code expects self.batch_count to be overwritten in the main training
# loop.
self.batch_count = 0
@ -998,8 +997,9 @@ class ScheduledFloat(torch.nn.Module):
def __init__(self,
*args):
super().__init__()
# self.batch_count will be written to in the training loop.
# self.batch_count and self.name will be written to in the training loop.
self.batch_count = 0
self.name = ''
assert len(args) >= 1
for (x,y) in args:
assert x >= 0
@ -1012,17 +1012,27 @@ class ScheduledFloat(torch.nn.Module):
self.schedule)
def __float__(self):
print_prob = 0.0001
def maybe_print(ans):
if random.random() < print_prob:
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
batch_count = self.batch_count
if batch_count <= self.schedule[0][0]:
return self.schedule[0][1]
ans = self.schedule[0][1]
maybe_print(ans)
return ans
elif batch_count >= self.schedule[-1][0]:
return self.schedule[-1][1]
ans = self.schedule[-1][1]
maybe_print(ans)
return ans
else:
cur_x, cur_y = self.schedule[0]
for i in range(1, len(self.schedule)):
next_x, next_y = self.schedule[i]
if batch_count >= cur_x and batch_count <= next_x:
return cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
maybe_print(ans)
return ans
cur_x, cur_y = next_x, next_y
assert False

View File

@ -95,9 +95,11 @@ def set_batch_count(
if isinstance(model, DDP):
# get underlying nn.Module
model = model.module
for module in model.modules():
for name, module in model.named_modules():
if hasattr(module, 'batch_count'):
module.batch_count = batch_count
if hasattr(module, 'name'):
module.name = name
def add_model_arguments(parser: argparse.ArgumentParser):