remove score sorting in test mode

This commit is contained in:
yaozengwei 2023-08-02 19:26:48 +08:00
parent 74bf02bba6
commit 42800f775e

View File

@ -23,6 +23,7 @@ import logging
import torch
import random
from encoder_interface import EncoderInterface
from icefall.utils import make_pad_mask
from scaling import (
Balancer,
BiasNorm,
@ -924,14 +925,13 @@ class LearnedDownsamplingModule(nn.Module):
scores = scores.squeeze(-1).t() # (batch_size, seq_len)
if self.training:
# sscores, indexes: (batch_size, seq_len)
sscores, indexes = scores.sort(dim=-1, descending=True)
weights = sscores.clamp(min=0.0, max=1.0)
weights = self.copy_weights1(weights)
if self.training:
d = self.downsampling_factor
seq_len_reduced = (seq_len + d - 1) // d
@ -947,27 +947,15 @@ class LearnedDownsamplingModule(nn.Module):
if random.random() < 0.01 or __name__ == '__main__':
logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, discarded-weights={weights_discarded.mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
if random.random() < 0.5:
# flipping it half the time increases the randomness, so gives an extra incentive
# to avoid nonzero weights in the discarded half
weights_discarded = weights_discarded.flip(dims=(1,))
weights = weights[:, :seq_len_reduced] - weights_discarded
else:
# test mode. because the sequence might be short, we keep all nonzero scores;
# and there is no need for any penalty.
# need to work out seq_len_reduced.
seq_len_reduced = max(1,
(weights > 0.0).to(torch.int32).sum(dim=-1).max().item())
if random.random() < 0.02:
logging.info(f"seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
weights = weights[:, :seq_len_reduced]
indexes = indexes[:, :seq_len_reduced]
weights = self.copy_weights2(weights)
# re-sort the indexes we kept, on index value, so that
@ -976,11 +964,33 @@ class LearnedDownsamplingModule(nn.Module):
# can remove this??)
indexes, reorder = indexes.sort(dim=-1)
weights = torch.gather(weights, dim=-1, index=reorder)
else:
# test mode. because the sequence might be short, we keep all nonzero scores;
# and there is no need for any penalty.
weights = scores.clamp(min=0.0, max=1.0)
mask = weights > 0.0
# The per-sample lengths we will keep
count = mask.to(torch.int32).sum(dim=-1)
# The columns we will keep
indexes = mask.nonzero(as_tuple=True)[1]
indexes = indexes.split(count.tolist())
# Padding with the last elements, e.g., with index=seq_len-1
indexes = torch.nn.utils.rnn.pad_sequence(
indexes, batch_first=True, padding_value=seq_len - 1
) # (batch_size, seq_len_reduced)
weights = torch.gather(weights, dim=-1, index=indexes)
padding_mask = make_pad_mask(count)
weights = weights.masked_fill(padding_mask, 0.0)
if random.random() < 0.02:
seq_len_reduced = indexes.shape[1]
logging.info(f"seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
x_downsampled = self.downsample(x, indexes)
return indexes, weights, x_downsampled
def downsample(self, x: Tensor, indexes: Tensor) -> Tensor:
"""
Downsamples x via indexing with the indexes obtained from the