diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 6d203e994..137273c56 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -26,15 +26,16 @@ from icefall.utils import add_sos, make_pad_mask from scaling import penalize_abs_values_gt, ScaledLinear -class Transducer(nn.Module): +class PromptedTransducer(nn.Module): """It implements https://arxiv.org/pdf/1211.3711.pdf "Sequence Transduction with Recurrent Neural Networks" """ - def __init__( self, encoder_embed: nn.Module, encoder: EncoderInterface, + text_embed: nn.Module, + text_encoder: EncoderInterface, decoder: nn.Module, joiner: nn.Module, encoder_dim: int, @@ -68,6 +69,8 @@ class Transducer(nn.Module): self.encoder_embed = encoder_embed self.encoder = encoder + self.text_embed = text_embed + self.text_encoder = text_encoder self.decoder = decoder self.joiner = joiner @@ -86,6 +89,9 @@ class Transducer(nn.Module): self, x: torch.Tensor, x_lens: torch.Tensor, + text: torch.Tensor, + style_lens: torch.Tensor, + text_lens: torch.Tensor, y: k2.RaggedTensor, prune_range: int = 5, am_scale: float = 0.0, @@ -98,6 +104,21 @@ class Transducer(nn.Module): x_lens: A 1-D tensor of shape (N,). It contains the number of frames in `x` before padding. + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + text: + A 2-D tensor of integer dtype containing prompt text, of shape (N, T). + It is exptected to contain the style prompt (first) and then the content + prompt. + style_lens: + A 1-D tensor of shape (N,), containing the number of elements (bytes) + within each row of `text` that correspond to the style prompt (these + are expected to come first). + text_lens: + A 1-D tensor of shape (N,). It contains the number of elements (bytes) + in `text` before padding, which will include the lengths of the + style plus the content prompt. y: A ragged tensor with 2 axes [utt][label]. It contains labels of each utterance. @@ -125,14 +146,25 @@ class Transducer(nn.Module): assert x.size(0) == x_lens.size(0) == y.dim0 - # logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M") x, x_lens = self.encoder_embed(x, x_lens) - # logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M") src_key_padding_mask = make_pad_mask(x_lens) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) - encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask) + text = text.t() # now (T, N) + text = self.text_embed(text) # now (T, N, C) + text_key_padding_mask = make_pad_mask(text_lens) + + memory, text_lens = self.text_encoder(text, text_lens, + text_key_padding_mask) + + memory = self._add_style_indicator(memory, style_lens) + + memory_key_padding_mask = make_pad_mask(text_lens) + + encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) assert torch.all(x_lens > 0) @@ -217,3 +249,24 @@ class Transducer(nn.Module): ) return (simple_loss, pruned_loss) + + + def _add_style_indicator(self, memory: Tensor, style_lens: Tensor): + """ + Adds to `memory` an indicator that is 0.1 for positions that correspond to + the `style prompt` and 0 elsewhere. The scale can be fixed because the + scale of the memory vector can adjust to compensate (within limits set + by the balancers).. + + Args: + memory: (memory_len, batch_size, embed_dim) + style_lens: (batch_size,), a vector of lengths of the style prompt. + """ + + (memory_len, batch_size, embed_dim) = memory.shape + + + indicator = torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens + indicator = indicator.to(memory.dtype).unsqueeze(-1) + + return memory + indicator diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 4a4288f61..822723ea5 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -71,7 +71,9 @@ from lhotse.utils import fix_random_seed from model import Transducer from optim import Eden, ScaledAdam from torch import Tensor +from torch import nn from torch.cuda.amp import GradScaler + from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter @@ -159,6 +161,14 @@ def add_model_arguments(parser: argparse.ArgumentParser): help="Embedding dimension in encoder stacks: a single int or comma-separated list." ) + parser.add_argument( + "--text-encoder-dim", + type=str, + default="256,256,384,512", + help="Embedding dimension in text encoder stacks: a comma-separated list of 4 elements, " + "or you should change other configs in the code." + ) + parser.add_argument( "--query-head-dim", type=str, @@ -547,6 +557,32 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module: return encoder_embed +def get_text_embed(params: AttributeDict) -> nn.Module: + return nn.Embedding( + num_embeddings=256, # we encode the text as UTF-8 bytes + embedding_dim=_to_int_tuple(params.text_encoder_dim)[0], + ) + +def get_text_encoder(params: AttributeDict) -> nn.Module: + return Zipformer2( + output_downsampling_factor=8, + downsampling_factor=(1,2,4,8), + num_encoder_layers=(2,4,6,6), + encoder_dim=_to_int_tuple(params.text_encoder_dim), + encoder_unmasked_dim=(196,196,256,256), + query_head_dim=(32,32,32,32), + pos_head_dim=(4,4,4,4), + value_head_dim=(12,12,12,12), + pos_dim=48, + num_heads=(4,4,4,8), + feedforward_dim=(384,512,768,1024), # could increase this if there is nough data + cnn_module_kernel=(31,31,15,15), + dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)), + warmup_batches=4000.0, + causal=False, + ) + + def get_encoder_model(params: AttributeDict) -> nn.Module: encoder = Zipformer2( output_downsampling_factor=2, @@ -566,6 +602,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module: causal=params.causal, chunk_size=_to_int_tuple(params.chunk_size), left_context_frames=_to_int_tuple(params.left_context_frames), + memory_dim=_to_int_tuple(params.text_encoder_dim)[-1], ) return encoder @@ -593,12 +630,16 @@ def get_joiner_model(params: AttributeDict) -> nn.Module: def get_transducer_model(params: AttributeDict) -> nn.Module: encoder_embed = get_encoder_embed(params) encoder = get_encoder_model(params) + text_embed = get_text_embed(params) + text_encoder = get_text_encoder(params) decoder = get_decoder_model(params) joiner = get_joiner_model(params) model = Transducer( encoder_embed=encoder_embed, encoder=encoder, + text_embed=text_embed, + text_encoder=text_encoder, decoder=decoder, joiner=joiner, encoder_dim=int(max(params.encoder_dim.split(','))), diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index 4ead015a4..a55b4bd57 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -89,6 +89,9 @@ class Zipformer2(EncoderInterface): context chunks for causal training; will be rounded to a number of chunks. Must not be less than cnn_module_kernel (after factoring in rounding and downsampling); an error will be thrown if this is violated. + memory_dim: if supplied and >0, will be the dimension of the memory embeddings + passed into the zipformer (e.g. this might be the output of another + Zipformer used to create embedding vectors.) """ def __init__( self, @@ -103,6 +106,7 @@ class Zipformer2(EncoderInterface): num_heads: Union[int, Tuple[int]] = 8, feedforward_dim: Union[int, Tuple[int]] = 1536, cnn_module_kernel: Union[int, Tuple[int]] = 31, + memory_dim: int = -1, pos_dim: int = 192, dropout: FloatLike = None, # see code below for default warmup_batches: float = 4000.0, @@ -160,6 +164,7 @@ class Zipformer2(EncoderInterface): pos_head_dim=pos_head_dim[i], value_head_dim=value_head_dim[i], feedforward_dim=feedforward_dim[i], + memory_dim=memory_dim, dropout=dropout, cnn_module_kernel=cnn_module_kernel[i], causal=causal, @@ -271,9 +276,12 @@ class Zipformer2(EncoderInterface): def forward( - self, x: torch.Tensor, + self, + x: torch.Tensor, x_lens: torch.Tensor, src_key_padding_mask: Optional[torch.Tensor] = None, + memory: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -284,7 +292,12 @@ class Zipformer2(EncoderInterface): `x` before padding. src_key_padding_mask: The mask for padding, of shape (batch_size, seq_len); True means - masked position. May be None. + masked position. May be None. + memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) + memory_key_padding_mask: optionally the mask for padding of memory input (for source- + attention), of shape (batch_size, memory_len); True means + masked position. May be None. + Returns: Return a tuple containing 2 tensors: - embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim)) @@ -298,6 +311,13 @@ class Zipformer2(EncoderInterface): attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) + if self.training and memory is not None: + batch_size = x.shape[1] + # setting memory to zero should be equivalent to not using the + # memory input at all, since the Attention module has no biases. + memory_dropout_rate = 0.05 + memory = memory * (torch.rand(batch_size, 1) > memory_dropout_rate) + for i, module in enumerate(self.encoders): ds = self.downsampling_factor[i] x = convert_num_channels(x, self.encoder_dim[i]) @@ -308,6 +328,8 @@ class Zipformer2(EncoderInterface): src_key_padding_mask=(None if src_key_padding_mask is None else src_key_padding_mask[...,::ds]), attn_mask=attn_mask, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, ) outputs.append(x) @@ -416,6 +438,7 @@ class Zipformer2EncoderLayer(nn.Module): dropout: FloatLike = 0.1, cnn_module_kernel: int = 31, causal: bool = False, + memory_dim: int = -1, attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), @@ -452,11 +475,25 @@ class Zipformer2EncoderLayer(nn.Module): dropout=0.0, ) - self.self_attn1 = SelfAttention(embed_dim, num_heads, + + self.self_attn1 = Attention(embed_dim, embed_dim, num_heads, value_head_dim) - self.self_attn2 = SelfAttention(embed_dim, num_heads, - value_head_dim) + self.self_attn2 = Attention(embed_dim, embed_dim, num_heads, + value_head_dim) + + if memory_dim > 0: + self.attn_weights = MultiheadAttentionWeights( + memory_dim, embed_dim, + num_heads=num_heads, + head_dim=query_head_dim, + dropout=0.0, + ) + self.src_attn1 = Attention(memory_dim, embed_dim, num_heads, + value_head_dim) + self.src_attn2 = Attention(memory_dim, embed_dim, num_heads, + value_head_dim) + self.feed_forward1 = FeedforwardModule(embed_dim, (feedforward_dim * 3) // 4, @@ -579,6 +616,8 @@ class Zipformer2EncoderLayer(nn.Module): chunk_size: int = -1, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + memory: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """ Pass the input through the encoder layer. @@ -608,11 +647,14 @@ class Zipformer2EncoderLayer(nn.Module): pos_emb=pos_emb, attn_mask=attn_mask, key_padding_mask=src_key_padding_mask, - ) + ) + + if memory is not None and hasattr(self, 'attn_weights'): + src_attn_weights = self.attn_weights(memory, src, memory_key_padding_mask) src = src + self.feed_forward1(src) - self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) + attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) if True: selected_attn_weights = attn_weights[0:2] @@ -630,12 +672,16 @@ class Zipformer2EncoderLayer(nn.Module): na = self.balancer_na(self.nonlin_attention(src, selected_attn_weights[0:1])) - src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask) + src = src + (na if attn_dropout_mask is None else na * attn_dropout_mask) self_attn = self.self_attn1( src, attn_weights) - src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) + src = src + (self_attn if attn_dropout_mask is None else self_attn * attn_dropout_mask) + + if memory is not None and hasattr(self, 'attn_weights'): + src = src + self.sequence_dropout(self.src_attn1(memory, src_attn_weights), + attention_skip_rate) src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask), @@ -650,7 +696,11 @@ class Zipformer2EncoderLayer(nn.Module): self_attn = self.self_attn2( src, attn_weights) - src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) + src = src + (self_attn if attn_dropout_mask is None else self_attn * attn_dropout_mask) + + if memory is not None and hasattr(self, 'attn_weights'): + src = src + self.sequence_dropout(self.src_attn2(memory, src_attn_weights), + attention_skip_rate) src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size, src_key_padding_mask=src_key_padding_mask), @@ -673,9 +723,9 @@ class Zipformer2Encoder(nn.Module): r"""Zipformer2Encoder is a stack of N encoder layers Args: - encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). + encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). num_layers: the number of sub-encoder-layers in the encoder (required). - pos_dim: the dimension for the relative positional encoding + pos_dim: the dimension for the relative positional encoding Examples:: >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) @@ -721,6 +771,8 @@ class Zipformer2Encoder(nn.Module): feature_mask: Union[Tensor, float] = 1.0, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + memory: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Pass the input through the encoder layers in turn. @@ -734,6 +786,10 @@ class Zipformer2Encoder(nn.Module): True means masked position. May be None. src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means masked position. May be None. + memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) + memory_key_padding_mask: optionally the mask for padding of memory input (for source- + attention), of shape (batch_size, memory_len); True means + masked position. May be None. Returns: a Tensor with the same shape as src. """ @@ -751,6 +807,8 @@ class Zipformer2Encoder(nn.Module): chunk_size=chunk_size, attn_mask=attn_mask, src_key_padding_mask=src_key_padding_mask, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, ) output = output * feature_mask @@ -843,6 +901,8 @@ class DownsampledZipformer2Encoder(nn.Module): feature_mask: Union[Tensor, float] = 1.0, attn_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, + memory: Optional[Tensor] = None, + memory_key_padding_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: r"""Downsample, go through encoder, upsample. @@ -855,6 +915,10 @@ class DownsampledZipformer2Encoder(nn.Module): True means masked position. May be None. src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means masked position. May be None. + memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim) + memory_key_padding_mask: optionally the mask for padding of memory input (for source- + attention), of shape (batch_size, memory_len); True means + masked position. May be None. Returns: a Tensor with the same shape as src. """ @@ -870,6 +934,8 @@ class DownsampledZipformer2Encoder(nn.Module): feature_mask=feature_mask, attn_mask=attn_mask, src_key_padding_mask=src_key_padding_mask, + memory=memory, + memory_key_padding_mask=memory_key_padding_mask, ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor @@ -1185,7 +1251,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module): query_dim = query_head_dim * num_heads - # self-attention q = x[...,0:query_dim] k = x[...,query_dim:2*query_dim] # p is the position-encoding query @@ -1231,9 +1296,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module): attn_scores = attn_scores + pos_scores if self.training and random.random() < 0.1: - # This is a harder way of limiting the attention scores to not be + # This is away of limiting the attention scores to not be # too large. It incurs a penalty if any of them has an absolute - # value greater than 50.0. this should be outside the normal range + # value greater than 25.0. this should be outside the normal range # of the attention scores. We use this mechanism instead of, say, # something added to the loss function involving the entropy, # because once the entropy gets very small gradients through the @@ -1295,29 +1360,31 @@ class RelPositionMultiheadAttentionWeights(nn.Module): logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}") -class SelfAttention(nn.Module): +class Attention(nn.Module): """ The simplest possible attention module. This one works with already-computed attention weights, e.g. as computed by RelPositionMultiheadAttentionWeights. Args: - embed_dim: the input and output embedding dimension + embed_dim_in: the input embedding dimension + embed_dim_out: the output embedding dimension (normally the same as input) num_heads: the number of attention heads value_head_dim: the value dimension per head """ def __init__( self, - embed_dim: int, + embed_dim_in: int, + embed_dim_out: int, num_heads: int, value_head_dim: int, ) -> None: super().__init__() - self.in_proj = nn.Linear(embed_dim, + self.in_proj = nn.Linear(embed_dim_in, num_heads * value_head_dim, - bias=True) + bias=False) self.out_proj = ScaledLinear(num_heads * value_head_dim, - embed_dim, bias=True, + embed_dim_out, bias=False, initial_scale=0.05) self.whiten = Whiten(num_groups=1, @@ -1334,35 +1401,182 @@ class SelfAttention(nn.Module): """ Args: x: input tensor, of shape (seq_len, batch_size, embed_dim) - attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), - with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect - attn_weights.sum(dim=-1) == 1. + attn_weights: a tensor of shape (num_heads, batch_size, query_len, key_len), + Expect attn_weights.sum(dim=-1) == 1. Returns: a tensor with the same shape as x. """ - (seq_len, batch_size, embed_dim) = x.shape - num_heads = attn_weights.shape[0] - assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, query_len, key_len) = attn_weights.shape - x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) - x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) - # now x: (num_heads, batch_size, seq_len, value_head_dim) + x = self.in_proj(x) # (key_len, batch_size, num_heads * value_head_dim) + x = x.reshape(key_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) + # now x: (num_heads, batch_size, key_len, value_head_dim) value_head_dim = x.shape[-1] # todo: see whether there is benefit in overriding matmul x = torch.matmul(attn_weights, x) - # v: (num_heads, batch_size, seq_len, value_head_dim) + # v: (num_heads, batch_size, query_len, value_head_dim) x = x.permute(2, 1, 0, 3).contiguous().view( - seq_len, batch_size, num_heads * value_head_dim) + query_len, batch_size, num_heads * value_head_dim) - # returned value is of shape (seq_len, batch_size, embed_dim), like the input. + # returned value is of shape (query_len, batch_size, embed_dim), like the input. x = self.out_proj(x) x = self.whiten(x) return x +class MultiheadAttentionWeights(nn.Module): + r"""Module that computes multi-head cross-attention weights. Allows src and target + to have different dims. + + Args: + key_embed_dim: number of channels of the thing that we'll project to + make the query (corresponds to source). e.g. 256 + query_embed_dim: number of channels of the thing that we'll project to + make the query (corresponds to target). e.g. 256 + num_heads: number of heads to compute weights for, e.g. 8 + head_dim: dimension of the query and key, per head. e.g. 24. + dropout: dropout probability for attn_output_weights. Default: 0.0. + """ + + def __init__( + self, + key_embed_dim: int, + query_embed_dim: int, + num_heads: int, + head_dim: int, + dropout: float = 0.0, + + ) -> None: + super().__init__() + self.key_embed_dim = key_embed_dim + self.query_embed_dim = query_embed_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.dropout = dropout + self.name = None # will be overwritten in training code; for diagnostics. + + + # the initial_scale is supposed to take over the "scaling" factor of + # head_dim ** -0.5 that has been used in previous forms of attention, + # dividing it between the query and key. Note: this module is intended + # to be used with the ScaledAdam optimizer; with most other optimizers, + # it would be necessary to apply the scaling factor in the forward function. + self.query_in_proj = ScaledLinear(query_embed_dim, + head_dim * num_heads, + bias=True, + initial_scale=head_dim ** -0.25) + + # weights produced by this module are invariant to adding a constant to + # the keys, so we don't need a bias for the keys. + self.key_in_proj = ScaledLinear(key_embed_dim, + head_dim * num_heads, + bias=False, + initial_scale=head_dim ** -0.25) + + self.whiten_keys = Whiten(num_groups=num_heads, + whitening_limit=_whitening_schedule(3.0), + prob=(0.025, 0.25), + grad_scale=0.025) + + + + def forward( + self, + key: Tensor, + query: Tensor, + key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r""" + Args: + key: input of shape (key_len, batch_size, key_embed_dim) + query: input of shape (query_len, batch_size, query_embed_dim) + key_padding_mask: an optional bool tensor of shape (batch_size, key_len). Positions that + are True in this mask will be ignored as sources in the attention weighting. + Returns: + a tensor of attention weights, of shape (hum_heads, batch_size, query_len, key_len) + """ + q = self.query_in_proj(query) + k = self.key_in_proj(key) + + head_dim = self.head_dim + num_heads = self.num_heads + + query_len, batch_size, _ = q.shape + key_len, _batch_size, _ = k.shape + assert _batch_size == batch_size + + k = self.whiten_keys(k) # does nothing in the forward pass. + + q = q.reshape(query_len, batch_size, num_heads, head_dim) + k = k.reshape(key_len, batch_size, num_heads, head_dim) + + # time1 refers to target, time2 refers to source. + q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim) + k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2) + + attn_scores = torch.matmul(q, k) + + if self.training and random.random() < 0.1: + # This is a way of limiting the attention scores to not be + # too large. It incurs a penalty if any of them has an absolute + # value greater than 25.0. this should be outside the normal range + # of the attention scores. We use this mechanism instead of, say, + # something added to the loss function involving the entropy, + # because once the entropy gets very small gradients through the + # softmax can become very small, and we'd get zero derivatives. The + # choices of 1.0e-04 as the scale on the penalty makes this + # mechanism vulnerable to the absolute scale of the loss function, + # but we view this as a failsafe to avoid "implausible" parameter + # values rather than a regularization method that should be active + # under normal circumstances. + attn_scores = penalize_abs_values_gt(attn_scores, + limit=25.0, + penalty=1.0e-04, + name=self.name) + + assert attn_scores.shape == (num_heads, batch_size, query_len, key_len) + + if key_padding_mask is not None: + assert key_padding_mask.shape == (batch_size, key_len), key_padding_mask.shape + attn_scores = attn_scores.masked_fill( + key_padding_mask.unsqueeze(1), + -1000, + ) + + # We use our own version of softmax, defined in scaling.py, which should + # save a little of the memory used in backprop by, if we are in + # automatic mixed precision mode (amp / autocast), by only storing the + # half-precision output for backprop purposes. + attn_weights = softmax(attn_scores, dim=-1) + + if random.random() < 0.001: + self._print_attn_entropy(attn_weights) + + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) + + return attn_weights + + + def _print_attn_entropy( + self, + attn_weights: Tensor): + # attn_weights: (num_heads, batch_size, seq_len, seq_len) + (num_heads, batch_size, seq_len, seq_len) = attn_weights.shape + + with torch.no_grad(): + with torch.cuda.amp.autocast(enabled=False): + attn_weights = attn_weights.to(torch.float32) + attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( + dim=-1).mean(dim=(1,2)) + logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}") + + + class FeedforwardModule(nn.Module): """Feedforward module in Zipformer2 model. """ @@ -1650,12 +1864,14 @@ def _test_zipformer_main(causal: bool = False): batch_size = 5 seq_len = 20 # Just make sure the forward pass runs. + memory_dim = 100 c = Zipformer2( encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), causal=causal, chunk_size=(4,) if causal else (-1,), - left_context_frames=(64,) + left_context_frames=(64,), + memory_dim=memory_dim, ) batch_size = 5 seq_len = 20 @@ -1663,6 +1879,7 @@ def _test_zipformer_main(causal: bool = False): f = c( torch.randn(seq_len, batch_size, 64), torch.full((batch_size,), seq_len, dtype=torch.int64), + memory=torch.randn(101, batch_size, memory_dim), ) f[0].sum().backward() c.eval()