Remove flipping of weights; reduce eps.

This commit is contained in:
Daniel Povey 2023-05-18 19:50:16 +08:00
parent c487f9a0ef
commit 57a023902c

View File

@ -879,7 +879,7 @@ class LearnedDownsamplingModule(nn.Module):
if random.random() < 0.01 or __name__ == '__main__': if random.random() < 0.01 or __name__ == '__main__':
logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, discarded-weights={weights_discarded.mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}") logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, discarded-weights={weights_discarded.mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
weights_discarded = weights_discarded.flip(dims=(1,)) #weights_discarded = weights_discarded.flip(dims=(1,))
weights = (weights[:, :seq_len_reduced] - weights_discarded) weights = (weights[:, :seq_len_reduced] - weights_discarded)
else: else:
@ -966,7 +966,7 @@ class LearnedDownsamplingModule(nn.Module):
attn_offset: Tensor, attn_offset: Tensor,
indexes: Tensor, indexes: Tensor,
weights: Tensor, weights: Tensor,
eps: float = 1.0e-03) -> Tensor: eps: float = 2.0e-04) -> Tensor:
""" """
Downsamples attn_offset and also modifies it to account for the weights in `weights`. Downsamples attn_offset and also modifies it to account for the weights in `weights`.
Args: Args: