mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Remove flipping of weights; reduce eps.
This commit is contained in:
parent
c487f9a0ef
commit
57a023902c
@ -879,7 +879,7 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
if random.random() < 0.01 or __name__ == '__main__':
|
if random.random() < 0.01 or __name__ == '__main__':
|
||||||
logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, discarded-weights={weights_discarded.mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, discarded-weights={weights_discarded.mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
||||||
|
|
||||||
weights_discarded = weights_discarded.flip(dims=(1,))
|
#weights_discarded = weights_discarded.flip(dims=(1,))
|
||||||
|
|
||||||
weights = (weights[:, :seq_len_reduced] - weights_discarded)
|
weights = (weights[:, :seq_len_reduced] - weights_discarded)
|
||||||
else:
|
else:
|
||||||
@ -966,7 +966,7 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
attn_offset: Tensor,
|
attn_offset: Tensor,
|
||||||
indexes: Tensor,
|
indexes: Tensor,
|
||||||
weights: Tensor,
|
weights: Tensor,
|
||||||
eps: float = 1.0e-03) -> Tensor:
|
eps: float = 2.0e-04) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Downsamples attn_offset and also modifies it to account for the weights in `weights`.
|
Downsamples attn_offset and also modifies it to account for the weights in `weights`.
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user