mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add printing capability
This commit is contained in:
parent
4c8575878a
commit
b32dec1119
@ -502,7 +502,6 @@ class ActivationBalancer(torch.nn.Module):
|
||||
min_prob: float = 0.1,
|
||||
):
|
||||
super(ActivationBalancer, self).__init__()
|
||||
|
||||
# CAUTION: this code expects self.batch_count to be overwritten in the main training
|
||||
# loop.
|
||||
self.batch_count = 0
|
||||
@ -998,8 +997,9 @@ class ScheduledFloat(torch.nn.Module):
|
||||
def __init__(self,
|
||||
*args):
|
||||
super().__init__()
|
||||
# self.batch_count will be written to in the training loop.
|
||||
# self.batch_count and self.name will be written to in the training loop.
|
||||
self.batch_count = 0
|
||||
self.name = ''
|
||||
assert len(args) >= 1
|
||||
for (x,y) in args:
|
||||
assert x >= 0
|
||||
@ -1012,17 +1012,27 @@ class ScheduledFloat(torch.nn.Module):
|
||||
self.schedule)
|
||||
|
||||
def __float__(self):
|
||||
print_prob = 0.0001
|
||||
def maybe_print(ans):
|
||||
if random.random() < print_prob:
|
||||
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
|
||||
batch_count = self.batch_count
|
||||
if batch_count <= self.schedule[0][0]:
|
||||
return self.schedule[0][1]
|
||||
ans = self.schedule[0][1]
|
||||
maybe_print(ans)
|
||||
return ans
|
||||
elif batch_count >= self.schedule[-1][0]:
|
||||
return self.schedule[-1][1]
|
||||
ans = self.schedule[-1][1]
|
||||
maybe_print(ans)
|
||||
return ans
|
||||
else:
|
||||
cur_x, cur_y = self.schedule[0]
|
||||
for i in range(1, len(self.schedule)):
|
||||
next_x, next_y = self.schedule[i]
|
||||
if batch_count >= cur_x and batch_count <= next_x:
|
||||
return cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
||||
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
||||
maybe_print(ans)
|
||||
return ans
|
||||
cur_x, cur_y = next_x, next_y
|
||||
assert False
|
||||
|
||||
|
||||
@ -95,9 +95,11 @@ def set_batch_count(
|
||||
if isinstance(model, DDP):
|
||||
# get underlying nn.Module
|
||||
model = model.module
|
||||
for module in model.modules():
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, 'batch_count'):
|
||||
module.batch_count = batch_count
|
||||
if hasattr(module, 'name'):
|
||||
module.name = name
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user