Bug fices

This commit is contained in:
Daniel Povey 2022-10-04 22:34:24 +08:00
parent 5fe8cb134f
commit 81542832bf

View File

@ -342,7 +342,7 @@ class ConformerEncoder(nn.Module):
# This ensures that if we are using multiple worker processes, they all use the same
# random numbers, so they will all take about the same amount of time to process
# the batch.
r = random.Random(self.count)
rng = random.Random(self.count)
self.count += 1
def get_random_mask():
@ -350,7 +350,7 @@ class ConformerEncoder(nn.Module):
mask = torch.ones(num_layers, device='cpu')
if self.training:
return mask
r = r.random()
r = rng.random()
if r < 0.1:
# drop zero layers, to make sure that sometimes we see the complete network.
return mask
@ -358,15 +358,16 @@ class ConformerEncoder(nn.Module):
if r < 0.1 + 0.25:
# with prob 0.25: completely drop the last n layers. let n
# be a multiple of 3 (this is what we used to do with aux_layers).
final_layers_dropped = 3 * r.randint(1, num_layers // 3)
final_layers_dropped = 3 * rng.randint(1, num_layers // 3)
mask[-final_layers_dropped:] = 0.0
layer_drop_prob = 0.075
for i in range(final_layers_dropped):
mask[i] = (r.random() > layer_drop_prob)
mask[i] = (rng.random() > layer_drop_prob)
if mask.sum() == 0.0:
mask[0] = 1.0
return mask
mask = get_random_mask()
device = self.to_layerdrop_scales[0].weight.device
layerdrop_scales = 1.0 + self.to_layerdrop_scales(mask.to(device))
@ -1257,6 +1258,13 @@ def _test_conformer_main():
warmup=0.5,
)
f # to remove flake8 warnings
c.eval()
f = c(
torch.randn(batch_size, seq_len, feature_dim),
torch.full((batch_size,), seq_len, dtype=torch.int64),
warmup=0.5,
)
f # to remove flake8 warnings