mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Bug fices
This commit is contained in:
parent
5fe8cb134f
commit
81542832bf
@ -342,7 +342,7 @@ class ConformerEncoder(nn.Module):
|
||||
# This ensures that if we are using multiple worker processes, they all use the same
|
||||
# random numbers, so they will all take about the same amount of time to process
|
||||
# the batch.
|
||||
r = random.Random(self.count)
|
||||
rng = random.Random(self.count)
|
||||
self.count += 1
|
||||
|
||||
def get_random_mask():
|
||||
@ -350,7 +350,7 @@ class ConformerEncoder(nn.Module):
|
||||
mask = torch.ones(num_layers, device='cpu')
|
||||
if self.training:
|
||||
return mask
|
||||
r = r.random()
|
||||
r = rng.random()
|
||||
if r < 0.1:
|
||||
# drop zero layers, to make sure that sometimes we see the complete network.
|
||||
return mask
|
||||
@ -358,15 +358,16 @@ class ConformerEncoder(nn.Module):
|
||||
if r < 0.1 + 0.25:
|
||||
# with prob 0.25: completely drop the last n layers. let n
|
||||
# be a multiple of 3 (this is what we used to do with aux_layers).
|
||||
final_layers_dropped = 3 * r.randint(1, num_layers // 3)
|
||||
final_layers_dropped = 3 * rng.randint(1, num_layers // 3)
|
||||
mask[-final_layers_dropped:] = 0.0
|
||||
|
||||
layer_drop_prob = 0.075
|
||||
for i in range(final_layers_dropped):
|
||||
mask[i] = (r.random() > layer_drop_prob)
|
||||
|
||||
mask[i] = (rng.random() > layer_drop_prob)
|
||||
if mask.sum() == 0.0:
|
||||
mask[0] = 1.0
|
||||
return mask
|
||||
|
||||
mask = get_random_mask()
|
||||
device = self.to_layerdrop_scales[0].weight.device
|
||||
layerdrop_scales = 1.0 + self.to_layerdrop_scales(mask.to(device))
|
||||
@ -1257,6 +1258,13 @@ def _test_conformer_main():
|
||||
warmup=0.5,
|
||||
)
|
||||
f # to remove flake8 warnings
|
||||
c.eval()
|
||||
f = c(
|
||||
torch.randn(batch_size, seq_len, feature_dim),
|
||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||
warmup=0.5,
|
||||
)
|
||||
f # to remove flake8 warnings
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user