mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Introduce minimum probs in the SimpleCombiner
This commit is contained in:
parent
efbb1d25c7
commit
5e51534fbc
@ -178,7 +178,8 @@ class Zipformer(EncoderInterface):
|
||||
f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.")
|
||||
skip_layers.append(j)
|
||||
skip_modules.append(SimpleCombiner(self.encoder_dims[j],
|
||||
self.encoder_dims[i-1]))
|
||||
self.encoder_dims[i-1],
|
||||
min_weight=(0.0,0.25)))
|
||||
break
|
||||
self.skip_layers = skip_layers
|
||||
self.skip_modules = nn.ModuleList(skip_modules)
|
||||
@ -620,7 +621,8 @@ class DownsampledZipformerEncoder(nn.Module):
|
||||
self.encoder = encoder
|
||||
self.upsample = SimpleUpsample(output_dim, downsample)
|
||||
self.out_combiner = SimpleCombiner(input_dim,
|
||||
output_dim)
|
||||
output_dim,
|
||||
min_weight=(0.0, 0.25))
|
||||
|
||||
|
||||
def forward(self,
|
||||
@ -814,17 +816,18 @@ class SimpleCombiner(torch.nn.Module):
|
||||
learned weighted combination in the shared part of the dim.
|
||||
Args:
|
||||
dim1: the dimension of the first input, e.g. 256
|
||||
dim2: the dimension of the second input, e.g. 384. Require dim2 >= dim1.
|
||||
dim2: the dimension of the second input, e.g. 384.
|
||||
The output will have the same dimension as dim2.
|
||||
"""
|
||||
def __init__(self,
|
||||
dim1: int,
|
||||
dim2: int):
|
||||
dim2: int,
|
||||
min_weight: Tuple[float] = (0., 0.)):
|
||||
super(SimpleCombiner, self).__init__()
|
||||
assert dim2 >= dim1
|
||||
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01)
|
||||
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)
|
||||
|
||||
self.min_weight = min_weight
|
||||
|
||||
def forward(self,
|
||||
src1: Tensor,
|
||||
@ -844,15 +847,36 @@ class SimpleCombiner(torch.nn.Module):
|
||||
logit = (weight1 + weight2)
|
||||
|
||||
if self.training and random.random() < 0.1:
|
||||
logit = penalize_abs_values_gt(logit,
|
||||
limit=25.0,
|
||||
penalty=1.0e-04)
|
||||
logit = penalize_abs_values_gt(logit,
|
||||
limit=25.0,
|
||||
penalty=1.0e-04)
|
||||
|
||||
# `weight` will be the wight on src1.
|
||||
weight = logit.sigmoid()
|
||||
|
||||
src2_part1 = src2[...,:dim1]
|
||||
part1 = src1 * weight + src2_part1 * (1.0 - weight)
|
||||
part2 = src2[...,dim1:]
|
||||
return torch.cat((part1, part2), dim=-1)
|
||||
if self.training and self.min_weight != (0., 0.) and random.random() < 0.25:
|
||||
weight = weight.clamp(min=self.min_weight[0],
|
||||
max=1.0-self.min_weight[1])
|
||||
|
||||
|
||||
src1 = src1 * weight
|
||||
src2 = src2 * (1.0 - weight)
|
||||
|
||||
src1_dim = src1.shape[-1]
|
||||
src2_dim = src2.shape[-1]
|
||||
if src1_dim != src2_dim:
|
||||
if src1_dim < src2_dim:
|
||||
zeros_shape = list(src1.shape[:-1]) + [src2_dim - src1_dim]
|
||||
src1 = torch.cat((src1, torch.zeros(*zeros_shape,
|
||||
device=src1.device,
|
||||
dtype=src1.dtype)),
|
||||
dim=-1)
|
||||
else:
|
||||
src1 = src1[:src2_dim]
|
||||
|
||||
|
||||
return src1 + src2
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user