mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Try removing weight_scale
This commit is contained in:
parent
d2c198c072
commit
c487f9a0ef
@ -811,7 +811,7 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
downsampling_factor: int,
|
downsampling_factor: int,
|
||||||
weight_scale: float = 10.0):
|
weight_scale: float = 1.0):
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -966,7 +966,7 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
attn_offset: Tensor,
|
attn_offset: Tensor,
|
||||||
indexes: Tensor,
|
indexes: Tensor,
|
||||||
weights: Tensor,
|
weights: Tensor,
|
||||||
eps: float = 1.0e-05) -> Tensor:
|
eps: float = 1.0e-03) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Downsamples attn_offset and also modifies it to account for the weights in `weights`.
|
Downsamples attn_offset and also modifies it to account for the weights in `weights`.
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user