Add memory to model

This commit is contained in:
Daniel Povey 2023-05-01 20:47:09 +08:00
parent 6f5c4688ef
commit fa696e919b
3 changed files with 350 additions and 39 deletions

View File

@ -26,15 +26,16 @@ from icefall.utils import add_sos, make_pad_mask
from scaling import penalize_abs_values_gt, ScaledLinear from scaling import penalize_abs_values_gt, ScaledLinear
class Transducer(nn.Module): class PromptedTransducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf """It implements https://arxiv.org/pdf/1211.3711.pdf
"Sequence Transduction with Recurrent Neural Networks" "Sequence Transduction with Recurrent Neural Networks"
""" """
def __init__( def __init__(
self, self,
encoder_embed: nn.Module, encoder_embed: nn.Module,
encoder: EncoderInterface, encoder: EncoderInterface,
text_embed: nn.Module,
text_encoder: EncoderInterface,
decoder: nn.Module, decoder: nn.Module,
joiner: nn.Module, joiner: nn.Module,
encoder_dim: int, encoder_dim: int,
@ -68,6 +69,8 @@ class Transducer(nn.Module):
self.encoder_embed = encoder_embed self.encoder_embed = encoder_embed
self.encoder = encoder self.encoder = encoder
self.text_embed = text_embed
self.text_encoder = text_encoder
self.decoder = decoder self.decoder = decoder
self.joiner = joiner self.joiner = joiner
@ -86,6 +89,9 @@ class Transducer(nn.Module):
self, self,
x: torch.Tensor, x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
text: torch.Tensor,
style_lens: torch.Tensor,
text_lens: torch.Tensor,
y: k2.RaggedTensor, y: k2.RaggedTensor,
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
@ -98,6 +104,21 @@ class Transducer(nn.Module):
x_lens: x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x` A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding. before padding.
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
text:
A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
It is exptected to contain the style prompt (first) and then the content
prompt.
style_lens:
A 1-D tensor of shape (N,), containing the number of elements (bytes)
within each row of `text` that correspond to the style prompt (these
are expected to come first).
text_lens:
A 1-D tensor of shape (N,). It contains the number of elements (bytes)
in `text` before padding, which will include the lengths of the
style plus the content prompt.
y: y:
A ragged tensor with 2 axes [utt][label]. It contains labels of each A ragged tensor with 2 axes [utt][label]. It contains labels of each
utterance. utterance.
@ -125,14 +146,25 @@ class Transducer(nn.Module):
assert x.size(0) == x_lens.size(0) == y.dim0 assert x.size(0) == x_lens.size(0) == y.dim0
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens) x, x_lens = self.encoder_embed(x, x_lens)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
src_key_padding_mask = make_pad_mask(x_lens) src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask) text = text.t() # now (T, N)
text = self.text_embed(text) # now (T, N, C)
text_key_padding_mask = make_pad_mask(text_lens)
memory, text_lens = self.text_encoder(text, text_lens,
text_key_padding_mask)
memory = self._add_style_indicator(memory, style_lens)
memory_key_padding_mask = make_pad_mask(text_lens)
encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(x_lens > 0) assert torch.all(x_lens > 0)
@ -217,3 +249,24 @@ class Transducer(nn.Module):
) )
return (simple_loss, pruned_loss) return (simple_loss, pruned_loss)
def _add_style_indicator(self, memory: Tensor, style_lens: Tensor):
"""
Adds to `memory` an indicator that is 0.1 for positions that correspond to
the `style prompt` and 0 elsewhere. The scale can be fixed because the
scale of the memory vector can adjust to compensate (within limits set
by the balancers)..
Args:
memory: (memory_len, batch_size, embed_dim)
style_lens: (batch_size,), a vector of lengths of the style prompt.
"""
(memory_len, batch_size, embed_dim) = memory.shape
indicator = torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens
indicator = indicator.to(memory.dtype).unsqueeze(-1)
return memory + indicator

View File

@ -71,7 +71,9 @@ from lhotse.utils import fix_random_seed
from model import Transducer from model import Transducer
from optim import Eden, ScaledAdam from optim import Eden, ScaledAdam
from torch import Tensor from torch import Tensor
from torch import nn
from torch.cuda.amp import GradScaler from torch.cuda.amp import GradScaler
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -159,6 +161,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Embedding dimension in encoder stacks: a single int or comma-separated list." help="Embedding dimension in encoder stacks: a single int or comma-separated list."
) )
parser.add_argument(
"--text-encoder-dim",
type=str,
default="256,256,384,512",
help="Embedding dimension in text encoder stacks: a comma-separated list of 4 elements, "
"or you should change other configs in the code."
)
parser.add_argument( parser.add_argument(
"--query-head-dim", "--query-head-dim",
type=str, type=str,
@ -547,6 +557,32 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
return encoder_embed return encoder_embed
def get_text_embed(params: AttributeDict) -> nn.Module:
return nn.Embedding(
num_embeddings=256, # we encode the text as UTF-8 bytes
embedding_dim=_to_int_tuple(params.text_encoder_dim)[0],
)
def get_text_encoder(params: AttributeDict) -> nn.Module:
return Zipformer2(
output_downsampling_factor=8,
downsampling_factor=(1,2,4,8),
num_encoder_layers=(2,4,6,6),
encoder_dim=_to_int_tuple(params.text_encoder_dim),
encoder_unmasked_dim=(196,196,256,256),
query_head_dim=(32,32,32,32),
pos_head_dim=(4,4,4,4),
value_head_dim=(12,12,12,12),
pos_dim=48,
num_heads=(4,4,4,8),
feedforward_dim=(384,512,768,1024), # could increase this if there is nough data
cnn_module_kernel=(31,31,15,15),
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
warmup_batches=4000.0,
causal=False,
)
def get_encoder_model(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module:
encoder = Zipformer2( encoder = Zipformer2(
output_downsampling_factor=2, output_downsampling_factor=2,
@ -566,6 +602,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
causal=params.causal, causal=params.causal,
chunk_size=_to_int_tuple(params.chunk_size), chunk_size=_to_int_tuple(params.chunk_size),
left_context_frames=_to_int_tuple(params.left_context_frames), left_context_frames=_to_int_tuple(params.left_context_frames),
memory_dim=_to_int_tuple(params.text_encoder_dim)[-1],
) )
return encoder return encoder
@ -593,12 +630,16 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
def get_transducer_model(params: AttributeDict) -> nn.Module: def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder_embed = get_encoder_embed(params) encoder_embed = get_encoder_embed(params)
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
text_embed = get_text_embed(params)
text_encoder = get_text_encoder(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)
model = Transducer( model = Transducer(
encoder_embed=encoder_embed, encoder_embed=encoder_embed,
encoder=encoder, encoder=encoder,
text_embed=text_embed,
text_encoder=text_encoder,
decoder=decoder, decoder=decoder,
joiner=joiner, joiner=joiner,
encoder_dim=int(max(params.encoder_dim.split(','))), encoder_dim=int(max(params.encoder_dim.split(','))),

View File

@ -89,6 +89,9 @@ class Zipformer2(EncoderInterface):
context chunks for causal training; will be rounded to a number of context chunks for causal training; will be rounded to a number of
chunks. Must not be less than cnn_module_kernel (after factoring in chunks. Must not be less than cnn_module_kernel (after factoring in
rounding and downsampling); an error will be thrown if this is violated. rounding and downsampling); an error will be thrown if this is violated.
memory_dim: if supplied and >0, will be the dimension of the memory embeddings
passed into the zipformer (e.g. this might be the output of another
Zipformer used to create embedding vectors.)
""" """
def __init__( def __init__(
self, self,
@ -103,6 +106,7 @@ class Zipformer2(EncoderInterface):
num_heads: Union[int, Tuple[int]] = 8, num_heads: Union[int, Tuple[int]] = 8,
feedforward_dim: Union[int, Tuple[int]] = 1536, feedforward_dim: Union[int, Tuple[int]] = 1536,
cnn_module_kernel: Union[int, Tuple[int]] = 31, cnn_module_kernel: Union[int, Tuple[int]] = 31,
memory_dim: int = -1,
pos_dim: int = 192, pos_dim: int = 192,
dropout: FloatLike = None, # see code below for default dropout: FloatLike = None, # see code below for default
warmup_batches: float = 4000.0, warmup_batches: float = 4000.0,
@ -160,6 +164,7 @@ class Zipformer2(EncoderInterface):
pos_head_dim=pos_head_dim[i], pos_head_dim=pos_head_dim[i],
value_head_dim=value_head_dim[i], value_head_dim=value_head_dim[i],
feedforward_dim=feedforward_dim[i], feedforward_dim=feedforward_dim[i],
memory_dim=memory_dim,
dropout=dropout, dropout=dropout,
cnn_module_kernel=cnn_module_kernel[i], cnn_module_kernel=cnn_module_kernel[i],
causal=causal, causal=causal,
@ -271,9 +276,12 @@ class Zipformer2(EncoderInterface):
def forward( def forward(
self, x: torch.Tensor, self,
x: torch.Tensor,
x_lens: torch.Tensor, x_lens: torch.Tensor,
src_key_padding_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None,
memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
@ -284,7 +292,12 @@ class Zipformer2(EncoderInterface):
`x` before padding. `x` before padding.
src_key_padding_mask: src_key_padding_mask:
The mask for padding, of shape (batch_size, seq_len); True means The mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None. masked position. May be None.
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
attention), of shape (batch_size, memory_len); True means
masked position. May be None.
Returns: Returns:
Return a tuple containing 2 tensors: Return a tuple containing 2 tensors:
- embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim)) - embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim))
@ -298,6 +311,13 @@ class Zipformer2(EncoderInterface):
attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks) attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
if self.training and memory is not None:
batch_size = x.shape[1]
# setting memory to zero should be equivalent to not using the
# memory input at all, since the Attention module has no biases.
memory_dropout_rate = 0.05
memory = memory * (torch.rand(batch_size, 1) > memory_dropout_rate)
for i, module in enumerate(self.encoders): for i, module in enumerate(self.encoders):
ds = self.downsampling_factor[i] ds = self.downsampling_factor[i]
x = convert_num_channels(x, self.encoder_dim[i]) x = convert_num_channels(x, self.encoder_dim[i])
@ -308,6 +328,8 @@ class Zipformer2(EncoderInterface):
src_key_padding_mask=(None if src_key_padding_mask is None src_key_padding_mask=(None if src_key_padding_mask is None
else src_key_padding_mask[...,::ds]), else src_key_padding_mask[...,::ds]),
attn_mask=attn_mask, attn_mask=attn_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
) )
outputs.append(x) outputs.append(x)
@ -416,6 +438,7 @@ class Zipformer2EncoderLayer(nn.Module):
dropout: FloatLike = 0.1, dropout: FloatLike = 0.1,
cnn_module_kernel: int = 31, cnn_module_kernel: int = 31,
causal: bool = False, causal: bool = False,
memory_dim: int = -1,
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0), conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0), const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
@ -452,11 +475,25 @@ class Zipformer2EncoderLayer(nn.Module):
dropout=0.0, dropout=0.0,
) )
self.self_attn1 = SelfAttention(embed_dim, num_heads,
self.self_attn1 = Attention(embed_dim, embed_dim, num_heads,
value_head_dim) value_head_dim)
self.self_attn2 = SelfAttention(embed_dim, num_heads, self.self_attn2 = Attention(embed_dim, embed_dim, num_heads,
value_head_dim) value_head_dim)
if memory_dim > 0:
self.attn_weights = MultiheadAttentionWeights(
memory_dim, embed_dim,
num_heads=num_heads,
head_dim=query_head_dim,
dropout=0.0,
)
self.src_attn1 = Attention(memory_dim, embed_dim, num_heads,
value_head_dim)
self.src_attn2 = Attention(memory_dim, embed_dim, num_heads,
value_head_dim)
self.feed_forward1 = FeedforwardModule(embed_dim, self.feed_forward1 = FeedforwardModule(embed_dim,
(feedforward_dim * 3) // 4, (feedforward_dim * 3) // 4,
@ -579,6 +616,8 @@ class Zipformer2EncoderLayer(nn.Module):
chunk_size: int = -1, chunk_size: int = -1,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
""" """
Pass the input through the encoder layer. Pass the input through the encoder layer.
@ -608,11 +647,14 @@ class Zipformer2EncoderLayer(nn.Module):
pos_emb=pos_emb, pos_emb=pos_emb,
attn_mask=attn_mask, attn_mask=attn_mask,
key_padding_mask=src_key_padding_mask, key_padding_mask=src_key_padding_mask,
) )
if memory is not None and hasattr(self, 'attn_weights'):
src_attn_weights = self.attn_weights(memory, src, memory_key_padding_mask)
src = src + self.feed_forward1(src) src = src + self.feed_forward1(src)
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate) attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
if True: if True:
selected_attn_weights = attn_weights[0:2] selected_attn_weights = attn_weights[0:2]
@ -630,12 +672,16 @@ class Zipformer2EncoderLayer(nn.Module):
na = self.balancer_na(self.nonlin_attention(src, na = self.balancer_na(self.nonlin_attention(src,
selected_attn_weights[0:1])) selected_attn_weights[0:1]))
src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask) src = src + (na if attn_dropout_mask is None else na * attn_dropout_mask)
self_attn = self.self_attn1( self_attn = self.self_attn1(
src, attn_weights) src, attn_weights)
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) src = src + (self_attn if attn_dropout_mask is None else self_attn * attn_dropout_mask)
if memory is not None and hasattr(self, 'attn_weights'):
src = src + self.sequence_dropout(self.src_attn1(memory, src_attn_weights),
attention_skip_rate)
src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size, src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size,
src_key_padding_mask=src_key_padding_mask), src_key_padding_mask=src_key_padding_mask),
@ -650,7 +696,11 @@ class Zipformer2EncoderLayer(nn.Module):
self_attn = self.self_attn2( self_attn = self.self_attn2(
src, attn_weights) src, attn_weights)
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask) src = src + (self_attn if attn_dropout_mask is None else self_attn * attn_dropout_mask)
if memory is not None and hasattr(self, 'attn_weights'):
src = src + self.sequence_dropout(self.src_attn2(memory, src_attn_weights),
attention_skip_rate)
src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size, src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size,
src_key_padding_mask=src_key_padding_mask), src_key_padding_mask=src_key_padding_mask),
@ -673,9 +723,9 @@ class Zipformer2Encoder(nn.Module):
r"""Zipformer2Encoder is a stack of N encoder layers r"""Zipformer2Encoder is a stack of N encoder layers
Args: Args:
encoder_layer: an instance of the Zipformer2EncoderLayer() class (required). encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required). num_layers: the number of sub-encoder-layers in the encoder (required).
pos_dim: the dimension for the relative positional encoding pos_dim: the dimension for the relative positional encoding
Examples:: Examples::
>>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8) >>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
@ -721,6 +771,8 @@ class Zipformer2Encoder(nn.Module):
feature_mask: Union[Tensor, float] = 1.0, feature_mask: Union[Tensor, float] = 1.0,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
) -> Tensor: ) -> Tensor:
r"""Pass the input through the encoder layers in turn. r"""Pass the input through the encoder layers in turn.
@ -734,6 +786,10 @@ class Zipformer2Encoder(nn.Module):
True means masked position. May be None. True means masked position. May be None.
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None. masked position. May be None.
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
attention), of shape (batch_size, memory_len); True means
masked position. May be None.
Returns: a Tensor with the same shape as src. Returns: a Tensor with the same shape as src.
""" """
@ -751,6 +807,8 @@ class Zipformer2Encoder(nn.Module):
chunk_size=chunk_size, chunk_size=chunk_size,
attn_mask=attn_mask, attn_mask=attn_mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
) )
output = output * feature_mask output = output * feature_mask
@ -843,6 +901,8 @@ class DownsampledZipformer2Encoder(nn.Module):
feature_mask: Union[Tensor, float] = 1.0, feature_mask: Union[Tensor, float] = 1.0,
attn_mask: Optional[Tensor] = None, attn_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None,
memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]: ) -> Tuple[Tensor, Tensor]:
r"""Downsample, go through encoder, upsample. r"""Downsample, go through encoder, upsample.
@ -855,6 +915,10 @@ class DownsampledZipformer2Encoder(nn.Module):
True means masked position. May be None. True means masked position. May be None.
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
masked position. May be None. masked position. May be None.
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
attention), of shape (batch_size, memory_len); True means
masked position. May be None.
Returns: a Tensor with the same shape as src. Returns: a Tensor with the same shape as src.
""" """
@ -870,6 +934,8 @@ class DownsampledZipformer2Encoder(nn.Module):
feature_mask=feature_mask, feature_mask=feature_mask,
attn_mask=attn_mask, attn_mask=attn_mask,
src_key_padding_mask=src_key_padding_mask, src_key_padding_mask=src_key_padding_mask,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
) )
src = self.upsample(src) src = self.upsample(src)
# remove any extra frames that are not a multiple of downsample_factor # remove any extra frames that are not a multiple of downsample_factor
@ -1185,7 +1251,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
query_dim = query_head_dim * num_heads query_dim = query_head_dim * num_heads
# self-attention
q = x[...,0:query_dim] q = x[...,0:query_dim]
k = x[...,query_dim:2*query_dim] k = x[...,query_dim:2*query_dim]
# p is the position-encoding query # p is the position-encoding query
@ -1231,9 +1296,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
attn_scores = attn_scores + pos_scores attn_scores = attn_scores + pos_scores
if self.training and random.random() < 0.1: if self.training and random.random() < 0.1:
# This is a harder way of limiting the attention scores to not be # This is away of limiting the attention scores to not be
# too large. It incurs a penalty if any of them has an absolute # too large. It incurs a penalty if any of them has an absolute
# value greater than 50.0. this should be outside the normal range # value greater than 25.0. this should be outside the normal range
# of the attention scores. We use this mechanism instead of, say, # of the attention scores. We use this mechanism instead of, say,
# something added to the loss function involving the entropy, # something added to the loss function involving the entropy,
# because once the entropy gets very small gradients through the # because once the entropy gets very small gradients through the
@ -1295,29 +1360,31 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}") logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}")
class SelfAttention(nn.Module): class Attention(nn.Module):
""" """
The simplest possible attention module. This one works with already-computed attention The simplest possible attention module. This one works with already-computed attention
weights, e.g. as computed by RelPositionMultiheadAttentionWeights. weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
Args: Args:
embed_dim: the input and output embedding dimension embed_dim_in: the input embedding dimension
embed_dim_out: the output embedding dimension (normally the same as input)
num_heads: the number of attention heads num_heads: the number of attention heads
value_head_dim: the value dimension per head value_head_dim: the value dimension per head
""" """
def __init__( def __init__(
self, self,
embed_dim: int, embed_dim_in: int,
embed_dim_out: int,
num_heads: int, num_heads: int,
value_head_dim: int, value_head_dim: int,
) -> None: ) -> None:
super().__init__() super().__init__()
self.in_proj = nn.Linear(embed_dim, self.in_proj = nn.Linear(embed_dim_in,
num_heads * value_head_dim, num_heads * value_head_dim,
bias=True) bias=False)
self.out_proj = ScaledLinear(num_heads * value_head_dim, self.out_proj = ScaledLinear(num_heads * value_head_dim,
embed_dim, bias=True, embed_dim_out, bias=False,
initial_scale=0.05) initial_scale=0.05)
self.whiten = Whiten(num_groups=1, self.whiten = Whiten(num_groups=1,
@ -1334,35 +1401,182 @@ class SelfAttention(nn.Module):
""" """
Args: Args:
x: input tensor, of shape (seq_len, batch_size, embed_dim) x: input tensor, of shape (seq_len, batch_size, embed_dim)
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len), attn_weights: a tensor of shape (num_heads, batch_size, query_len, key_len),
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect Expect attn_weights.sum(dim=-1) == 1.
attn_weights.sum(dim=-1) == 1.
Returns: Returns:
a tensor with the same shape as x. a tensor with the same shape as x.
""" """
(seq_len, batch_size, embed_dim) = x.shape (num_heads, batch_size, query_len, key_len) = attn_weights.shape
num_heads = attn_weights.shape[0]
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim) x = self.in_proj(x) # (key_len, batch_size, num_heads * value_head_dim)
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3) x = x.reshape(key_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
# now x: (num_heads, batch_size, seq_len, value_head_dim) # now x: (num_heads, batch_size, key_len, value_head_dim)
value_head_dim = x.shape[-1] value_head_dim = x.shape[-1]
# todo: see whether there is benefit in overriding matmul # todo: see whether there is benefit in overriding matmul
x = torch.matmul(attn_weights, x) x = torch.matmul(attn_weights, x)
# v: (num_heads, batch_size, seq_len, value_head_dim) # v: (num_heads, batch_size, query_len, value_head_dim)
x = x.permute(2, 1, 0, 3).contiguous().view( x = x.permute(2, 1, 0, 3).contiguous().view(
seq_len, batch_size, num_heads * value_head_dim) query_len, batch_size, num_heads * value_head_dim)
# returned value is of shape (seq_len, batch_size, embed_dim), like the input. # returned value is of shape (query_len, batch_size, embed_dim), like the input.
x = self.out_proj(x) x = self.out_proj(x)
x = self.whiten(x) x = self.whiten(x)
return x return x
class MultiheadAttentionWeights(nn.Module):
r"""Module that computes multi-head cross-attention weights. Allows src and target
to have different dims.
Args:
key_embed_dim: number of channels of the thing that we'll project to
make the query (corresponds to source). e.g. 256
query_embed_dim: number of channels of the thing that we'll project to
make the query (corresponds to target). e.g. 256
num_heads: number of heads to compute weights for, e.g. 8
head_dim: dimension of the query and key, per head. e.g. 24.
dropout: dropout probability for attn_output_weights. Default: 0.0.
"""
def __init__(
self,
key_embed_dim: int,
query_embed_dim: int,
num_heads: int,
head_dim: int,
dropout: float = 0.0,
) -> None:
super().__init__()
self.key_embed_dim = key_embed_dim
self.query_embed_dim = query_embed_dim
self.num_heads = num_heads
self.head_dim = head_dim
self.dropout = dropout
self.name = None # will be overwritten in training code; for diagnostics.
# the initial_scale is supposed to take over the "scaling" factor of
# head_dim ** -0.5 that has been used in previous forms of attention,
# dividing it between the query and key. Note: this module is intended
# to be used with the ScaledAdam optimizer; with most other optimizers,
# it would be necessary to apply the scaling factor in the forward function.
self.query_in_proj = ScaledLinear(query_embed_dim,
head_dim * num_heads,
bias=True,
initial_scale=head_dim ** -0.25)
# weights produced by this module are invariant to adding a constant to
# the keys, so we don't need a bias for the keys.
self.key_in_proj = ScaledLinear(key_embed_dim,
head_dim * num_heads,
bias=False,
initial_scale=head_dim ** -0.25)
self.whiten_keys = Whiten(num_groups=num_heads,
whitening_limit=_whitening_schedule(3.0),
prob=(0.025, 0.25),
grad_scale=0.025)
def forward(
self,
key: Tensor,
query: Tensor,
key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""
Args:
key: input of shape (key_len, batch_size, key_embed_dim)
query: input of shape (query_len, batch_size, query_embed_dim)
key_padding_mask: an optional bool tensor of shape (batch_size, key_len). Positions that
are True in this mask will be ignored as sources in the attention weighting.
Returns:
a tensor of attention weights, of shape (hum_heads, batch_size, query_len, key_len)
"""
q = self.query_in_proj(query)
k = self.key_in_proj(key)
head_dim = self.head_dim
num_heads = self.num_heads
query_len, batch_size, _ = q.shape
key_len, _batch_size, _ = k.shape
assert _batch_size == batch_size
k = self.whiten_keys(k) # does nothing in the forward pass.
q = q.reshape(query_len, batch_size, num_heads, head_dim)
k = k.reshape(key_len, batch_size, num_heads, head_dim)
# time1 refers to target, time2 refers to source.
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
attn_scores = torch.matmul(q, k)
if self.training and random.random() < 0.1:
# This is a way of limiting the attention scores to not be
# too large. It incurs a penalty if any of them has an absolute
# value greater than 25.0. this should be outside the normal range
# of the attention scores. We use this mechanism instead of, say,
# something added to the loss function involving the entropy,
# because once the entropy gets very small gradients through the
# softmax can become very small, and we'd get zero derivatives. The
# choices of 1.0e-04 as the scale on the penalty makes this
# mechanism vulnerable to the absolute scale of the loss function,
# but we view this as a failsafe to avoid "implausible" parameter
# values rather than a regularization method that should be active
# under normal circumstances.
attn_scores = penalize_abs_values_gt(attn_scores,
limit=25.0,
penalty=1.0e-04,
name=self.name)
assert attn_scores.shape == (num_heads, batch_size, query_len, key_len)
if key_padding_mask is not None:
assert key_padding_mask.shape == (batch_size, key_len), key_padding_mask.shape
attn_scores = attn_scores.masked_fill(
key_padding_mask.unsqueeze(1),
-1000,
)
# We use our own version of softmax, defined in scaling.py, which should
# save a little of the memory used in backprop by, if we are in
# automatic mixed precision mode (amp / autocast), by only storing the
# half-precision output for backprop purposes.
attn_weights = softmax(attn_scores, dim=-1)
if random.random() < 0.001:
self._print_attn_entropy(attn_weights)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=self.training
)
return attn_weights
def _print_attn_entropy(
self,
attn_weights: Tensor):
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=False):
attn_weights = attn_weights.to(torch.float32)
attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum(
dim=-1).mean(dim=(1,2))
logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}")
class FeedforwardModule(nn.Module): class FeedforwardModule(nn.Module):
"""Feedforward module in Zipformer2 model. """Feedforward module in Zipformer2 model.
""" """
@ -1650,12 +1864,14 @@ def _test_zipformer_main(causal: bool = False):
batch_size = 5 batch_size = 5
seq_len = 20 seq_len = 20
# Just make sure the forward pass runs. # Just make sure the forward pass runs.
memory_dim = 100
c = Zipformer2( c = Zipformer2(
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
causal=causal, causal=causal,
chunk_size=(4,) if causal else (-1,), chunk_size=(4,) if causal else (-1,),
left_context_frames=(64,) left_context_frames=(64,),
memory_dim=memory_dim,
) )
batch_size = 5 batch_size = 5
seq_len = 20 seq_len = 20
@ -1663,6 +1879,7 @@ def _test_zipformer_main(causal: bool = False):
f = c( f = c(
torch.randn(seq_len, batch_size, 64), torch.randn(seq_len, batch_size, 64),
torch.full((batch_size,), seq_len, dtype=torch.int64), torch.full((batch_size,), seq_len, dtype=torch.int64),
memory=torch.randn(101, batch_size, memory_dim),
) )
f[0].sum().backward() f[0].sum().backward()
c.eval() c.eval()