Introduce minimum probs in the SimpleCombiner

This commit is contained in:
Daniel Povey 2022-10-31 17:02:21 +08:00
parent efbb1d25c7
commit 5e51534fbc

View File

@ -178,7 +178,8 @@ class Zipformer(EncoderInterface):
f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.")
skip_layers.append(j)
skip_modules.append(SimpleCombiner(self.encoder_dims[j],
self.encoder_dims[i-1]))
self.encoder_dims[i-1],
min_weight=(0.0,0.25)))
break
self.skip_layers = skip_layers
self.skip_modules = nn.ModuleList(skip_modules)
@ -620,7 +621,8 @@ class DownsampledZipformerEncoder(nn.Module):
self.encoder = encoder
self.upsample = SimpleUpsample(output_dim, downsample)
self.out_combiner = SimpleCombiner(input_dim,
output_dim)
output_dim,
min_weight=(0.0, 0.25))
def forward(self,
@ -814,17 +816,18 @@ class SimpleCombiner(torch.nn.Module):
learned weighted combination in the shared part of the dim.
Args:
dim1: the dimension of the first input, e.g. 256
dim2: the dimension of the second input, e.g. 384. Require dim2 >= dim1.
dim2: the dimension of the second input, e.g. 384.
The output will have the same dimension as dim2.
"""
def __init__(self,
dim1: int,
dim2: int):
dim2: int,
min_weight: Tuple[float] = (0., 0.)):
super(SimpleCombiner, self).__init__()
assert dim2 >= dim1
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01)
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)
self.min_weight = min_weight
def forward(self,
src1: Tensor,
@ -844,15 +847,36 @@ class SimpleCombiner(torch.nn.Module):
logit = (weight1 + weight2)
if self.training and random.random() < 0.1:
logit = penalize_abs_values_gt(logit,
limit=25.0,
penalty=1.0e-04)
logit = penalize_abs_values_gt(logit,
limit=25.0,
penalty=1.0e-04)
# `weight` will be the wight on src1.
weight = logit.sigmoid()
src2_part1 = src2[...,:dim1]
part1 = src1 * weight + src2_part1 * (1.0 - weight)
part2 = src2[...,dim1:]
return torch.cat((part1, part2), dim=-1)
if self.training and self.min_weight != (0., 0.) and random.random() < 0.25:
weight = weight.clamp(min=self.min_weight[0],
max=1.0-self.min_weight[1])
src1 = src1 * weight
src2 = src2 * (1.0 - weight)
src1_dim = src1.shape[-1]
src2_dim = src2.shape[-1]
if src1_dim != src2_dim:
if src1_dim < src2_dim:
zeros_shape = list(src1.shape[:-1]) + [src2_dim - src1_dim]
src1 = torch.cat((src1, torch.zeros(*zeros_shape,
device=src1.device,
dtype=src1.dtype)),
dim=-1)
else:
src1 = src1[:src2_dim]
return src1 + src2