From 1b8be0744fa4ba2759690ceb95db59beb9260a78 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Mon, 15 May 2023 15:20:02 +0800 Subject: [PATCH] Fix various bugs --- egs/libriheavy/LM/zipformer1/model.py | 12 +- egs/libriheavy/LM/zipformer1/subformer.py | 232 +++++++++------------- egs/libriheavy/LM/zipformer1/train.py | 21 +- 3 files changed, 102 insertions(+), 163 deletions(-) diff --git a/egs/libriheavy/LM/zipformer1/model.py b/egs/libriheavy/LM/zipformer1/model.py index 43fc715dd..b33492868 100644 --- a/egs/libriheavy/LM/zipformer1/model.py +++ b/egs/libriheavy/LM/zipformer1/model.py @@ -19,7 +19,6 @@ import torch from torch import nn, Tensor -from chunk_decoder import ChunkDecoder from zipformer import Zipformer2 @@ -28,7 +27,7 @@ class Zipformer2LM(nn.Module): def __init__(self, encoder_embed: nn.Module, encoder: Zipformer2, - decoder: ChunkDecoder): + decoder: nn.Module): super().__init__() self.encoder_embed = encoder_embed self.encoder = encoder # does subsampling @@ -47,18 +46,17 @@ class Zipformer2LM(nn.Module): """ (batch_size, seq_len) = labels.shape - chunk_size = self.decoder.chunk_size + chunk_size = 1 labels_shifted = labels.t() # (time, batch) - labels_shifted = torch.cat((torch.zeros_like(labels_shifted[:chunk_size]), - labels_shifted[:-chunk_size]), + labels_shifted = torch.cat((torch.zeros_like(labels_shifted[:1]), + labels_shifted[:-1]), dim=0) x = self.encoder_embed(labels_shifted) x_lens = torch.full((batch_size,), seq_len, dtype=torch.long, device=labels.device) + # x_lens is after subsampling. Actually we don't need it. - - (x, x_lens) = self.encoder(x, x_lens) logprobs = self.decoder(labels, x) diff --git a/egs/libriheavy/LM/zipformer1/subformer.py b/egs/libriheavy/LM/zipformer1/subformer.py index c29f71d1a..f42a89a88 100644 --- a/egs/libriheavy/LM/zipformer1/subformer.py +++ b/egs/libriheavy/LM/zipformer1/subformer.py @@ -76,11 +76,7 @@ class Subformer2(EncoderInterface): dropout (float): dropout rate warmup_batches (float): number of batches to warm up over; this controls dropout of encoder layers. - causal (bool): if True, support chunkwise causal convolution. This should - not hurt WER as no modeling power is lost, but the convolution modules will be - slightly slower and use more memory. Enables use of the chunk_size and - left_context_chunks options in forward(), which simulates streaming - decoding. + causal (bool): if True, use causal attention-mask. memory_dim: if supplied and >0, will be the dimension of the memory embeddings passed into the zipformer (e.g. this might be the output of another Subformer used to create embedding vectors.) @@ -97,7 +93,6 @@ class Subformer2(EncoderInterface): num_heads: Union[int, Tuple[int]] = 8, feedforward_dim: Union[int, Tuple[int]] = 1536, memory_dim: int = -1, - pos_emb_dim: int = 192, pos_dim: int = 4, dropout: FloatLike = None, # see code below for default warmup_batches: float = 4000.0, @@ -129,6 +124,7 @@ class Subformer2(EncoderInterface): value_head_dim = _to_tuple(value_head_dim) num_heads = _to_tuple(num_heads) feedforward_dim = _to_tuple(feedforward_dim) + self.causal = causal for u,d in zip(encoder_unmasked_dim, encoder_dim): assert u <= d @@ -156,7 +152,6 @@ class Subformer2(EncoderInterface): encoder = Subformer2Encoder( encoder_layer, num_encoder_layers[i], - pos_dim=pos_dim, dropout=dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), @@ -173,14 +168,15 @@ class Subformer2(EncoderInterface): encoders.append(encoder) - self.encoder_pos = CompactRelPositionalEncoding(pos_emb_dim, pos_dim, dropout_rate=0.15, + self.encoder_pos = CompactRelPositionalEncoding(64, pos_dim, + dropout_rate=0.15, length_factor=1.0) self.encoders = nn.ModuleList(encoders) - self.downsample_output = SimpleDownsample(max(encoder_dim), - downsample=output_downsampling_factor, - dropout=dropout) + #self.downsample_output = SimpleDownsample(max(encoder_dim), + # downsample=output_downsampling_factor, + # dropout=dropout) def get_feature_masks( self, @@ -273,7 +269,7 @@ class Subformer2(EncoderInterface): outputs = [] feature_masks = self.get_feature_masks(x) - attn_offset = self._get_attn_offset(x) + attn_offset = self._get_attn_offset(x, src_key_padding_mask) if self.training and memory is not None: batch_size = x.shape[1] @@ -286,15 +282,11 @@ class Subformer2(EncoderInterface): pos_emb = self.encoder_pos(x) for i, module in enumerate(self.encoders): - ds = self.downsampling_factor[i] x = convert_num_channels(x, self.encoder_dim[i]) x = module(x, pos_emb, - chunk_size=chunk_size, feature_mask=feature_masks[i], - src_key_padding_mask=(None if src_key_padding_mask is None - else src_key_padding_mask[...,::ds]), attn_offset=attn_offset, memory=memory, memory_key_padding_mask=memory_key_padding_mask, @@ -321,37 +313,37 @@ class Subformer2(EncoderInterface): # from different pieces of 'outputs', taking each dimension from the # most recent output that has it present. x = get_full_dim_output() - x = self.downsample_output(x) + #x = self.downsample_output(x) d = self.output_downsampling_factor lengths = (x_lens + d - 1) // d return x, lengths - def _get_attn_offset(self, x: Tensor) -> Optional[Tensor]: + def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]: """ - Return attention offset of shape (1, seq_len, seq_len), interpreted as (tgt_seq_len, + Return attention offset of shape (1 or batch_size, seq_len, seq_len), interpreted as (1 or batch_size, tgt_seq_len, src_seq_len); this reflects masking, if causal == True, otherwise will be all zeros. Args: x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim). - chunk_size: chunk size, must divide + src_key_padding_mask: optional key-padding mask of shape (batch_size, seq_len) with True in masked positions. """ - if not self.causal: - return None - - seq_len = x.shape[0] - - # t is frame index, shape (seq_len,) - t = torch.arange(seq_len, dtype=torch.int32, device=x.device) - src_c = c - tgt_c = c.unsqueeze(-1) - - attn_mask = (src_c > tgt_c) + seq_len, batch_size, _num_channels = x.shape ans = torch.zeros(1, seq_len, seq_len, device=x.device) - ans.masked_fill(attn_mask, float('-inf')) + if self.causal: + # t is frame index, shape (seq_len,) + t = torch.arange(seq_len, dtype=torch.int32, device=x.device) + src_t = t + tgt_t = t.unsqueeze(-1) + attn_mask = (src_t > tgt_t) + ans.masked_fill(attn_mask, float('-inf')) + + if src_key_padding_mask is not None: + ans = ans * src_key_padding_mask.unsqueeze(1).logical_not() + # now ans: (batch_size, seq_len, seq_len). return ans @@ -384,11 +376,10 @@ class Subformer2EncoderLayer(nn.Module): def __init__( self, embed_dim: int, - pos_dim: int, num_heads: int, query_head_dim: int, - pos_dim: int, value_head_dim: int, + pos_dim: int, feedforward_dim: int, dropout: FloatLike = 0.1, causal: bool = False, @@ -431,14 +422,15 @@ class Subformer2EncoderLayer(nn.Module): self.self_attn1 = Attention(embed_dim, embed_dim, num_heads, - value_head_dim) + value_head_dim) self.self_attn2 = Attention(embed_dim, embed_dim, num_heads, value_head_dim) if memory_dim > 0: self.attn_weights = MultiheadAttentionWeights( - memory_dim, embed_dim, + memory_dim, + embed_dim, num_heads=num_heads, head_dim=query_head_dim, dropout=0.0, @@ -559,7 +551,6 @@ class Subformer2EncoderLayer(nn.Module): self, src: Tensor, pos_emb: Tensor, - chunk_size: int = -1, attn_offset: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, memory: Optional[Tensor] = None, @@ -570,7 +561,6 @@ class Subformer2EncoderLayer(nn.Module): Args: src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim) - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. feature_mask: something that broadcasts with src, that we'll multiply `src` by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) attn_offset: the attention offset, of shape broadcasting with (batch_size, seq_len, seq_len), @@ -591,7 +581,6 @@ class Subformer2EncoderLayer(nn.Module): src, pos_emb=pos_emb, attn_offset=attn_offset, - key_padding_mask=src_key_padding_mask, ) if memory is not None and hasattr(self, 'attn_weights'): @@ -662,7 +651,6 @@ class Subformer2Encoder(nn.Module): Args: encoder_layer: an instance of the Subformer2EncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). - pos_dim: the dimension for the relative positional encoding Examples:: >>> encoder_layer = Subformer2EncoderLayer(embed_dim=512, nhead=8) @@ -674,7 +662,6 @@ class Subformer2Encoder(nn.Module): self, encoder_layer: nn.Module, num_layers: int, - pos_dim: int, dropout: float, warmup_begin: float, warmup_end: float, @@ -701,10 +688,9 @@ class Subformer2Encoder(nn.Module): def forward( self, src: Tensor, - chunk_size: int = -1, + pos_emb: Tensor, feature_mask: Union[Tensor, float] = 1.0, attn_offset: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, memory: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: @@ -712,14 +698,13 @@ class Subformer2Encoder(nn.Module): Args: src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim). - chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking. + pos_emb: positional embedding tensor, of shape (batch_size, seq_len, seq_len, pos_dim), + e.g. pos_dim=4. feature_mask: something that broadcasts with src, that we'll multiply `src` by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim) attn_offset: the attention offset (does masking and related tasks), of shape broadcasting with (batch_size, seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len). - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) memory_key_padding_mask: optionally the mask for padding of memory input (for source- attention), of shape (batch_size, memory_len); True means @@ -727,7 +712,6 @@ class Subformer2Encoder(nn.Module): Returns: a Tensor with the same shape as src. """ - pos_emb = self.encoder_pos(src) output = src rnd_seed = src.numel() + random.randint(0, 1000) @@ -738,9 +722,7 @@ class Subformer2Encoder(nn.Module): output = mod( output, pos_emb, - chunk_size=chunk_size, attn_offset=attn_offset, - src_key_padding_mask=src_key_padding_mask, memory=memory, memory_key_padding_mask=memory_key_padding_mask, ) @@ -827,12 +809,13 @@ class LearnedDownsamplingModule(nn.Module): embed_dim: int, downsampling_factor: int, intermediate_rate: FloatLike = 0.2): + super().__init__() self.to_scores = nn.Linear(embed_dim, 1, bias=False) # score_balancer is just to keep the magnitudes of the scores in # a fixed range and keep them balanced around zero, to stop # these drifting around. self.score_balancer = Balancer(1, channel_dim=-1, - min_positive=0.4, max_positive=0.6 + min_positive=0.4, max_positive=0.6, min_abs=1.0, max_abs=1.2, prob=0.025) @@ -856,14 +839,14 @@ class LearnedDownsamplingModule(nn.Module): corresponding to the kept frames; these will be between 0 and 1, but mostly exactly 1. """ - (seq_len, batch_size, _) + (seq_len, batch_size, _) = x.shape scores = self.to_scores(x) # (seq_len, batch_size, 1) scores = self.score_balancer(scores) scores = scores.squeeze(-1).t() # (batch_size, seq_len) - # indexes, sscores: (batch_size, seq_len) - indexes, sscores = scores.sort(dim=-1, descending=True) + # sscores, indexes: (batch_size, seq_len) + sscores, indexes = scores.sort(dim=-1, descending=True) d = self.downsampling_factor seq_len_reduced = (seq_len + d - 1) // d @@ -883,10 +866,10 @@ class LearnedDownsamplingModule(nn.Module): collar = max(1, int(seq_len_reduced * 0.5 * self.intermediate_rate)) # right_avg: shape (batch_size,), this is to be mapped to 0.0 - right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1) + right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1, keepdim=True) # left_avg: shape (batch_size,), this is to be mapped to 1.0 - left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1) + left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1, keepdim=True) # the + 0.001 is to avoid possible division by zero in case of ties. weights = (sscores - right_avg) / (left_avg - right_avg + 0.001) @@ -901,11 +884,11 @@ class LearnedDownsamplingModule(nn.Module): indexes, reorder = indexes.sort(dim=-1) weights = torch.gather(weights, dim=-1, index=reorder) - x_downsampled = downsample(indexes, x) + x_downsampled = self.downsample(x, indexes) return indexes, weights, x_downsampled - def downsample(x: Tensor, indexes: Tensor) -> Tensor: + def downsample(self, x: Tensor, indexes: Tensor) -> Tensor: """ Downsamples x via indexing with the indexes obtained from the forward() function. @@ -917,19 +900,19 @@ class LearnedDownsamplingModule(nn.Module): Returns: x_downsampled, of shape (seq_len_reduced, batch_size, num_channels) """ - indexes = indexes.t().unsqueeze(-1).expand(-1, -1, x.shape[-1]) - # indexes now: (seq_len_reduced, batch_size, num_channels) - ans = torch.gather(x, dim=0, index=indexes) + indexes_expanded = indexes.t().unsqueeze(-1).expand(-1, -1, x.shape[-1]) + # indexe_expanded: (seq_len_reduced, batch_size, num_channels) + ans = torch.gather(x, dim=0, index=indexes_expanded) - if __name__ == __main__: + if __name__ == '__main__': # temp, for testing - x_reconstructed = upsample(x, ans, indexes) + x_reconstructed = self.upsample(x, ans, indexes) assert torch.allclose(x, x_reconstructed) return ans - def downsample_pos_emb(pos_emb: Tensor, indexes: Tensor) -> Tensor: + def downsample_pos_emb(self, pos_emb: Tensor, indexes: Tensor) -> Tensor: """ Downsample positional embedding tensor with the provided indexes. Args: @@ -958,7 +941,8 @@ class LearnedDownsamplingModule(nn.Module): return pos_emb - def downsample_attn_offset(attn_offset: Tensor, + def downsample_attn_offset(self, + attn_offset: Tensor, indexes: Tensor, weights: Tensor, eps: float = 1.0e-05) -> Tensor: @@ -979,18 +963,17 @@ class LearnedDownsamplingModule(nn.Module): assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len) attn_offset = attn_offset.expand(batch_size, seq_len, seq_len) - - attn_offset = attn_offset.gather(dim=1, src=indices.unsqueeze(-1).expand( + attn_offset = attn_offset.gather(dim=1, index=indexes.unsqueeze(-1).expand( batch_size, seq_len_reduced, seq_len)) - attn_offset = attn_offset.gather(dim=2, src=indices.unsqueeze(1).expand( + attn_offset = attn_offset.gather(dim=2, index=indexes.unsqueeze(1).expand( batch_size, seq_len_reduced, seq_len_reduced)) # unsqueeze at position 1 so the extra cost relates to the source position. attn_offset = attn_offset + weights.clamp(min=eps).log().unsqueeze(1) - return attn_offst + return attn_offset - def upsample(x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor: + def upsample(self, x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor: """ Upsamples, reversing the downsample() operation and filling in any not-chosen frames with their original value before downsampling @@ -1013,14 +996,14 @@ class LearnedDownsamplingModule(nn.Module): not_kept = torch.ones(batch_size, seq_len, dtype=torch.bool, device=x.device) - not_kept.scatter_(src=False, dim=1, index=indexes) + not_kept.scatter_(dim=1, index=indexes, value=False) indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels) # indexes now: (seq_len_reduced, batch_size, num_channels) ans = torch.zeros_like(x_orig) - ans.scatter_(x, dim=0, index=indexes) + ans.scatter_(dim=0, index=indexes, src=x) # add in x_orig in the frames that were not originally kept. return ans + x_orig * not_kept.t().unsqueeze(-1) @@ -1051,7 +1034,6 @@ class DownsampledSubformer2Encoder(nn.Module): pos_emb: Tensor, feature_mask: Union[Tensor, float] = 1.0, attn_offset: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, memory: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: @@ -1065,8 +1047,6 @@ class DownsampledSubformer2Encoder(nn.Module): attn_offset: the attention offset, added to scores for attention of shape (batch_size, seq_len, seq_len) or (seq_len, seq_len), interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len). - src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) memory_key_padding_mask: optionally the mask for padding of memory input (for source- attention), of shape (batch_size, memory_len); True means @@ -1079,30 +1059,24 @@ class DownsampledSubformer2Encoder(nn.Module): pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes) - attn_offset = self.downsample.downsample_attn_offset(attn_offset, - indexes, - weights.clamp(min=1.0e-05)) - + attn_offset = self.downsampler.downsample_attn_offset(attn_offset, + indexes, + weights) src = self.encoder( src, - os_emb, + pos_emb, feature_mask=feature_mask, attn_offset=attn_offset, - src_key_padding_mask=src_key_padding_mask, memory=memory, memory_key_padding_mask=memory_key_padding_mask, ) - src = self.upsample(src) - # remove any extra frames that are not a multiple of downsample_factor - src = src[:src_orig.shape[0]] + src = self.downsampler.upsample(src_orig, src, indexes) return self.out_combiner(src_orig, src) - - class CompactRelPositionalEncoding(torch.nn.Module): """ Relative positional encoding module. This version is "compact" meaning it is able to encode @@ -1123,6 +1097,7 @@ class CompactRelPositionalEncoding(torch.nn.Module): Args: embed_dim: Temporary embedding dimension used inside this module + pos_dim: Smaller positional-encoding dim used after a projecction. dropout_rate: Dropout rate. max_len: Maximum input length: just a heuristic for initialization. length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives @@ -1130,11 +1105,12 @@ class CompactRelPositionalEncoding(torch.nn.Module): pos_dim: dimension at the output of this module. """ def __init__( - self, embed_dim: int, + self, + embed_dim: int, + pos_dim: int, dropout_rate: FloatLike, max_len: int = 1000, length_factor: float = 1.0, - pos_dim: int = 4, ) -> None: """Construct a CompactRelPositionalEncoding object.""" super(CompactRelPositionalEncoding, self).__init__() @@ -1211,16 +1187,13 @@ class CompactRelPositionalEncoding(torch.nn.Module): x (torch.Tensor): Input tensor (time, batch, num_channels_in) Returns: - positional embedding, of shape (1, 2*time-1, pos_dim). + positional embedding, of shape (batch_size, 2*time-1, pos_dim). """ self.extend_pe(x) seq_len = x.size(0) pos_emb = self.pe[ - self.pe.size(0) // 2 - - seq_len, - + 1 : self.pe.size(0) // 2 # noqa E203 - + seq_len, + self.pe.size(0) // 2 - seq_len + 1 : self.pe.size(0) // 2 + seq_len, : ] pos_emb = pos_emb.unsqueeze(0) @@ -1230,12 +1203,20 @@ class CompactRelPositionalEncoding(torch.nn.Module): # currenly pos_emb: (1, 2*seq_len-1, pos_dim) pos_dim = pos_emb.shape[-1] batch_size = x.size(1) - (_, seq_stride, channel_stride) = pos_emb.stride() # it doesn't really matter which one we make positive and which negative here, it # would just flip the meaning of the embedding. + + + # expand the '1' dimension to seq_len; this introduces a dimension that + # 'does nothing', just creates copies, as a workaround for lack of torch support + # for negative strides. + pos_emb = pos_emb.expand(seq_len, 2*seq_len-1, pos_dim).contiguous() + + (useless_stride, seq_stride, channel_stride) = pos_emb.stride() + pos_emb = pos_emb.as_strided((batch_size, seq_len, seq_len, pos_dim), - (0, -seq_stride, seq_stride, channel_stride), - storage_offset=seq_stride * (seqs_len - 1)) + (0, useless_stride-seq_stride, seq_stride, channel_stride), + storage_offset=seq_stride * (seq_len - 1)) return pos_emb # (batch_size, seq_len, seq_len, pos_dim) @@ -1326,8 +1307,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module): x: Tensor, pos_emb: Tensor, attn_offset: Optional[Tensor] = None, - pos_emb: Tensor, - quadratic_pos_weight: Tensor, ) -> Tensor: r""" Args: @@ -1368,35 +1347,23 @@ class RelPositionMultiheadAttentionWeights(nn.Module): q = q.reshape(seq_len, batch_size, num_heads, query_head_dim) - p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim) + p = p.reshape(seq_len, batch_size, num_heads, pos_dim) k = k.reshape(seq_len, batch_size, num_heads, query_head_dim) - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + q = q.permute(2, 1, 0, 3) # (head, batch, tgt_seq_len, query_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, src_seq_len) + # attn_scores: (num_heads, batch_size, tgt_seq_len, src_esq_len) attn_scores = torch.matmul(q, k) if not self.training or random.random() >= float(self.pos_emb_skip_rate): - pos_emb = self.linear_pos(pos_emb) - seq_len2 = 2 * seq_len - 1 - pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1) - # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - - # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # [where seq_len2 represents relative position.] - pos_scores = torch.matmul(p, pos_emb) - # the following .as_strided() expression converts the last axis of pos_scores from relative - # to absolute position. I don't know whether I might have got the time-offsets backwards or - # not, but let this code define which way round it is supposed to be. - pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len), - (pos_scores.stride(0), - pos_scores.stride(1), - pos_scores.stride(2)-pos_scores.stride(3), - pos_scores.stride(3)), - storage_offset=pos_scores.stride(3) * (seq_len - 1)) + # pos_emb: (batch_size, tgt_seq_len, src_seq_len, pos_dim) + p = p.permute(1, 0, 3, 2) # (batch_size, tgt_seq_len, pos_dim, num_heads) + pos_scores = torch.matmul(pos_emb, p) + # pos_scores: (batch_size, tgt_seq_len, src_seq_len, num_heads) + pos_scores = pos_scores.permute(3, 0, 1, 2) + # pos_scores: (num_heads, batch_size, tgt_seq_len, src_seq_len) attn_scores = attn_scores + pos_scores if self.training and random.random() < 0.1: @@ -1417,23 +1384,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module): penalty=1.0e-04, name=self.name) + # attn_offset includes key-padding mask and attention-mask, plus any weights + # from the subsampling. + attn_scores = attn_scores + attn_offset + assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len) - if attn_mask is not None: - assert attn_mask.dtype == torch.bool - # use -1000 to avoid nan's where attn_mask and key_padding_mask make - # all scores zero. It's important that this be large enough that exp(-1000) - # is exactly zero, for reasons related to const_attention_rate, it - # compares the final weights with zero. - attn_scores = attn_scores.masked_fill(attn_mask, -1000) - - if key_padding_mask is not None: - assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape - attn_scores = attn_scores.masked_fill( - key_padding_mask.unsqueeze(1), - -1000, - ) - # We use our own version of softmax, defined in scaling.py, which should # save a little of the memory used in backprop by, if we are in # automatic mixed precision mode (amp / autocast), by only storing the @@ -1617,9 +1573,9 @@ class MultiheadAttentionWeights(nn.Module): q = q.reshape(query_len, batch_size, num_heads, head_dim) k = k.reshape(key_len, batch_size, num_heads, head_dim) - # time1 refers to target, time2 refers to source. - q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) - k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + # tgt_seq_len refers to target, src_seq_len refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, tgt_seq_len, query_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, src_seq_len) attn_scores = torch.matmul(q, k) @@ -1842,8 +1798,6 @@ def _test_zipformer_main(causal: bool = False): c = Subformer2( encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), causal=causal, - chunk_size=(4,) if causal else (-1,), - left_context_frames=(64,), memory_dim=memory_dim, ) batch_size = 5 diff --git a/egs/libriheavy/LM/zipformer1/train.py b/egs/libriheavy/LM/zipformer1/train.py index 1b9631ea5..08460b432 100755 --- a/egs/libriheavy/LM/zipformer1/train.py +++ b/egs/libriheavy/LM/zipformer1/train.py @@ -63,7 +63,7 @@ from lm_datamodule import LmDataset, LmDataloader from zipformer import Zipformer2 from scaling import ScheduledFloat from lhotse.utils import fix_random_seed -from chunk_decoder import ChunkDecoder +from decoder import Decoder from model import Zipformer2LM from optim import Eden, ScaledAdam from torch import Tensor @@ -176,13 +176,6 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list." ) - parser.add_argument( - "--pos-dim", - type=int, - default="48", - help="Positional-encoding embedding dimension" - ) - parser.add_argument( "--encoder-unmasked-dim", type=str, @@ -505,9 +498,9 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module: - chunk_size = _to_int_tuple(params.downsampling_factor)[-1] + #chunk_size = _to_int_tuple(params.downsampling_factor)[-1] encoder = Zipformer2( - output_downsampling_factor=chunk_size, + #output_downsampling_factor=chunk_size, downsampling_factor=_to_int_tuple(params.downsampling_factor), num_encoder_layers=_to_int_tuple(params.num_encoder_layers), encoder_dim=_to_int_tuple(params.encoder_dim), @@ -515,10 +508,8 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: query_head_dim=_to_int_tuple(params.query_head_dim), pos_head_dim=_to_int_tuple(params.pos_head_dim), value_head_dim=_to_int_tuple(params.value_head_dim), - pos_dim=params.pos_dim, num_heads=_to_int_tuple(params.num_heads), feedforward_dim=_to_int_tuple(params.feedforward_dim), - cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel), dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), warmup_batches=4000.0, causal=True, @@ -529,13 +520,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: def get_decoder_model(params: AttributeDict) -> nn.Module: - chunk_size = _to_int_tuple(params.downsampling_factor)[-1] - decoder = ChunkDecoder( + decoder = DecoderDecoder( embed_dim=max(_to_int_tuple(params.encoder_dim)), - chunk_size=chunk_size, vocab_size=256, # bytes - hidden_size=params.decoder_hidden_size, - num_layers=params.decoder_num_layers, ) return decoder