Fix various bugs

This commit is contained in:
Daniel Povey 2023-05-15 15:20:02 +08:00
parent f740282a1a
commit 1b8be0744f
3 changed files with 102 additions and 163 deletions

View File

@ -19,7 +19,6 @@
import torch import torch
from torch import nn, Tensor from torch import nn, Tensor
from chunk_decoder import ChunkDecoder
from zipformer import Zipformer2 from zipformer import Zipformer2
@ -28,7 +27,7 @@ class Zipformer2LM(nn.Module):
def __init__(self, def __init__(self,
encoder_embed: nn.Module, encoder_embed: nn.Module,
encoder: Zipformer2, encoder: Zipformer2,
decoder: ChunkDecoder): decoder: nn.Module):
super().__init__() super().__init__()
self.encoder_embed = encoder_embed self.encoder_embed = encoder_embed
self.encoder = encoder # does subsampling self.encoder = encoder # does subsampling
@ -47,18 +46,17 @@ class Zipformer2LM(nn.Module):
""" """
(batch_size, seq_len) = labels.shape (batch_size, seq_len) = labels.shape
chunk_size = self.decoder.chunk_size chunk_size = 1
labels_shifted = labels.t() # (time, batch) labels_shifted = labels.t() # (time, batch)
labels_shifted = torch.cat((torch.zeros_like(labels_shifted[:chunk_size]), labels_shifted = torch.cat((torch.zeros_like(labels_shifted[:1]),
labels_shifted[:-chunk_size]), labels_shifted[:-1]),
dim=0) dim=0)
x = self.encoder_embed(labels_shifted) x = self.encoder_embed(labels_shifted)
x_lens = torch.full((batch_size,), seq_len, x_lens = torch.full((batch_size,), seq_len,
dtype=torch.long, device=labels.device) dtype=torch.long, device=labels.device)
# x_lens is after subsampling. Actually we don't need it. # x_lens is after subsampling. Actually we don't need it.
(x, x_lens) = self.encoder(x, x_lens) (x, x_lens) = self.encoder(x, x_lens)
logprobs = self.decoder(labels, x) logprobs = self.decoder(labels, x)

View File

