mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Bug fices
This commit is contained in:
parent
5fe8cb134f
commit
81542832bf
@ -342,7 +342,7 @@ class ConformerEncoder(nn.Module):
|
|||||||
# This ensures that if we are using multiple worker processes, they all use the same
|
# This ensures that if we are using multiple worker processes, they all use the same
|
||||||
# random numbers, so they will all take about the same amount of time to process
|
# random numbers, so they will all take about the same amount of time to process
|
||||||
# the batch.
|
# the batch.
|
||||||
r = random.Random(self.count)
|
rng = random.Random(self.count)
|
||||||
self.count += 1
|
self.count += 1
|
||||||
|
|
||||||
def get_random_mask():
|
def get_random_mask():
|
||||||
@ -350,7 +350,7 @@ class ConformerEncoder(nn.Module):
|
|||||||
mask = torch.ones(num_layers, device='cpu')
|
mask = torch.ones(num_layers, device='cpu')
|
||||||
if self.training:
|
if self.training:
|
||||||
return mask
|
return mask
|
||||||
r = r.random()
|
r = rng.random()
|
||||||
if r < 0.1:
|
if r < 0.1:
|
||||||
# drop zero layers, to make sure that sometimes we see the complete network.
|
# drop zero layers, to make sure that sometimes we see the complete network.
|
||||||
return mask
|
return mask
|
||||||
@ -358,15 +358,16 @@ class ConformerEncoder(nn.Module):
|
|||||||
if r < 0.1 + 0.25:
|
if r < 0.1 + 0.25:
|
||||||
# with prob 0.25: completely drop the last n layers. let n
|
# with prob 0.25: completely drop the last n layers. let n
|
||||||
# be a multiple of 3 (this is what we used to do with aux_layers).
|
# be a multiple of 3 (this is what we used to do with aux_layers).
|
||||||
final_layers_dropped = 3 * r.randint(1, num_layers // 3)
|
final_layers_dropped = 3 * rng.randint(1, num_layers // 3)
|
||||||
mask[-final_layers_dropped:] = 0.0
|
mask[-final_layers_dropped:] = 0.0
|
||||||
|
|
||||||
layer_drop_prob = 0.075
|
layer_drop_prob = 0.075
|
||||||
for i in range(final_layers_dropped):
|
for i in range(final_layers_dropped):
|
||||||
mask[i] = (r.random() > layer_drop_prob)
|
mask[i] = (rng.random() > layer_drop_prob)
|
||||||
|
|
||||||
if mask.sum() == 0.0:
|
if mask.sum() == 0.0:
|
||||||
mask[0] = 1.0
|
mask[0] = 1.0
|
||||||
|
return mask
|
||||||
|
|
||||||
mask = get_random_mask()
|
mask = get_random_mask()
|
||||||
device = self.to_layerdrop_scales[0].weight.device
|
device = self.to_layerdrop_scales[0].weight.device
|
||||||
layerdrop_scales = 1.0 + self.to_layerdrop_scales(mask.to(device))
|
layerdrop_scales = 1.0 + self.to_layerdrop_scales(mask.to(device))
|
||||||
@ -1257,6 +1258,13 @@ def _test_conformer_main():
|
|||||||
warmup=0.5,
|
warmup=0.5,
|
||||||
)
|
)
|
||||||
f # to remove flake8 warnings
|
f # to remove flake8 warnings
|
||||||
|
c.eval()
|
||||||
|
f = c(
|
||||||
|
torch.randn(batch_size, seq_len, feature_dim),
|
||||||
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
|
warmup=0.5,
|
||||||
|
)
|
||||||
|
f # to remove flake8 warnings
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user