mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
remove score sorting in test mode
This commit is contained in:
parent
74bf02bba6
commit
42800f775e
@ -23,6 +23,7 @@ import logging
|
||||
import torch
|
||||
import random
|
||||
from encoder_interface import EncoderInterface
|
||||
from icefall.utils import make_pad_mask
|
||||
from scaling import (
|
||||
Balancer,
|
||||
BiasNorm,
|
||||
@ -924,14 +925,13 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
|
||||
scores = scores.squeeze(-1).t() # (batch_size, seq_len)
|
||||
|
||||
if self.training:
|
||||
# sscores, indexes: (batch_size, seq_len)
|
||||
sscores, indexes = scores.sort(dim=-1, descending=True)
|
||||
|
||||
|
||||
weights = sscores.clamp(min=0.0, max=1.0)
|
||||
weights = self.copy_weights1(weights)
|
||||
|
||||
if self.training:
|
||||
d = self.downsampling_factor
|
||||
seq_len_reduced = (seq_len + d - 1) // d
|
||||
|
||||
@ -947,27 +947,15 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
if random.random() < 0.01 or __name__ == '__main__':
|
||||
logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, discarded-weights={weights_discarded.mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
||||
|
||||
|
||||
if random.random() < 0.5:
|
||||
# flipping it half the time increases the randomness, so gives an extra incentive
|
||||
# to avoid nonzero weights in the discarded half
|
||||
weights_discarded = weights_discarded.flip(dims=(1,))
|
||||
|
||||
weights = weights[:, :seq_len_reduced] - weights_discarded
|
||||
else:
|
||||
# test mode. because the sequence might be short, we keep all nonzero scores;
|
||||
# and there is no need for any penalty.
|
||||
|
||||
# need to work out seq_len_reduced.
|
||||
seq_len_reduced = max(1,
|
||||
(weights > 0.0).to(torch.int32).sum(dim=-1).max().item())
|
||||
if random.random() < 0.02:
|
||||
logging.info(f"seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
||||
weights = weights[:, :seq_len_reduced]
|
||||
|
||||
indexes = indexes[:, :seq_len_reduced]
|
||||
|
||||
|
||||
weights = self.copy_weights2(weights)
|
||||
|
||||
# re-sort the indexes we kept, on index value, so that
|
||||
@ -976,11 +964,33 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
# can remove this??)
|
||||
indexes, reorder = indexes.sort(dim=-1)
|
||||
weights = torch.gather(weights, dim=-1, index=reorder)
|
||||
else:
|
||||
# test mode. because the sequence might be short, we keep all nonzero scores;
|
||||
# and there is no need for any penalty.
|
||||
weights = scores.clamp(min=0.0, max=1.0)
|
||||
mask = weights > 0.0
|
||||
# The per-sample lengths we will keep
|
||||
count = mask.to(torch.int32).sum(dim=-1)
|
||||
# The columns we will keep
|
||||
indexes = mask.nonzero(as_tuple=True)[1]
|
||||
indexes = indexes.split(count.tolist())
|
||||
# Padding with the last elements, e.g., with index=seq_len-1
|
||||
indexes = torch.nn.utils.rnn.pad_sequence(
|
||||
indexes, batch_first=True, padding_value=seq_len - 1
|
||||
) # (batch_size, seq_len_reduced)
|
||||
|
||||
weights = torch.gather(weights, dim=-1, index=indexes)
|
||||
|
||||
padding_mask = make_pad_mask(count)
|
||||
weights = weights.masked_fill(padding_mask, 0.0)
|
||||
|
||||
if random.random() < 0.02:
|
||||
seq_len_reduced = indexes.shape[1]
|
||||
logging.info(f"seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
||||
|
||||
x_downsampled = self.downsample(x, indexes)
|
||||
return indexes, weights, x_downsampled
|
||||
|
||||
|
||||
def downsample(self, x: Tensor, indexes: Tensor) -> Tensor:
|
||||
"""
|
||||
Downsamples x via indexing with the indexes obtained from the
|
||||
|
Loading…
x
Reference in New Issue
Block a user