@ -76,11 +76,7 @@ class Subformer2(EncoderInterface):
dropout (float): dropout rate dropout (float): dropout rate
warmup_batches (float): number of batches to warm up over; this controls warmup_batches (float): number of batches to warm up over; this controls
dropout of encoder layers. dropout of encoder layers.
causal (bool): if True, support chunkwise causal convolution. This should causal (bool): if True, use causal attention-mask.
not hurt WER as no modeling power is lost, but the convolution modules will be
slightly slower and use more memory. Enables use of the chunk_size and
left_context_chunks options in forward(), which simulates streaming
decoding.
memory_dim: if supplied and >0, will be the dimension of the memory embeddings memory_dim: if supplied and >0, will be the dimension of the memory embeddings
passed into the zipformer (e.g. this might be the output of another passed into the zipformer (e.g. this might be the output of another
Subformer used to create embedding vectors.) Subformer used to create embedding vectors.)
@ -97,7 +93,6 @@ class Subformer2(EncoderInterface):
num_heads: Union[int, Tuple[int]] = 8, num_heads: Union[int, Tuple[int]] = 8,
feedforward_dim: Union[int, Tuple[int]] = 1536, feedforward_dim: Union[int, Tuple[int]] = 1536,
memory_dim: int = -1, memory_dim: int = -1,
pos_emb_dim: int = 192,
pos_dim: int = 4, pos_dim: int = 4,
dropout: FloatLike = None, # see code below for default dropout: FloatLike = None, # see code below for default
warmup_batches: float = 4000.0, warmup_batches: float = 4000.0,
@ -129,6 +124,7 @@ class Subformer2(EncoderInterface):
value_head_dim = _to_tuple(value_head_dim) value_head_dim = _to_tuple(value_head_dim)
num_heads = _to_tuple(num_heads) num_heads = _to_tuple(num_heads)
feedforward_dim = _to_tuple(feedforward_dim) feedforward_dim = _to_tuple(feedforward_dim)
self.causal = causal
for u,d in zip(encoder_unmasked_dim, encoder_dim): for u,d in zip(encoder_unmasked_dim, encoder_dim):
assert u <= d assert u <= d
@ -156,7 +152,6 @@ class Subformer2(EncoderInterface):
encoder = Subformer2Encoder( encoder = Subformer2Encoder(
encoder_layer, encoder_layer,
num_encoder_layers[i], num_encoder_layers[i],
pos_dim=pos_dim,
dropout=dropout, dropout=dropout,
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
@ -173,14 +168,15 @@ class Subformer2(EncoderInterface):
encoders.append(encoder) encoders.append(encoder)
self.encoder_pos = CompactRelPositionalEncoding(pos_emb_dim, pos_dim, dropout_rate=0.15, self.encoder_pos = CompactRelPositionalEncoding(64, pos_dim,
dropout_rate=0.15,
length_factor=1.0) length_factor=1.0)
self.encoders = nn.ModuleList(encoders) self.encoders = nn.ModuleList(encoders)
self.downsample_output = SimpleDownsample(max(encoder_dim), #self.downsample_output = SimpleDownsample(max(encoder_dim),
downsample=output_downsampling_factor, # downsample=output_downsampling_factor,
dropout=dropout) # dropout=dropout)
def get_feature_masks( def get_feature_masks(
self, self,
@ -273,7 +269,7 @@ class Subformer2(EncoderInterface):
outputs = [] outputs = []
feature_masks = self.get_feature_masks(x) feature_masks = self.get_feature_masks(x)
attn_offset = self._get_attn_offset(x) attn_offset = self._get_attn_offset(x, src_key_padding_mask)
if self.training and memory is not None: if self.training and memory is not None:
batch_size = x.shape[1] batch_size = x.shape[1]
@ -286,15 +282,11 @@ class Subformer2(EncoderInterface):
pos_emb = self.encoder_pos(x) pos_emb = self.encoder_pos(x)
for i, module in enumerate(self.encoders): for i, module in enumerate(self.encoders):
ds = self.downsampling_factor[i]
x = convert_num_channels(x, self.encoder_dim[i]) x = convert_num_channels(x, self.encoder_dim[i])
x = module(x, x = module(x,
pos_emb, pos_emb,
chunk_size=chunk_size,
feature_mask=feature_masks[i], feature_mask=feature_masks[i],
src_key_padding_mask=(None if src_key_padding_mask is None
else src_key_padding_mask[...,::ds]),
attn_offset=attn_offset, attn_offset=attn_offset,
memory=memory, memory=memory,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
@ -321,37 +313,37 @@ class Subformer2(EncoderInterface):
# from different pieces of 'outputs', taking each dimension from the # from different pieces of 'outputs', taking each dimension from the
# most recent output that has it present. # most recent output that has it present.
x = get_full_dim_output() x = get_full_dim_output()
x = self.downsample_output(x) #x = self.downsample_output(x)
d = self.output_downsampling_factor d = self.output_downsampling_factor
lengths = (x_lens + d - 1) // d lengths = (x_lens + d - 1) // d
return x, lengths return x, lengths
def _get_attn_offset(self, x: Tensor) -> Optional[Tensor]: def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]:
""" """
Return attention offset of shape (1, seq_len, seq_len), interpreted as (tgt_seq_len, Return attention offset of shape (1 or batch_size, seq_len, seq_len), interpreted as (1 or batch_size, tgt_seq_len,
src_seq_len); this reflects masking, if causal == True, otherwise will be all zeros. src_seq_len); this reflects masking, if causal == True, otherwise will be all zeros.
Args: Args:
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
chunk_size: chunk size, must divide src_key_padding_mask: optional key-padding mask of shape (batch_size, seq_len) with True in masked positions.
""" """
if not self.causal: seq_len, batch_size, _num_channels = x.shape
return None
seq_len = x.shape[0]
# t is frame index, shape (seq_len,)
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
src_c = c
tgt_c = c.unsqueeze(-1)
attn_mask = (src_c > tgt_c)
ans = torch.zeros(1, seq_len, seq_len, device=x.device) ans = torch.zeros(1, seq_len, seq_len, device=x.device)
ans.masked_fill(attn_mask, float('-inf')) if self.causal:
# t is frame index, shape (seq_len,)
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
src_t = t
tgt_t = t.unsqueeze(-1)
attn_mask = (src_t > tgt_t)
ans.masked_fill(attn_mask, float('-inf'))
if src_key_padding_mask is not None:
ans = ans * src_key_padding_mask.unsqueeze(1).logical_not()
# now ans: (batch_size, seq_len, seq_len).
return ans return ans
@ -384,11 +376,10 @@ class Subformer2EncoderLayer(nn.Module):
def __init__( def __init__(
self, self,
embed_dim: int, embed_dim: int,
pos_dim: int,
num_heads: int, num_heads: int,
query_head_dim: int, query_head_dim: int,
pos_dim: int,
value_head_dim: int, value_head_dim: int,
pos_dim: int,
feedforward_dim: int, feedforward_dim: int,
dropout: FloatLike = 0.1, dropout: FloatLike = 0.1,
causal: bool = False, causal: bool = False,
@ -431,14 +422,15 @@ class Subformer2EncoderLayer(nn.Module):
self.self_attn1 = Attention(embed_dim, embed_dim, num_heads, self.self_attn1 = Attention(embed_dim, embed_dim, num_heads,
value_head_dim) value_head_dim)
self.self_attn2 = Attention(embed_dim, embed_dim, num_heads, self.self_attn2 = Attention(embed_dim, embed_dim, num_heads,
value_head_dim) value_head_dim)
if memory_dim > 0: if memory_dim > 0:
self.attn_weights = MultiheadAttentionWeights( self.attn_weights = MultiheadAttentionWeights(
memory_dim, embed_dim, memory_dim,
embed_dim,
num_heads=num_heads, num_heads=num_heads,
head_dim=query_head_dim, head_dim=query_head_dim,
dropout=0.0, dropout=0.0,
@ -559,7 +551,6 @@ class Subformer2EncoderLayer(nn.Module):
self, self,
src: Tensor, src: Tensor,
pos_emb: Tensor, pos_emb: Tensor,
chunk_size: int = -1,
attn_offset: Optional[Tensor] = None, attn_offset: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
memory: Optional[Tensor] = None, memory: Optional[Tensor] = None,
@ -570,7 +561,6 @@ class Subformer2EncoderLayer(nn.Module):
Args: Args:
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
feature_mask: something that broadcasts with src, that we'll multiply `src` feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
attn_offset: the attention offset, of shape broadcasting with (batch_size, seq_len, seq_len), attn_offset: the attention offset, of shape broadcasting with (batch_size, seq_len, seq_len),
@ -591,7 +581,6 @@ class Subformer2EncoderLayer(nn.Module):
src, src,
pos_emb=pos_emb, pos_emb=pos_emb,
attn_offset=attn_offset, attn_offset=attn_offset,
key_padding_mask=src_key_padding_mask,
) )
if memory is not None and hasattr(self, 'attn_weights'): if memory is not None and hasattr(self, 'attn_weights'):
@ -662,7 +651,6 @@ class Subformer2Encoder(nn.Module):
Args: Args:
encoder_layer: an instance of the Subformer2EncoderLayer() class (required). encoder_layer: an instance of the Subformer2EncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required). num_layers: the number of sub-encoder-layers in the encoder (required).
pos_dim: the dimension for the relative positional encoding
Examples:: Examples::
>>> encoder_layer = Subformer2EncoderLayer(embed_dim=512, nhead=8) >>> encoder_layer = Subformer2EncoderLayer(embed_dim=512, nhead=8)
@ -674,7 +662,6 @@ class Subformer2Encoder(nn.Module):
self, self,
encoder_layer: nn.Module, encoder_layer: nn.Module,
num_layers: int, num_layers: int,
pos_dim: int,
dropout: float, dropout: float,
warmup_begin: float, warmup_begin: float,
warmup_end: float, warmup_end: float,
@ -701,10 +688,9 @@ class Subformer2Encoder(nn.Module):
def forward( def forward(
self, self,
src: Tensor, src: Tensor,
chunk_size: int = -1, pos_emb: Tensor,
feature_mask: Union[Tensor, float] = 1.0, feature_mask: Union[Tensor, float] = 1.0,
attn_offset: Optional[Tensor] = None, attn_offset: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
memory: Optional[Tensor] = None, memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
@ -712,14 +698,13 @@ class Subformer2Encoder(nn.Module):
Args: Args:
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. pos_emb: positional embedding tensor, of shape (batch_size, seq_len, seq_len, pos_dim),
e.g. pos_dim=4.
feature_mask: something that broadcasts with src, that we'll multiply `src` feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
attn_offset: the attention offset (does masking and related tasks), of shape attn_offset: the attention offset (does masking and related tasks), of shape
broadcasting with (batch_size, seq_len, seq_len), broadcasting with (batch_size, seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len). interpreted as (batch_size, tgt_seq_len, src_seq_len).
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None.
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
memory_key_padding_mask: optionally the mask for padding of memory input (for source- memory_key_padding_mask: optionally the mask for padding of memory input (for source-
attention), of shape (batch_size, memory_len); True means attention), of shape (batch_size, memory_len); True means
@ -727,7 +712,6 @@ class Subformer2Encoder(nn.Module):
Returns: a Tensor with the same shape as src. Returns: a Tensor with the same shape as src.
""" """
pos_emb = self.encoder_pos(src)
output = src output = src
rnd_seed = src.numel() + random.randint(0, 1000) rnd_seed = src.numel() + random.randint(0, 1000)
@ -738,9 +722,7 @@ class Subformer2Encoder(nn.Module):
output = mod( output = mod(
output, output,
pos_emb, pos_emb,
chunk_size=chunk_size,
attn_offset=attn_offset, attn_offset=attn_offset,
src_key_padding_mask=src_key_padding_mask,
memory=memory, memory=memory,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
) )
@ -827,12 +809,13 @@ class LearnedDownsamplingModule(nn.Module):
embed_dim: int, embed_dim: int,
downsampling_factor: int, downsampling_factor: int,
intermediate_rate: FloatLike = 0.2): intermediate_rate: FloatLike = 0.2):
super().__init__()
self.to_scores = nn.Linear(embed_dim, 1, bias=False) self.to_scores = nn.Linear(embed_dim, 1, bias=False)
# score_balancer is just to keep the magnitudes of the scores in # score_balancer is just to keep the magnitudes of the scores in
# a fixed range and keep them balanced around zero, to stop # a fixed range and keep them balanced around zero, to stop
# these drifting around. # these drifting around.
self.score_balancer = Balancer(1, channel_dim=-1, self.score_balancer = Balancer(1, channel_dim=-1,
min_positive=0.4, max_positive=0.6 min_positive=0.4, max_positive=0.6,
min_abs=1.0, max_abs=1.2, min_abs=1.0, max_abs=1.2,
prob=0.025) prob=0.025)
@ -856,14 +839,14 @@ class LearnedDownsamplingModule(nn.Module):
corresponding to the kept frames; these will be between 0 and 1, but corresponding to the kept frames; these will be between 0 and 1, but
mostly exactly 1. mostly exactly 1.
""" """
(seq_len, batch_size, _) (seq_len, batch_size, _) = x.shape
scores = self.to_scores(x) # (seq_len, batch_size, 1) scores = self.to_scores(x) # (seq_len, batch_size, 1)
scores = self.score_balancer(scores) scores = self.score_balancer(scores)
scores = scores.squeeze(-1).t() # (batch_size, seq_len) scores = scores.squeeze(-1).t() # (batch_size, seq_len)
# indexes, sscores: (batch_size, seq_len) # sscores, indexes: (batch_size, seq_len)
indexes, sscores = scores.sort(dim=-1, descending=True) sscores, indexes = scores.sort(dim=-1, descending=True)
d = self.downsampling_factor d = self.downsampling_factor
seq_len_reduced = (seq_len + d - 1) // d seq_len_reduced = (seq_len + d - 1) // d
@ -883,10 +866,10 @@ class LearnedDownsamplingModule(nn.Module):
collar = max(1, int(seq_len_reduced * 0.5 * self.intermediate_rate)) collar = max(1, int(seq_len_reduced * 0.5 * self.intermediate_rate))
# right_avg: shape (batch_size,), this is to be mapped to 0.0 # right_avg: shape (batch_size,), this is to be mapped to 0.0
right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1) right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1, keepdim=True)
# left_avg: shape (batch_size,), this is to be mapped to 1.0 # left_avg: shape (batch_size,), this is to be mapped to 1.0
left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1) left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1, keepdim=True)
# the + 0.001 is to avoid possible division by zero in case of ties. # the + 0.001 is to avoid possible division by zero in case of ties.
weights = (sscores - right_avg) / (left_avg - right_avg + 0.001) weights = (sscores - right_avg) / (left_avg - right_avg + 0.001)
@ -901,11 +884,11 @@ class LearnedDownsamplingModule(nn.Module):
indexes, reorder = indexes.sort(dim=-1) indexes, reorder = indexes.sort(dim=-1)
weights = torch.gather(weights, dim=-1, index=reorder) weights = torch.gather(weights, dim=-1, index=reorder)
x_downsampled = downsample(indexes, x) x_downsampled = self.downsample(x, indexes)
return indexes, weights, x_downsampled return indexes, weights, x_downsampled
def downsample(x: Tensor, indexes: Tensor) -> Tensor: def downsample(self, x: Tensor, indexes: Tensor) -> Tensor:
""" """
Downsamples x via indexing with the indexes obtained from the Downsamples x via indexing with the indexes obtained from the
forward() function. forward() function.
@ -917,19 +900,19 @@ class LearnedDownsamplingModule(nn.Module):
Returns: Returns:
x_downsampled, of shape (seq_len_reduced, batch_size, num_channels) x_downsampled, of shape (seq_len_reduced, batch_size, num_channels)
""" """
indexes = indexes.t().unsqueeze(-1).expand(-1, -1, x.shape[-1]) indexes_expanded = indexes.t().unsqueeze(-1).expand(-1, -1, x.shape[-1])
# indexes now: (seq_len_reduced, batch_size, num_channels) # indexe_expanded: (seq_len_reduced, batch_size, num_channels)
ans = torch.gather(x, dim=0, index=indexes) ans = torch.gather(x, dim=0, index=indexes_expanded)
if __name__ == __main__: if __name__ == '__main__':
# temp, for testing # temp, for testing
x_reconstructed = upsample(x, ans, indexes) x_reconstructed = self.upsample(x, ans, indexes)
assert torch.allclose(x, x_reconstructed) assert torch.allclose(x, x_reconstructed)
return ans return ans
def downsample_pos_emb(pos_emb: Tensor, indexes: Tensor) -> Tensor: def downsample_pos_emb(self, pos_emb: Tensor, indexes: Tensor) -> Tensor:
""" """
Downsample positional embedding tensor with the provided indexes. Downsample positional embedding tensor with the provided indexes.
Args: Args:
@ -958,7 +941,8 @@ class LearnedDownsamplingModule(nn.Module):
return pos_emb return pos_emb
def downsample_attn_offset(attn_offset: Tensor, def downsample_attn_offset(self,
attn_offset: Tensor,
indexes: Tensor, indexes: Tensor,
weights: Tensor, weights: Tensor,
eps: float = 1.0e-05) -> Tensor: eps: float = 1.0e-05) -> Tensor:
@ -979,18 +963,17 @@ class LearnedDownsamplingModule(nn.Module):
assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len) assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len)
attn_offset = attn_offset.expand(batch_size, seq_len, seq_len) attn_offset = attn_offset.expand(batch_size, seq_len, seq_len)
attn_offset = attn_offset.gather(dim=1, index=indexes.unsqueeze(-1).expand(
attn_offset = attn_offset.gather(dim=1, src=indices.unsqueeze(-1).expand(
batch_size, seq_len_reduced, seq_len)) batch_size, seq_len_reduced, seq_len))
attn_offset = attn_offset.gather(dim=2, src=indices.unsqueeze(1).expand( attn_offset = attn_offset.gather(dim=2, index=indexes.unsqueeze(1).expand(
batch_size, seq_len_reduced, seq_len_reduced)) batch_size, seq_len_reduced, seq_len_reduced))
# unsqueeze at position 1 so the extra cost relates to the source position. # unsqueeze at position 1 so the extra cost relates to the source position.
attn_offset = attn_offset + weights.clamp(min=eps).log().unsqueeze(1) attn_offset = attn_offset + weights.clamp(min=eps).log().unsqueeze(1)
return attn_offst return attn_offset
def upsample(x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor: def upsample(self, x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor:
""" """
Upsamples, reversing the downsample() operation and filling in Upsamples, reversing the downsample() operation and filling in
any not-chosen frames with their original value before downsampling any not-chosen frames with their original value before downsampling
@ -1013,14 +996,14 @@ class LearnedDownsamplingModule(nn.Module):
not_kept = torch.ones(batch_size, seq_len, dtype=torch.bool, not_kept = torch.ones(batch_size, seq_len, dtype=torch.bool,
device=x.device) device=x.device)
not_kept.scatter_(src=False, dim=1, index=indexes) not_kept.scatter_(dim=1, index=indexes, value=False)
indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels) indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels)
# indexes now: (seq_len_reduced, batch_size, num_channels) # indexes now: (seq_len_reduced, batch_size, num_channels)
ans = torch.zeros_like(x_orig) ans = torch.zeros_like(x_orig)
ans.scatter_(x, dim=0, index=indexes) ans.scatter_(dim=0, index=indexes, src=x)
# add in x_orig in the frames that were not originally kept. # add in x_orig in the frames that were not originally kept.
return ans + x_orig * not_kept.t().unsqueeze(-1) return ans + x_orig * not_kept.t().unsqueeze(-1)
@ -1051,7 +1034,6 @@ class DownsampledSubformer2Encoder(nn.Module):
pos_emb: Tensor, pos_emb: Tensor,
feature_mask: Union[Tensor, float] = 1.0, feature_mask: Union[Tensor, float] = 1.0,
attn_offset: Optional[Tensor] = None, attn_offset: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
memory: Optional[Tensor] = None, memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
@ -1065,8 +1047,6 @@ class DownsampledSubformer2Encoder(nn.Module):
attn_offset: the attention offset, added to scores for attention of shape attn_offset: the attention offset, added to scores for attention of shape
(batch_size, seq_len, seq_len) or (seq_len, seq_len), (batch_size, seq_len, seq_len) or (seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None.
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
memory_key_padding_mask: optionally the mask for padding of memory input (for source- memory_key_padding_mask: optionally the mask for padding of memory input (for source-
attention), of shape (batch_size, memory_len); True means attention), of shape (batch_size, memory_len); True means
@ -1079,30 +1059,24 @@ class DownsampledSubformer2Encoder(nn.Module):
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes) pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
attn_offset = self.downsample.downsample_attn_offset(attn_offset, attn_offset = self.downsampler.downsample_attn_offset(attn_offset,
indexes, indexes,
weights.clamp(min=1.0e-05)) weights)
src = self.encoder( src = self.encoder(
src, src,
os_emb, pos_emb,
feature_mask=feature_mask, feature_mask=feature_mask,
attn_offset=attn_offset, attn_offset=attn_offset,
src_key_padding_mask=src_key_padding_mask,
memory=memory, memory=memory,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
) )
src = self.upsample(src) src = self.downsampler.upsample(src_orig, src, indexes)
# remove any extra frames that are not a multiple of downsample_factor
src = src[:src_orig.shape[0]]
return self.out_combiner(src_orig, src) return self.out_combiner(src_orig, src)
class CompactRelPositionalEncoding(torch.nn.Module): class CompactRelPositionalEncoding(torch.nn.Module):
""" """
Relative positional encoding module. This version is "compact" meaning it is able to encode Relative positional encoding module. This version is "compact" meaning it is able to encode
@ -1123,6 +1097,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
Args: Args:
embed_dim: Temporary embedding dimension used inside this module embed_dim: Temporary embedding dimension used inside this module
pos_dim: Smaller positional-encoding dim used after a projecction.
dropout_rate: Dropout rate. dropout_rate: Dropout rate.
max_len: Maximum input length: just a heuristic for initialization. max_len: Maximum input length: just a heuristic for initialization.
length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
@ -1130,11 +1105,12 @@ class CompactRelPositionalEncoding(torch.nn.Module):
pos_dim: dimension at the output of this module. pos_dim: dimension at the output of this module.
""" """
def __init__( def __init__(
self, embed_dim: int, self,
embed_dim: int,
pos_dim: int,
dropout_rate: FloatLike, dropout_rate: FloatLike,
max_len: int = 1000, max_len: int = 1000,
length_factor: float = 1.0, length_factor: float = 1.0,
pos_dim: int = 4,
) -> None: ) -> None:
"""Construct a CompactRelPositionalEncoding object.""" """Construct a CompactRelPositionalEncoding object."""
super(CompactRelPositionalEncoding, self).__init__() super(CompactRelPositionalEncoding, self).__init__()
@ -1211,16 +1187,13 @@ class CompactRelPositionalEncoding(torch.nn.Module):
x (torch.Tensor): Input tensor (time, batch, num_channels_in) x (torch.Tensor): Input tensor (time, batch, num_channels_in)
Returns: Returns:
positional embedding, of shape (1, 2*time-1, pos_dim). positional embedding, of shape (batch_size, 2*time-1, pos_dim).
""" """
self.extend_pe(x) self.extend_pe(x)
seq_len = x.size(0) seq_len = x.size(0)
pos_emb = self.pe[ pos_emb = self.pe[
self.pe.size(0) // 2 self.pe.size(0) // 2 - seq_len + 1 : self.pe.size(0) // 2 + seq_len,
- seq_len,
+ 1 : self.pe.size(0) // 2 # noqa E203
+ seq_len,
: :
] ]
pos_emb = pos_emb.unsqueeze(0) pos_emb = pos_emb.unsqueeze(0)
@ -1230,12 +1203,20 @@ class CompactRelPositionalEncoding(torch.nn.Module):
# currenly pos_emb: (1, 2*seq_len-1, pos_dim) # currenly pos_emb: (1, 2*seq_len-1, pos_dim)
pos_dim = pos_emb.shape[-1] pos_dim = pos_emb.shape[-1]
batch_size = x.size(1) batch_size = x.size(1)
(_, seq_stride, channel_stride) = pos_emb.stride()
# it doesn't really matter which one we make positive and which negative here, it # it doesn't really matter which one we make positive and which negative here, it
# would just flip the meaning of the embedding. # would just flip the meaning of the embedding.
# expand the '1' dimension to seq_len; this introduces a dimension that
# 'does nothing', just creates copies, as a workaround for lack of torch support
# for negative strides.
pos_emb = pos_emb.expand(seq_len, 2*seq_len-1, pos_dim).contiguous()
(useless_stride, seq_stride, channel_stride) = pos_emb.stride()
pos_emb = pos_emb.as_strided((batch_size, seq_len, seq_len, pos_dim), pos_emb = pos_emb.as_strided((batch_size, seq_len, seq_len, pos_dim),
(0, -seq_stride, seq_stride, channel_stride), (0, useless_stride-seq_stride, seq_stride, channel_stride),
storage_offset=seq_stride * (seqs_len - 1)) storage_offset=seq_stride * (seq_len - 1))
return pos_emb # (batch_size, seq_len, seq_len, pos_dim) return pos_emb # (batch_size, seq_len, seq_len, pos_dim)
@ -1326,8 +1307,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
x: Tensor, x: Tensor,
pos_emb: Tensor, pos_emb: Tensor,
attn_offset: Optional[Tensor] = None, attn_offset: Optional[Tensor] = None,
pos_emb: Tensor,
quadratic_pos_weight: Tensor,
) -> Tensor: ) -> Tensor:
r""" r"""
Args: Args:
@ -1368,35 +1347,23 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) p = p.reshape(seq_len, batch_size, num_heads, pos_dim)
k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
# time1 refers to target, time2 refers to source. q = q.permute(2, 1, 0, 3) # (head, batch, tgt_seq_len, query_head_dim)
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) k = k.permute(2, 1, 3, 0) # (head, batch, d_k, src_seq_len)
p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
# attn_scores: (num_heads, batch_size, tgt_seq_len, src_esq_len)
attn_scores = torch.matmul(q, k) attn_scores = torch.matmul(q, k)
if not self.training or random.random() >= float(self.pos_emb_skip_rate): if not self.training or random.random() >= float(self.pos_emb_skip_rate):
pos_emb = self.linear_pos(pos_emb) # pos_emb: (batch_size, tgt_seq_len, src_seq_len, pos_dim)
seq_len2 = 2 * seq_len - 1 p = p.permute(1, 0, 3, 2) # (batch_size, tgt_seq_len, pos_dim, num_heads)
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
# [where seq_len2 represents relative position.]
pos_scores = torch.matmul(p, pos_emb)
# the following .as_strided() expression converts the last axis of pos_scores from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be.
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
(pos_scores.stride(0),
pos_scores.stride(1),
pos_scores.stride(2)-pos_scores.stride(3),
pos_scores.stride(3)),
storage_offset=pos_scores.stride(3) * (seq_len - 1))
pos_scores = torch.matmul(pos_emb, p)
# pos_scores: (batch_size, tgt_seq_len, src_seq_len, num_heads)
pos_scores = pos_scores.permute(3, 0, 1, 2)
# pos_scores: (num_heads, batch_size, tgt_seq_len, src_seq_len)
attn_scores = attn_scores + pos_scores attn_scores = attn_scores + pos_scores
if self.training and random.random() < 0.1: if self.training and random.random() < 0.1:
@ -1417,23 +1384,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
penalty=1.0e-04, penalty=1.0e-04,
name=self.name) name=self.name)
# attn_offset includes key-padding mask and attention-mask, plus any weights
# from the subsampling.
attn_scores = attn_scores + attn_offset
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
if attn_mask is not None:
assert attn_mask.dtype == torch.bool
# use -1000 to avoid nan's where attn_mask and key_padding_mask make
# all scores zero. It's important that this be large enough that exp(-1000)
# is exactly zero, for reasons related to const_attention_rate, it
# compares the final weights with zero.
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
attn_scores = attn_scores.masked_fill(
key_padding_mask.unsqueeze(1),
-1000,
)
# We use our own version of softmax, defined in scaling.py, which should # We use our own version of softmax, defined in scaling.py, which should
# save a little of the memory used in backprop by, if we are in # save a little of the memory used in backprop by, if we are in
# automatic mixed precision mode (amp / autocast), by only storing the # automatic mixed precision mode (amp / autocast), by only storing the
@ -1617,9 +1573,9 @@ class MultiheadAttentionWeights(nn.Module):
q = q.reshape(query_len, batch_size, num_heads, head_dim) q = q.reshape(query_len, batch_size, num_heads, head_dim)
k = k.reshape(key_len, batch_size, num_heads, head_dim) k = k.reshape(key_len, batch_size, num_heads, head_dim)
# time1 refers to target, time2 refers to source. # tgt_seq_len refers to target, src_seq_len refers to source.
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) q = q.permute(2, 1, 0, 3) # (head, batch, tgt_seq_len, query_head_dim)
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) k = k.permute(2, 1, 3, 0) # (head, batch, d_k, src_seq_len)
attn_scores = torch.matmul(q, k) attn_scores = torch.matmul(q, k)
@ -1842,8 +1798,6 @@ def _test_zipformer_main(causal: bool = False):
c = Subformer2( c = Subformer2(
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
causal=causal, causal=causal,
chunk_size=(4,) if causal else (-1,),
left_context_frames=(64,),
memory_dim=memory_dim, memory_dim=memory_dim,
) )
batch_size = 5 batch_size = 5

