Try removing weight_scale

This commit is contained in:
Daniel Povey 2023-05-18 18:41:39 +08:00
parent d2c198c072
commit c487f9a0ef

View File

@ -811,7 +811,7 @@ class LearnedDownsamplingModule(nn.Module):
def __init__(self,
embed_dim: int,
downsampling_factor: int,
weight_scale: float = 10.0):
weight_scale: float = 1.0):
super().__init__()
@ -966,7 +966,7 @@ class LearnedDownsamplingModule(nn.Module):
attn_offset: Tensor,
indexes: Tensor,
weights: Tensor,
eps: float = 1.0e-05) -> Tensor:
eps: float = 1.0e-03) -> Tensor:
"""
Downsamples attn_offset and also modifies it to account for the weights in `weights`.
Args: