Fix various bugs

This commit is contained in:
Daniel Povey 2023-05-15 15:20:02 +08:00
parent f740282a1a
commit 1b8be0744f
3 changed files with 102 additions and 163 deletions

View File

@ -19,7 +19,6 @@
import torch
from torch import nn, Tensor
from chunk_decoder import ChunkDecoder
from zipformer import Zipformer2
@ -28,7 +27,7 @@ class Zipformer2LM(nn.Module):
def __init__(self,
encoder_embed: nn.Module,
encoder: Zipformer2,
decoder: ChunkDecoder):
decoder: nn.Module):
super().__init__()
self.encoder_embed = encoder_embed
self.encoder = encoder # does subsampling
@ -47,18 +46,17 @@ class Zipformer2LM(nn.Module):
"""
(batch_size, seq_len) = labels.shape
chunk_size = self.decoder.chunk_size
chunk_size = 1
labels_shifted = labels.t() # (time, batch)
labels_shifted = torch.cat((torch.zeros_like(labels_shifted[:chunk_size]),
labels_shifted[:-chunk_size]),
labels_shifted = torch.cat((torch.zeros_like(labels_shifted[:1]),
labels_shifted[:-1]),
dim=0)
x = self.encoder_embed(labels_shifted)
x_lens = torch.full((batch_size,), seq_len,
dtype=torch.long, device=labels.device)
# x_lens is after subsampling. Actually we don't need it.
(x, x_lens) = self.encoder(x, x_lens)
logprobs = self.decoder(labels, x)

View File

@ -76,11 +76,7 @@ class Subformer2(EncoderInterface):
dropout (float): dropout rate
warmup_batches (float): number of batches to warm up over; this controls
dropout of encoder layers.
causal (bool): if True, support chunkwise causal convolution. This should
not hurt WER as no modeling power is lost, but the convolution modules will be
slightly slower and use more memory. Enables use of the chunk_size and
left_context_chunks options in forward(), which simulates streaming
decoding.
causal (bool): if True, use causal attention-mask.
memory_dim: if supplied and >0, will be the dimension of the memory embeddings
passed into the zipformer (e.g. this might be the output of another
Subformer used to create embedding vectors.)
@ -97,7 +93,6 @@ class Subformer2(EncoderInterface):
num_heads: Union[int, Tuple[int]] = 8,
feedforward_dim: Union[int, Tuple[int]] = 1536,
memory_dim: int = -1,
pos_emb_dim: int = 192,
pos_dim: int = 4,
dropout: FloatLike = None, # see code below for default
warmup_batches: float = 4000.0,
@ -129,6 +124,7 @@ class Subformer2(EncoderInterface):
value_head_dim = _to_tuple(value_head_dim)
num_heads = _to_tuple(num_heads)
feedforward_dim = _to_tuple(feedforward_dim)
self.causal = causal
for u,d in zip(encoder_unmasked_dim, encoder_dim):
assert u <= d
@ -156,7 +152,6 @@ class Subformer2(EncoderInterface):
encoder = Subformer2Encoder(
encoder_layer,
num_encoder_layers[i],
pos_dim=pos_dim,
dropout=dropout,
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
@ -173,14 +168,15 @@ class Subformer2(EncoderInterface):
encoders.append(encoder)
self.encoder_pos = CompactRelPositionalEncoding(pos_emb_dim, pos_dim, dropout_rate=0.15,
self.encoder_pos = CompactRelPositionalEncoding(64, pos_dim,
dropout_rate=0.15,
length_factor=1.0)
self.encoders = nn.ModuleList(encoders)
self.downsample_output = SimpleDownsample(max(encoder_dim),
downsample=output_downsampling_factor,
dropout=dropout)
#self.downsample_output = SimpleDownsample(max(encoder_dim),
# downsample=output_downsampling_factor,
# dropout=dropout)
def get_feature_masks(
self,
@ -273,7 +269,7 @@ class Subformer2(EncoderInterface):
outputs = []
feature_masks = self.get_feature_masks(x)
attn_offset = self._get_attn_offset(x)
attn_offset = self._get_attn_offset(x, src_key_padding_mask)
if self.training and memory is not None:
batch_size = x.shape[1]
@ -286,15 +282,11 @@ class Subformer2(EncoderInterface):
pos_emb = self.encoder_pos(x)
for i, module in enumerate(self.encoders):
ds = self.downsampling_factor[i]
x = convert_num_channels(x, self.encoder_dim[i])
x = module(x,
pos_emb,
chunk_size=chunk_size,
feature_mask=feature_masks[i],
src_key_padding_mask=(None if src_key_padding_mask is None
else src_key_padding_mask[...,::ds]),
attn_offset=attn_offset,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
@ -321,37 +313,37 @@ class Subformer2(EncoderInterface):
# from different pieces of 'outputs', taking each dimension from the
# most recent output that has it present.
x = get_full_dim_output()
x = self.downsample_output(x)
#x = self.downsample_output(x)
d = self.output_downsampling_factor
lengths = (x_lens + d - 1) // d
return x, lengths
def _get_attn_offset(self, x: Tensor) -> Optional[Tensor]:
def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]:
"""
Return attention offset of shape (1, seq_len, seq_len), interpreted as (tgt_seq_len,
Return attention offset of shape (1 or batch_size, seq_len, seq_len), interpreted as (1 or batch_size, tgt_seq_len,
src_seq_len); this reflects masking, if causal == True, otherwise will be all zeros.
Args:
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
chunk_size: chunk size, must divide
src_key_padding_mask: optional key-padding mask of shape (batch_size, seq_len) with True in masked positions.
"""
if not self.causal:
return None
seq_len = x.shape[0]
# t is frame index, shape (seq_len,)
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
src_c = c
tgt_c = c.unsqueeze(-1)
attn_mask = (src_c > tgt_c)
seq_len, batch_size, _num_channels = x.shape
ans = torch.zeros(1, seq_len, seq_len, device=x.device)
ans.masked_fill(attn_mask, float('-inf'))
if self.causal:
# t is frame index, shape (seq_len,)
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
src_t = t
tgt_t = t.unsqueeze(-1)
attn_mask = (src_t > tgt_t)
ans.masked_fill(attn_mask, float('-inf'))
if src_key_padding_mask is not None:
ans = ans * src_key_padding_mask.unsqueeze(1).logical_not()
# now ans: (batch_size, seq_len, seq_len).
return ans
@ -384,11 +376,10 @@ class Subformer2EncoderLayer(nn.Module):
def __init__(
self,
embed_dim: int,
pos_dim: int,
num_heads: int,
query_head_dim: int,
pos_dim: int,
value_head_dim: int,
pos_dim: int,
feedforward_dim: int,
dropout: FloatLike = 0.1,
causal: bool = False,
@ -431,14 +422,15 @@ class Subformer2EncoderLayer(nn.Module):
self.self_attn1 = Attention(embed_dim, embed_dim, num_heads,
value_head_dim)
value_head_dim)
self.self_attn2 = Attention(embed_dim, embed_dim, num_heads,
value_head_dim)
if memory_dim > 0:
self.attn_weights = MultiheadAttentionWeights(
memory_dim, embed_dim,
memory_dim,
embed_dim,
num_heads=num_heads,
head_dim=query_head_dim,
dropout=0.0,
@ -559,7 +551,6 @@ class Subformer2EncoderLayer(nn.Module):
self,
src: Tensor,
pos_emb: Tensor,
chunk_size: int = -1,
attn_offset: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
memory: Optional[Tensor] = None,
@ -570,7 +561,6 @@ class Subformer2EncoderLayer(nn.Module):
Args:
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
attn_offset: the attention offset, of shape broadcasting with (batch_size, seq_len, seq_len),
@ -591,7 +581,6 @@ class Subformer2EncoderLayer(nn.Module):
src,
pos_emb=pos_emb,
attn_offset=attn_offset,
key_padding_mask=src_key_padding_mask,
)
if memory is not None and hasattr(self, 'attn_weights'):
@ -662,7 +651,6 @@ class Subformer2Encoder(nn.Module):
Args:
encoder_layer: an instance of the Subformer2EncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
pos_dim: the dimension for the relative positional encoding
Examples::
>>> encoder_layer = Subformer2EncoderLayer(embed_dim=512, nhead=8)
@ -674,7 +662,6 @@ class Subformer2Encoder(nn.Module):
self,
encoder_layer: nn.Module,
num_layers: int,
pos_dim: int,
dropout: float,
warmup_begin: float,
warmup_end: float,
@ -701,10 +688,9 @@ class Subformer2Encoder(nn.Module):
def forward(
self,
src: Tensor,
chunk_size: int = -1,
pos_emb: Tensor,
feature_mask: Union[Tensor, float] = 1.0,
attn_offset: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
@ -712,14 +698,13 @@ class Subformer2Encoder(nn.Module):
Args:
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
pos_emb: positional embedding tensor, of shape (batch_size, seq_len, seq_len, pos_dim),
e.g. pos_dim=4.
feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
attn_offset: the attention offset (does masking and related tasks), of shape
broadcasting with (batch_size, seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len).
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None.
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
attention), of shape (batch_size, memory_len); True means
@ -727,7 +712,6 @@ class Subformer2Encoder(nn.Module):
Returns: a Tensor with the same shape as src.
"""
pos_emb = self.encoder_pos(src)
output = src
rnd_seed = src.numel() + random.randint(0, 1000)
@ -738,9 +722,7 @@ class Subformer2Encoder(nn.Module):
output = mod(
output,
pos_emb,
chunk_size=chunk_size,
attn_offset=attn_offset,
src_key_padding_mask=src_key_padding_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
@ -827,12 +809,13 @@ class LearnedDownsamplingModule(nn.Module):
embed_dim: int,
downsampling_factor: int,
intermediate_rate: FloatLike = 0.2):
super().__init__()
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
# score_balancer is just to keep the magnitudes of the scores in
# a fixed range and keep them balanced around zero, to stop
# these drifting around.
self.score_balancer = Balancer(1, channel_dim=-1,
min_positive=0.4, max_positive=0.6
min_positive=0.4, max_positive=0.6,
min_abs=1.0, max_abs=1.2,
prob=0.025)
@ -856,14 +839,14 @@ class LearnedDownsamplingModule(nn.Module):
corresponding to the kept frames; these will be between 0 and 1, but
mostly exactly 1.
"""
(seq_len, batch_size, _)
(seq_len, batch_size, _) = x.shape
scores = self.to_scores(x) # (seq_len, batch_size, 1)
scores = self.score_balancer(scores)
scores = scores.squeeze(-1).t() # (batch_size, seq_len)
# indexes, sscores: (batch_size, seq_len)
indexes, sscores = scores.sort(dim=-1, descending=True)
# sscores, indexes: (batch_size, seq_len)
sscores, indexes = scores.sort(dim=-1, descending=True)
d = self.downsampling_factor
seq_len_reduced = (seq_len + d - 1) // d
@ -883,10 +866,10 @@ class LearnedDownsamplingModule(nn.Module):
collar = max(1, int(seq_len_reduced * 0.5 * self.intermediate_rate))
# right_avg: shape (batch_size,), this is to be mapped to 0.0
right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1)
right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1, keepdim=True)
# left_avg: shape (batch_size,), this is to be mapped to 1.0
left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1)
left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1, keepdim=True)
# the + 0.001 is to avoid possible division by zero in case of ties.
weights = (sscores - right_avg) / (left_avg - right_avg + 0.001)
@ -901,11 +884,11 @@ class LearnedDownsamplingModule(nn.Module):
indexes, reorder = indexes.sort(dim=-1)
weights = torch.gather(weights, dim=-1, index=reorder)
x_downsampled = downsample(indexes, x)
x_downsampled = self.downsample(x, indexes)
return indexes, weights, x_downsampled
def downsample(x: Tensor, indexes: Tensor) -> Tensor:
def downsample(self, x: Tensor, indexes: Tensor) -> Tensor:
"""
Downsamples x via indexing with the indexes obtained from the
forward() function.
@ -917,19 +900,19 @@ class LearnedDownsamplingModule(nn.Module):
Returns:
x_downsampled, of shape (seq_len_reduced, batch_size, num_channels)
"""
indexes = indexes.t().unsqueeze(-1).expand(-1, -1, x.shape[-1])
# indexes now: (seq_len_reduced, batch_size, num_channels)
ans = torch.gather(x, dim=0, index=indexes)
indexes_expanded = indexes.t().unsqueeze(-1).expand(-1, -1, x.shape[-1])
# indexe_expanded: (seq_len_reduced, batch_size, num_channels)
ans = torch.gather(x, dim=0, index=indexes_expanded)
if __name__ == __main__:
if __name__ == '__main__':
# temp, for testing
x_reconstructed = upsample(x, ans, indexes)
x_reconstructed = self.upsample(x, ans, indexes)
assert torch.allclose(x, x_reconstructed)
return ans
def downsample_pos_emb(pos_emb: Tensor, indexes: Tensor) -> Tensor:
def downsample_pos_emb(self, pos_emb: Tensor, indexes: Tensor) -> Tensor:
"""
Downsample positional embedding tensor with the provided indexes.
Args:
@ -958,7 +941,8 @@ class LearnedDownsamplingModule(nn.Module):
return pos_emb
def downsample_attn_offset(attn_offset: Tensor,
def downsample_attn_offset(self,
attn_offset: Tensor,
indexes: Tensor,
weights: Tensor,
eps: float = 1.0e-05) -> Tensor:
@ -979,18 +963,17 @@ class LearnedDownsamplingModule(nn.Module):
assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len)
attn_offset = attn_offset.expand(batch_size, seq_len, seq_len)
attn_offset = attn_offset.gather(dim=1, src=indices.unsqueeze(-1).expand(
attn_offset = attn_offset.gather(dim=1, index=indexes.unsqueeze(-1).expand(
batch_size, seq_len_reduced, seq_len))
attn_offset = attn_offset.gather(dim=2, src=indices.unsqueeze(1).expand(
attn_offset = attn_offset.gather(dim=2, index=indexes.unsqueeze(1).expand(
batch_size, seq_len_reduced, seq_len_reduced))
# unsqueeze at position 1 so the extra cost relates to the source position.
attn_offset = attn_offset + weights.clamp(min=eps).log().unsqueeze(1)
return attn_offst
return attn_offset
def upsample(x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor:
def upsample(self, x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor:
"""
Upsamples, reversing the downsample() operation and filling in
any not-chosen frames with their original value before downsampling
@ -1013,14 +996,14 @@ class LearnedDownsamplingModule(nn.Module):
not_kept = torch.ones(batch_size, seq_len, dtype=torch.bool,
device=x.device)
not_kept.scatter_(src=False, dim=1, index=indexes)
not_kept.scatter_(dim=1, index=indexes, value=False)
indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels)
# indexes now: (seq_len_reduced, batch_size, num_channels)
ans = torch.zeros_like(x_orig)
ans.scatter_(x, dim=0, index=indexes)
ans.scatter_(dim=0, index=indexes, src=x)
# add in x_orig in the frames that were not originally kept.
return ans + x_orig * not_kept.t().unsqueeze(-1)
@ -1051,7 +1034,6 @@ class DownsampledSubformer2Encoder(nn.Module):
pos_emb: Tensor,
feature_mask: Union[Tensor, float] = 1.0,
attn_offset: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
@ -1065,8 +1047,6 @@ class DownsampledSubformer2Encoder(nn.Module):
attn_offset: the attention offset, added to scores for attention of shape
(batch_size, seq_len, seq_len) or (seq_len, seq_len),
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None.
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
attention), of shape (batch_size, memory_len); True means
@ -1079,30 +1059,24 @@ class DownsampledSubformer2Encoder(nn.Module):
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
attn_offset = self.downsample.downsample_attn_offset(attn_offset,
indexes,
weights.clamp(min=1.0e-05))
attn_offset = self.downsampler.downsample_attn_offset(attn_offset,
indexes,
weights)
src = self.encoder(
src,
os_emb,
pos_emb,
feature_mask=feature_mask,
attn_offset=attn_offset,
src_key_padding_mask=src_key_padding_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
src = self.upsample(src)
# remove any extra frames that are not a multiple of downsample_factor
src = src[:src_orig.shape[0]]
src = self.downsampler.upsample(src_orig, src, indexes)
return self.out_combiner(src_orig, src)
class CompactRelPositionalEncoding(torch.nn.Module):
"""
Relative positional encoding module. This version is "compact" meaning it is able to encode
@ -1123,6 +1097,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
Args:
embed_dim: Temporary embedding dimension used inside this module
pos_dim: Smaller positional-encoding dim used after a projecction.
dropout_rate: Dropout rate.
max_len: Maximum input length: just a heuristic for initialization.
length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
@ -1130,11 +1105,12 @@ class CompactRelPositionalEncoding(torch.nn.Module):
pos_dim: dimension at the output of this module.
"""
def __init__(
self, embed_dim: int,
self,
embed_dim: int,
pos_dim: int,
dropout_rate: FloatLike,
max_len: int = 1000,
length_factor: float = 1.0,
pos_dim: int = 4,
) -> None:
"""Construct a CompactRelPositionalEncoding object."""
super(CompactRelPositionalEncoding, self).__init__()
@ -1211,16 +1187,13 @@ class CompactRelPositionalEncoding(torch.nn.Module):
x (torch.Tensor): Input tensor (time, batch, num_channels_in)
Returns:
positional embedding, of shape (1, 2*time-1, pos_dim).
positional embedding, of shape (batch_size, 2*time-1, pos_dim).
"""
self.extend_pe(x)
seq_len = x.size(0)
pos_emb = self.pe[
self.pe.size(0) // 2
- seq_len,
+ 1 : self.pe.size(0) // 2 # noqa E203
+ seq_len,
self.pe.size(0) // 2 - seq_len + 1 : self.pe.size(0) // 2 + seq_len,
:
]
pos_emb = pos_emb.unsqueeze(0)
@ -1230,12 +1203,20 @@ class CompactRelPositionalEncoding(torch.nn.Module):
# currenly pos_emb: (1, 2*seq_len-1, pos_dim)
pos_dim = pos_emb.shape[-1]
batch_size = x.size(1)
(_, seq_stride, channel_stride) = pos_emb.stride()
# it doesn't really matter which one we make positive and which negative here, it
# would just flip the meaning of the embedding.
# expand the '1' dimension to seq_len; this introduces a dimension that
# 'does nothing', just creates copies, as a workaround for lack of torch support
# for negative strides.
pos_emb = pos_emb.expand(seq_len, 2*seq_len-1, pos_dim).contiguous()
(useless_stride, seq_stride, channel_stride) = pos_emb.stride()
pos_emb = pos_emb.as_strided((batch_size, seq_len, seq_len, pos_dim),
(0, -seq_stride, seq_stride, channel_stride),
storage_offset=seq_stride * (seqs_len - 1))
(0, useless_stride-seq_stride, seq_stride, channel_stride),
storage_offset=seq_stride * (seq_len - 1))
return pos_emb # (batch_size, seq_len, seq_len, pos_dim)
@ -1326,8 +1307,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
x: Tensor,
pos_emb: Tensor,
attn_offset: Optional[Tensor] = None,
pos_emb: Tensor,
quadratic_pos_weight: Tensor,
) -> Tensor:
r"""
Args:
@ -1368,35 +1347,23 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
p = p.reshape(seq_len, batch_size, num_heads, pos_dim)
k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
# time1 refers to target, time2 refers to source.
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
q = q.permute(2, 1, 0, 3) # (head, batch, tgt_seq_len, query_head_dim)
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, src_seq_len)
# attn_scores: (num_heads, batch_size, tgt_seq_len, src_esq_len)
attn_scores = torch.matmul(q, k)
if not self.training or random.random() >= float(self.pos_emb_skip_rate):
pos_emb = self.linear_pos(pos_emb)
seq_len2 = 2 * seq_len - 1
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
# [where seq_len2 represents relative position.]
pos_scores = torch.matmul(p, pos_emb)
# the following .as_strided() expression converts the last axis of pos_scores from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be.
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
(pos_scores.stride(0),
pos_scores.stride(1),
pos_scores.stride(2)-pos_scores.stride(3),
pos_scores.stride(3)),
storage_offset=pos_scores.stride(3) * (seq_len - 1))
# pos_emb: (batch_size, tgt_seq_len, src_seq_len, pos_dim)
p = p.permute(1, 0, 3, 2) # (batch_size, tgt_seq_len, pos_dim, num_heads)
pos_scores = torch.matmul(pos_emb, p)
# pos_scores: (batch_size, tgt_seq_len, src_seq_len, num_heads)
pos_scores = pos_scores.permute(3, 0, 1, 2)
# pos_scores: (num_heads, batch_size, tgt_seq_len, src_seq_len)
attn_scores = attn_scores + pos_scores
if self.training and random.random() < 0.1:
@ -1417,23 +1384,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
penalty=1.0e-04,
name=self.name)
# attn_offset includes key-padding mask and attention-mask, plus any weights
# from the subsampling.
attn_scores = attn_scores + attn_offset
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
if attn_mask is not None:
assert attn_mask.dtype == torch.bool
# use -1000 to avoid nan's where attn_mask and key_padding_mask make
# all scores zero. It's important that this be large enough that exp(-1000)
# is exactly zero, for reasons related to const_attention_rate, it
# compares the final weights with zero.
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
attn_scores = attn_scores.masked_fill(
key_padding_mask.unsqueeze(1),
-1000,
)
# We use our own version of softmax, defined in scaling.py, which should
# save a little of the memory used in backprop by, if we are in
# automatic mixed precision mode (amp / autocast), by only storing the
@ -1617,9 +1573,9 @@ class MultiheadAttentionWeights(nn.Module):
q = q.reshape(query_len, batch_size, num_heads, head_dim)
k = k.reshape(key_len, batch_size, num_heads, head_dim)
# time1 refers to target, time2 refers to source.
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
# tgt_seq_len refers to target, src_seq_len refers to source.
q = q.permute(2, 1, 0, 3) # (head, batch, tgt_seq_len, query_head_dim)
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, src_seq_len)
attn_scores = torch.matmul(q, k)
@ -1842,8 +1798,6 @@ def _test_zipformer_main(causal: bool = False):
c = Subformer2(
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
causal=causal,
chunk_size=(4,) if causal else (-1,),
left_context_frames=(64,),
memory_dim=memory_dim,
)
batch_size = 5

View File

@ -63,7 +63,7 @@ from lm_datamodule import LmDataset, LmDataloader
from zipformer import Zipformer2
from scaling import ScheduledFloat
from lhotse.utils import fix_random_seed
from chunk_decoder import ChunkDecoder
from decoder import Decoder
from model import Zipformer2LM
from optim import Eden, ScaledAdam
from torch import Tensor
@ -176,13 +176,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list."
)
parser.add_argument(
"--pos-dim",
type=int,
default="48",
help="Positional-encoding embedding dimension"
)
parser.add_argument(
"--encoder-unmasked-dim",
type=str,
@ -505,9 +498,9 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
def get_encoder_model(params: AttributeDict) -> nn.Module:
chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
#chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
encoder = Zipformer2(
output_downsampling_factor=chunk_size,
#output_downsampling_factor=chunk_size,
downsampling_factor=_to_int_tuple(params.downsampling_factor),
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
encoder_dim=_to_int_tuple(params.encoder_dim),
@ -515,10 +508,8 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
query_head_dim=_to_int_tuple(params.query_head_dim),
pos_head_dim=_to_int_tuple(params.pos_head_dim),
value_head_dim=_to_int_tuple(params.value_head_dim),
pos_dim=params.pos_dim,
num_heads=_to_int_tuple(params.num_heads),
feedforward_dim=_to_int_tuple(params.feedforward_dim),
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
causal=True,
@ -529,13 +520,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
def get_decoder_model(params: AttributeDict) -> nn.Module:
chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
decoder = ChunkDecoder(
decoder = DecoderDecoder(
embed_dim=max(_to_int_tuple(params.encoder_dim)),
chunk_size=chunk_size,
vocab_size=256, # bytes
hidden_size=params.decoder_hidden_size,
num_layers=params.decoder_num_layers,
)
return decoder