View File

@ -63,7 +63,7 @@ from lm_datamodule import LmDataset, LmDataloader
from zipformer import Zipformer2 from zipformer import Zipformer2
from scaling import ScheduledFloat from scaling import ScheduledFloat
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from chunk_decoder import ChunkDecoder from decoder import Decoder
from model import Zipformer2LM from model import Zipformer2LM
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
@ -176,13 +176,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list." help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list."
) )
parser.add_argument(
"--pos-dim",
type=int,
default="48",
help="Positional-encoding embedding dimension"
)
parser.add_argument( parser.add_argument(
"--encoder-unmasked-dim", "--encoder-unmasked-dim",
type=str, type=str,
@ -505,9 +498,9 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
def get_encoder_model(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module:
chunk_size = _to_int_tuple(params.downsampling_factor)[-1] #chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
encoder = Zipformer2( encoder = Zipformer2(
output_downsampling_factor=chunk_size, #output_downsampling_factor=chunk_size,
downsampling_factor=_to_int_tuple(params.downsampling_factor), downsampling_factor=_to_int_tuple(params.downsampling_factor),
num_encoder_layers=_to_int_tuple(params.num_encoder_layers), num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
encoder_dim=_to_int_tuple(params.encoder_dim), encoder_dim=_to_int_tuple(params.encoder_dim),
@ -515,10 +508,8 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
query_head_dim=_to_int_tuple(params.query_head_dim), query_head_dim=_to_int_tuple(params.query_head_dim),
pos_head_dim=_to_int_tuple(params.pos_head_dim), pos_head_dim=_to_int_tuple(params.pos_head_dim),
value_head_dim=_to_int_tuple(params.value_head_dim), value_head_dim=_to_int_tuple(params.value_head_dim),
pos_dim=params.pos_dim,
num_heads=_to_int_tuple(params.num_heads), num_heads=_to_int_tuple(params.num_heads),
feedforward_dim=_to_int_tuple(params.feedforward_dim), feedforward_dim=_to_int_tuple(params.feedforward_dim),
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0, warmup_batches=4000.0,
causal=True, causal=True,
@ -529,13 +520,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
def get_decoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module:
chunk_size = _to_int_tuple(params.downsampling_factor)[-1] decoder = DecoderDecoder(
decoder = ChunkDecoder(
embed_dim=max(_to_int_tuple(params.encoder_dim)), embed_dim=max(_to_int_tuple(params.encoder_dim)),
chunk_size=chunk_size,
vocab_size=256, # bytes vocab_size=256, # bytes
hidden_size=params.decoder_hidden_size,
num_layers=params.decoder_num_layers,
) )
return decoder return decoder