mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
remove score sorting in test mode
This commit is contained in:
parent
74bf02bba6
commit
42800f775e
@ -23,6 +23,7 @@ import logging
|
|||||||
import torch
|
import torch
|
||||||
import random
|
import random
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
from icefall.utils import make_pad_mask
|
||||||
from scaling import (
|
from scaling import (
|
||||||
Balancer,
|
Balancer,
|
||||||
BiasNorm,
|
BiasNorm,
|
||||||
@ -924,14 +925,13 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
|
|
||||||
scores = scores.squeeze(-1).t() # (batch_size, seq_len)
|
scores = scores.squeeze(-1).t() # (batch_size, seq_len)
|
||||||
|
|
||||||
# sscores, indexes: (batch_size, seq_len)
|
|
||||||
sscores, indexes = scores.sort(dim=-1, descending=True)
|
|
||||||
|
|
||||||
|
|
||||||
weights = sscores.clamp(min=0.0, max=1.0)
|
|
||||||
weights = self.copy_weights1(weights)
|
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
|
# sscores, indexes: (batch_size, seq_len)
|
||||||
|
sscores, indexes = scores.sort(dim=-1, descending=True)
|
||||||
|
|
||||||
|
weights = sscores.clamp(min=0.0, max=1.0)
|
||||||
|
weights = self.copy_weights1(weights)
|
||||||
|
|
||||||
d = self.downsampling_factor
|
d = self.downsampling_factor
|
||||||
seq_len_reduced = (seq_len + d - 1) // d
|
seq_len_reduced = (seq_len + d - 1) // d
|
||||||
|
|
||||||
@ -947,40 +947,50 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
if random.random() < 0.01 or __name__ == '__main__':
|
if random.random() < 0.01 or __name__ == '__main__':
|
||||||
logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, discarded-weights={weights_discarded.mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
logging.info(f"mean weight={weights.mean()}, mean-abs-scores={scores.abs().mean()} positive-scores={(scores>0).to(torch.float32).mean()}, discarded-weights={weights_discarded.mean()}, seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
||||||
|
|
||||||
|
|
||||||
if random.random() < 0.5:
|
if random.random() < 0.5:
|
||||||
# flipping it half the time increases the randomness, so gives an extra incentive
|
# flipping it half the time increases the randomness, so gives an extra incentive
|
||||||
# to avoid nonzero weights in the discarded half
|
# to avoid nonzero weights in the discarded half
|
||||||
weights_discarded = weights_discarded.flip(dims=(1,))
|
weights_discarded = weights_discarded.flip(dims=(1,))
|
||||||
|
|
||||||
weights = weights[:, :seq_len_reduced] - weights_discarded
|
weights = weights[:, :seq_len_reduced] - weights_discarded
|
||||||
|
|
||||||
|
indexes = indexes[:, :seq_len_reduced]
|
||||||
|
|
||||||
|
weights = self.copy_weights2(weights)
|
||||||
|
|
||||||
|
# re-sort the indexes we kept, on index value, so that
|
||||||
|
# masking for causal models will be in the correct order.
|
||||||
|
# (actually this may not really matter, TODO: see whether we
|
||||||
|
# can remove this??)
|
||||||
|
indexes, reorder = indexes.sort(dim=-1)
|
||||||
|
weights = torch.gather(weights, dim=-1, index=reorder)
|
||||||
else:
|
else:
|
||||||
# test mode. because the sequence might be short, we keep all nonzero scores;
|
# test mode. because the sequence might be short, we keep all nonzero scores;
|
||||||
# and there is no need for any penalty.
|
# and there is no need for any penalty.
|
||||||
|
weights = scores.clamp(min=0.0, max=1.0)
|
||||||
|
mask = weights > 0.0
|
||||||
|
# The per-sample lengths we will keep
|
||||||
|
count = mask.to(torch.int32).sum(dim=-1)
|
||||||
|
# The columns we will keep
|
||||||
|
indexes = mask.nonzero(as_tuple=True)[1]
|
||||||
|
indexes = indexes.split(count.tolist())
|
||||||
|
# Padding with the last elements, e.g., with index=seq_len-1
|
||||||
|
indexes = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
indexes, batch_first=True, padding_value=seq_len - 1
|
||||||
|
) # (batch_size, seq_len_reduced)
|
||||||
|
|
||||||
|
weights = torch.gather(weights, dim=-1, index=indexes)
|
||||||
|
|
||||||
|
padding_mask = make_pad_mask(count)
|
||||||
|
weights = weights.masked_fill(padding_mask, 0.0)
|
||||||
|
|
||||||
# need to work out seq_len_reduced.
|
|
||||||
seq_len_reduced = max(1,
|
|
||||||
(weights > 0.0).to(torch.int32).sum(dim=-1).max().item())
|
|
||||||
if random.random() < 0.02:
|
if random.random() < 0.02:
|
||||||
|
seq_len_reduced = indexes.shape[1]
|
||||||
logging.info(f"seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
logging.info(f"seq_len={seq_len}, seq_len_reduced={seq_len_reduced}")
|
||||||
weights = weights[:, :seq_len_reduced]
|
|
||||||
|
|
||||||
indexes = indexes[:, :seq_len_reduced]
|
|
||||||
|
|
||||||
|
|
||||||
weights = self.copy_weights2(weights)
|
|
||||||
|
|
||||||
# re-sort the indexes we kept, on index value, so that
|
|
||||||
# masking for causal models will be in the correct order.
|
|
||||||
# (actually this may not really matter, TODO: see whether we
|
|
||||||
# can remove this??)
|
|
||||||
indexes, reorder = indexes.sort(dim=-1)
|
|
||||||
weights = torch.gather(weights, dim=-1, index=reorder)
|
|
||||||
|
|
||||||
x_downsampled = self.downsample(x, indexes)
|
x_downsampled = self.downsample(x, indexes)
|
||||||
return indexes, weights, x_downsampled
|
return indexes, weights, x_downsampled
|
||||||
|
|
||||||
|
|
||||||
def downsample(self, x: Tensor, indexes: Tensor) -> Tensor:
|
def downsample(self, x: Tensor, indexes: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Downsamples x via indexing with the indexes obtained from the
|
Downsamples x via indexing with the indexes obtained from the
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user