mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add printing capability
This commit is contained in:
parent
4c8575878a
commit
b32dec1119
@ -502,7 +502,6 @@ class ActivationBalancer(torch.nn.Module):
|
|||||||
min_prob: float = 0.1,
|
min_prob: float = 0.1,
|
||||||
):
|
):
|
||||||
super(ActivationBalancer, self).__init__()
|
super(ActivationBalancer, self).__init__()
|
||||||
|
|
||||||
# CAUTION: this code expects self.batch_count to be overwritten in the main training
|
# CAUTION: this code expects self.batch_count to be overwritten in the main training
|
||||||
# loop.
|
# loop.
|
||||||
self.batch_count = 0
|
self.batch_count = 0
|
||||||
@ -998,8 +997,9 @@ class ScheduledFloat(torch.nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
*args):
|
*args):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# self.batch_count will be written to in the training loop.
|
# self.batch_count and self.name will be written to in the training loop.
|
||||||
self.batch_count = 0
|
self.batch_count = 0
|
||||||
|
self.name = ''
|
||||||
assert len(args) >= 1
|
assert len(args) >= 1
|
||||||
for (x,y) in args:
|
for (x,y) in args:
|
||||||
assert x >= 0
|
assert x >= 0
|
||||||
@ -1012,17 +1012,27 @@ class ScheduledFloat(torch.nn.Module):
|
|||||||
self.schedule)
|
self.schedule)
|
||||||
|
|
||||||
def __float__(self):
|
def __float__(self):
|
||||||
|
print_prob = 0.0001
|
||||||
|
def maybe_print(ans):
|
||||||
|
if random.random() < print_prob:
|
||||||
|
logging.info(f"ScheduledFloat: name={self.name}, batch_count={self.batch_count}, ans={ans}")
|
||||||
batch_count = self.batch_count
|
batch_count = self.batch_count
|
||||||
if batch_count <= self.schedule[0][0]:
|
if batch_count <= self.schedule[0][0]:
|
||||||
return self.schedule[0][1]
|
ans = self.schedule[0][1]
|
||||||
|
maybe_print(ans)
|
||||||
|
return ans
|
||||||
elif batch_count >= self.schedule[-1][0]:
|
elif batch_count >= self.schedule[-1][0]:
|
||||||
return self.schedule[-1][1]
|
ans = self.schedule[-1][1]
|
||||||
|
maybe_print(ans)
|
||||||
|
return ans
|
||||||
else:
|
else:
|
||||||
cur_x, cur_y = self.schedule[0]
|
cur_x, cur_y = self.schedule[0]
|
||||||
for i in range(1, len(self.schedule)):
|
for i in range(1, len(self.schedule)):
|
||||||
next_x, next_y = self.schedule[i]
|
next_x, next_y = self.schedule[i]
|
||||||
if batch_count >= cur_x and batch_count <= next_x:
|
if batch_count >= cur_x and batch_count <= next_x:
|
||||||
return cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
ans = cur_y + (next_y - cur_y) * (batch_count - cur_x) / (next_x - cur_x)
|
||||||
|
maybe_print(ans)
|
||||||
|
return ans
|
||||||
cur_x, cur_y = next_x, next_y
|
cur_x, cur_y = next_x, next_y
|
||||||
assert False
|
assert False
|
||||||
|
|
||||||
|
|||||||
@ -95,9 +95,11 @@ def set_batch_count(
|
|||||||
if isinstance(model, DDP):
|
if isinstance(model, DDP):
|
||||||
# get underlying nn.Module
|
# get underlying nn.Module
|
||||||
model = model.module
|
model = model.module
|
||||||
for module in model.modules():
|
for name, module in model.named_modules():
|
||||||
if hasattr(module, 'batch_count'):
|
if hasattr(module, 'batch_count'):
|
||||||
module.batch_count = batch_count
|
module.batch_count = batch_count
|
||||||
|
if hasattr(module, 'name'):
|
||||||
|
module.name = name
|
||||||
|
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user