mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add memory to model
This commit is contained in:
parent
6f5c4688ef
commit
fa696e919b
@ -26,15 +26,16 @@ from icefall.utils import add_sos, make_pad_mask
|
||||
from scaling import penalize_abs_values_gt, ScaledLinear
|
||||
|
||||
|
||||
class Transducer(nn.Module):
|
||||
class PromptedTransducer(nn.Module):
|
||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||
"Sequence Transduction with Recurrent Neural Networks"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_embed: nn.Module,
|
||||
encoder: EncoderInterface,
|
||||
text_embed: nn.Module,
|
||||
text_encoder: EncoderInterface,
|
||||
decoder: nn.Module,
|
||||
joiner: nn.Module,
|
||||
encoder_dim: int,
|
||||
@ -68,6 +69,8 @@ class Transducer(nn.Module):
|
||||
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder = encoder
|
||||
self.text_embed = text_embed
|
||||
self.text_encoder = text_encoder
|
||||
self.decoder = decoder
|
||||
self.joiner = joiner
|
||||
|
||||
@ -86,6 +89,9 @@ class Transducer(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
style_lens: torch.Tensor,
|
||||
text_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
@ -98,6 +104,21 @@ class Transducer(nn.Module):
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
text:
|
||||
A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
|
||||
It is exptected to contain the style prompt (first) and then the content
|
||||
prompt.
|
||||
style_lens:
|
||||
A 1-D tensor of shape (N,), containing the number of elements (bytes)
|
||||
within each row of `text` that correspond to the style prompt (these
|
||||
are expected to come first).
|
||||
text_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of elements (bytes)
|
||||
in `text` before padding, which will include the lengths of the
|
||||
style plus the content prompt.
|
||||
y:
|
||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||
utterance.
|
||||
@ -125,14 +146,25 @@ class Transducer(nn.Module):
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||
|
||||
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
|
||||
x, x_lens = self.encoder_embed(x, x_lens)
|
||||
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
text = text.t() # now (T, N)
|
||||
text = self.text_embed(text) # now (T, N, C)
|
||||
text_key_padding_mask = make_pad_mask(text_lens)
|
||||
|
||||
memory, text_lens = self.text_encoder(text, text_lens,
|
||||
text_key_padding_mask)
|
||||
|
||||
memory = self._add_style_indicator(memory, style_lens)
|
||||
|
||||
memory_key_padding_mask = make_pad_mask(text_lens)
|
||||
|
||||
encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask)
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||
|
||||
assert torch.all(x_lens > 0)
|
||||
@ -217,3 +249,24 @@ class Transducer(nn.Module):
|
||||
)
|
||||
|
||||
return (simple_loss, pruned_loss)
|
||||
|
||||
|
||||
def _add_style_indicator(self, memory: Tensor, style_lens: Tensor):
|
||||
"""
|
||||
Adds to `memory` an indicator that is 0.1 for positions that correspond to
|
||||
the `style prompt` and 0 elsewhere. The scale can be fixed because the
|
||||
scale of the memory vector can adjust to compensate (within limits set
|
||||
by the balancers)..
|
||||
|
||||
Args:
|
||||
memory: (memory_len, batch_size, embed_dim)
|
||||
style_lens: (batch_size,), a vector of lengths of the style prompt.
|
||||
"""
|
||||
|
||||
(memory_len, batch_size, embed_dim) = memory.shape
|
||||
|
||||
|
||||
indicator = torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens
|
||||
indicator = indicator.to(memory.dtype).unsqueeze(-1)
|
||||
|
||||
return memory + indicator
|
||||
|
||||
@ -71,7 +71,9 @@ from lhotse.utils import fix_random_seed
|
||||
from model import Transducer
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
from torch import nn
|
||||
from torch.cuda.amp import GradScaler
|
||||
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
@ -159,6 +161,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text-encoder-dim",
|
||||
type=str,
|
||||
default="256,256,384,512",
|
||||
help="Embedding dimension in text encoder stacks: a comma-separated list of 4 elements, "
|
||||
"or you should change other configs in the code."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--query-head-dim",
|
||||
type=str,
|
||||
@ -547,6 +557,32 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||
return encoder_embed
|
||||
|
||||
|
||||
def get_text_embed(params: AttributeDict) -> nn.Module:
|
||||
return nn.Embedding(
|
||||
num_embeddings=256, # we encode the text as UTF-8 bytes
|
||||
embedding_dim=_to_int_tuple(params.text_encoder_dim)[0],
|
||||
)
|
||||
|
||||
def get_text_encoder(params: AttributeDict) -> nn.Module:
|
||||
return Zipformer2(
|
||||
output_downsampling_factor=8,
|
||||
downsampling_factor=(1,2,4,8),
|
||||
num_encoder_layers=(2,4,6,6),
|
||||
encoder_dim=_to_int_tuple(params.text_encoder_dim),
|
||||
encoder_unmasked_dim=(196,196,256,256),
|
||||
query_head_dim=(32,32,32,32),
|
||||
pos_head_dim=(4,4,4,4),
|
||||
value_head_dim=(12,12,12,12),
|
||||
pos_dim=48,
|
||||
num_heads=(4,4,4,8),
|
||||
feedforward_dim=(384,512,768,1024), # could increase this if there is nough data
|
||||
cnn_module_kernel=(31,31,15,15),
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
warmup_batches=4000.0,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = Zipformer2(
|
||||
output_downsampling_factor=2,
|
||||
@ -566,6 +602,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
causal=params.causal,
|
||||
chunk_size=_to_int_tuple(params.chunk_size),
|
||||
left_context_frames=_to_int_tuple(params.left_context_frames),
|
||||
memory_dim=_to_int_tuple(params.text_encoder_dim)[-1],
|
||||
)
|
||||
return encoder
|
||||
|
||||
@ -593,12 +630,16 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder_embed = get_encoder_embed(params)
|
||||
encoder = get_encoder_model(params)
|
||||
text_embed = get_text_embed(params)
|
||||
text_encoder = get_text_encoder(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
||||
model = Transducer(
|
||||
encoder_embed=encoder_embed,
|
||||
encoder=encoder,
|
||||
text_embed=text_embed,
|
||||
text_encoder=text_encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
encoder_dim=int(max(params.encoder_dim.split(','))),
|
||||
|
||||
@ -89,6 +89,9 @@ class Zipformer2(EncoderInterface):
|
||||
context chunks for causal training; will be rounded to a number of
|
||||
chunks. Must not be less than cnn_module_kernel (after factoring in
|
||||
rounding and downsampling); an error will be thrown if this is violated.
|
||||
memory_dim: if supplied and >0, will be the dimension of the memory embeddings
|
||||
passed into the zipformer (e.g. this might be the output of another
|
||||
Zipformer used to create embedding vectors.)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
@ -103,6 +106,7 @@ class Zipformer2(EncoderInterface):
|
||||
num_heads: Union[int, Tuple[int]] = 8,
|
||||
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
||||
cnn_module_kernel: Union[int, Tuple[int]] = 31,
|
||||
memory_dim: int = -1,
|
||||
pos_dim: int = 192,
|
||||
dropout: FloatLike = None, # see code below for default
|
||||
warmup_batches: float = 4000.0,
|
||||
@ -160,6 +164,7 @@ class Zipformer2(EncoderInterface):
|
||||
pos_head_dim=pos_head_dim[i],
|
||||
value_head_dim=value_head_dim[i],
|
||||
feedforward_dim=feedforward_dim[i],
|
||||
memory_dim=memory_dim,
|
||||
dropout=dropout,
|
||||
cnn_module_kernel=cnn_module_kernel[i],
|
||||
causal=causal,
|
||||
@ -271,9 +276,12 @@ class Zipformer2(EncoderInterface):
|
||||
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor,
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
src_key_padding_mask: Optional[torch.Tensor] = None,
|
||||
memory: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
@ -285,6 +293,11 @@ class Zipformer2(EncoderInterface):
|
||||
src_key_padding_mask:
|
||||
The mask for padding, of shape (batch_size, seq_len); True means
|
||||
masked position. May be None.
|
||||
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
|
||||
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
|
||||
attention), of shape (batch_size, memory_len); True means
|
||||
masked position. May be None.
|
||||
|
||||
Returns:
|
||||
Return a tuple containing 2 tensors:
|
||||
- embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim))
|
||||
@ -298,6 +311,13 @@ class Zipformer2(EncoderInterface):
|
||||
|
||||
attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
|
||||
|
||||
if self.training and memory is not None:
|
||||
batch_size = x.shape[1]
|
||||
# setting memory to zero should be equivalent to not using the
|
||||
# memory input at all, since the Attention module has no biases.
|
||||
memory_dropout_rate = 0.05
|
||||
memory = memory * (torch.rand(batch_size, 1) > memory_dropout_rate)
|
||||
|
||||
for i, module in enumerate(self.encoders):
|
||||
ds = self.downsampling_factor[i]
|
||||
x = convert_num_channels(x, self.encoder_dim[i])
|
||||
@ -308,6 +328,8 @@ class Zipformer2(EncoderInterface):
|
||||
src_key_padding_mask=(None if src_key_padding_mask is None
|
||||
else src_key_padding_mask[...,::ds]),
|
||||
attn_mask=attn_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
outputs.append(x)
|
||||
|
||||
@ -416,6 +438,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
dropout: FloatLike = 0.1,
|
||||
cnn_module_kernel: int = 31,
|
||||
causal: bool = False,
|
||||
memory_dim: int = -1,
|
||||
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
||||
@ -452,12 +475,26 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
dropout=0.0,
|
||||
)
|
||||
|
||||
self.self_attn1 = SelfAttention(embed_dim, num_heads,
|
||||
|
||||
self.self_attn1 = Attention(embed_dim, embed_dim, num_heads,
|
||||
value_head_dim)
|
||||
|
||||
self.self_attn2 = SelfAttention(embed_dim, num_heads,
|
||||
self.self_attn2 = Attention(embed_dim, embed_dim, num_heads,
|
||||
value_head_dim)
|
||||
|
||||
if memory_dim > 0:
|
||||
self.attn_weights = MultiheadAttentionWeights(
|
||||
memory_dim, embed_dim,
|
||||
num_heads=num_heads,
|
||||
head_dim=query_head_dim,
|
||||
dropout=0.0,
|
||||
)
|
||||
self.src_attn1 = Attention(memory_dim, embed_dim, num_heads,
|
||||
value_head_dim)
|
||||
self.src_attn2 = Attention(memory_dim, embed_dim, num_heads,
|
||||
value_head_dim)
|
||||
|
||||
|
||||
self.feed_forward1 = FeedforwardModule(embed_dim,
|
||||
(feedforward_dim * 3) // 4,
|
||||
dropout)
|
||||
@ -579,6 +616,8 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
chunk_size: int = -1,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
memory: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
"""
|
||||
Pass the input through the encoder layer.
|
||||
@ -610,9 +649,12 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
if memory is not None and hasattr(self, 'attn_weights'):
|
||||
src_attn_weights = self.attn_weights(memory, src, memory_key_padding_mask)
|
||||
|
||||
src = src + self.feed_forward1(src)
|
||||
|
||||
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
||||
attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
||||
|
||||
if True:
|
||||
selected_attn_weights = attn_weights[0:2]
|
||||
@ -630,12 +672,16 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
na = self.balancer_na(self.nonlin_attention(src,
|
||||
selected_attn_weights[0:1]))
|
||||
|
||||
src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask)
|
||||
src = src + (na if attn_dropout_mask is None else na * attn_dropout_mask)
|
||||
|
||||
self_attn = self.self_attn1(
|
||||
src, attn_weights)
|
||||
|
||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||
src = src + (self_attn if attn_dropout_mask is None else self_attn * attn_dropout_mask)
|
||||
|
||||
if memory is not None and hasattr(self, 'attn_weights'):
|
||||
src = src + self.sequence_dropout(self.src_attn1(memory, src_attn_weights),
|
||||
attention_skip_rate)
|
||||
|
||||
src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size,
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
@ -650,7 +696,11 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
self_attn = self.self_attn2(
|
||||
src, attn_weights)
|
||||
|
||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
||||
src = src + (self_attn if attn_dropout_mask is None else self_attn * attn_dropout_mask)
|
||||
|
||||
if memory is not None and hasattr(self, 'attn_weights'):
|
||||
src = src + self.sequence_dropout(self.src_attn2(memory, src_attn_weights),
|
||||
attention_skip_rate)
|
||||
|
||||
src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size,
|
||||
src_key_padding_mask=src_key_padding_mask),
|
||||
@ -721,6 +771,8 @@ class Zipformer2Encoder(nn.Module):
|
||||
feature_mask: Union[Tensor, float] = 1.0,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
memory: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
r"""Pass the input through the encoder layers in turn.
|
||||
|
||||
@ -734,6 +786,10 @@ class Zipformer2Encoder(nn.Module):
|
||||
True means masked position. May be None.
|
||||
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
||||
masked position. May be None.
|
||||
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
|
||||
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
|
||||
attention), of shape (batch_size, memory_len); True means
|
||||
masked position. May be None.
|
||||
|
||||
Returns: a Tensor with the same shape as src.
|
||||
"""
|
||||
@ -751,6 +807,8 @@ class Zipformer2Encoder(nn.Module):
|
||||
chunk_size=chunk_size,
|
||||
attn_mask=attn_mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
|
||||
output = output * feature_mask
|
||||
@ -843,6 +901,8 @@ class DownsampledZipformer2Encoder(nn.Module):
|
||||
feature_mask: Union[Tensor, float] = 1.0,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
memory: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
r"""Downsample, go through encoder, upsample.
|
||||
|
||||
@ -855,6 +915,10 @@ class DownsampledZipformer2Encoder(nn.Module):
|
||||
True means masked position. May be None.
|
||||
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
||||
masked position. May be None.
|
||||
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
|
||||
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
|
||||
attention), of shape (batch_size, memory_len); True means
|
||||
masked position. May be None.
|
||||
|
||||
Returns: a Tensor with the same shape as src.
|
||||
"""
|
||||
@ -870,6 +934,8 @@ class DownsampledZipformer2Encoder(nn.Module):
|
||||
feature_mask=feature_mask,
|
||||
attn_mask=attn_mask,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
src = self.upsample(src)
|
||||
# remove any extra frames that are not a multiple of downsample_factor
|
||||
@ -1185,7 +1251,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
|
||||
query_dim = query_head_dim * num_heads
|
||||
|
||||
# self-attention
|
||||
q = x[...,0:query_dim]
|
||||
k = x[...,query_dim:2*query_dim]
|
||||
# p is the position-encoding query
|
||||
@ -1231,9 +1296,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
attn_scores = attn_scores + pos_scores
|
||||
|
||||
if self.training and random.random() < 0.1:
|
||||
# This is a harder way of limiting the attention scores to not be
|
||||
# This is away of limiting the attention scores to not be
|
||||
# too large. It incurs a penalty if any of them has an absolute
|
||||
# value greater than 50.0. this should be outside the normal range
|
||||
# value greater than 25.0. this should be outside the normal range
|
||||
# of the attention scores. We use this mechanism instead of, say,
|
||||
# something added to the loss function involving the entropy,
|
||||
# because once the entropy gets very small gradients through the
|
||||
@ -1295,29 +1360,31 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}")
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
The simplest possible attention module. This one works with already-computed attention
|
||||
weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
|
||||
|
||||
Args:
|
||||
embed_dim: the input and output embedding dimension
|
||||
embed_dim_in: the input embedding dimension
|
||||
embed_dim_out: the output embedding dimension (normally the same as input)
|
||||
num_heads: the number of attention heads
|
||||
value_head_dim: the value dimension per head
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
embed_dim_in: int,
|
||||
embed_dim_out: int,
|
||||
num_heads: int,
|
||||
value_head_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_proj = nn.Linear(embed_dim,
|
||||
self.in_proj = nn.Linear(embed_dim_in,
|
||||
num_heads * value_head_dim,
|
||||
bias=True)
|
||||
bias=False)
|
||||
|
||||
self.out_proj = ScaledLinear(num_heads * value_head_dim,
|
||||
embed_dim, bias=True,
|
||||
embed_dim_out, bias=False,
|
||||
initial_scale=0.05)
|
||||
|
||||
self.whiten = Whiten(num_groups=1,
|
||||
@ -1334,35 +1401,182 @@ class SelfAttention(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
x: input tensor, of shape (seq_len, batch_size, embed_dim)
|
||||
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
|
||||
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
|
||||
attn_weights.sum(dim=-1) == 1.
|
||||
attn_weights: a tensor of shape (num_heads, batch_size, query_len, key_len),
|
||||
Expect attn_weights.sum(dim=-1) == 1.
|
||||
Returns:
|
||||
a tensor with the same shape as x.
|
||||
"""
|
||||
(seq_len, batch_size, embed_dim) = x.shape
|
||||
num_heads = attn_weights.shape[0]
|
||||
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
|
||||
(num_heads, batch_size, query_len, key_len) = attn_weights.shape
|
||||
|
||||
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
|
||||
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
||||
# now x: (num_heads, batch_size, seq_len, value_head_dim)
|
||||
x = self.in_proj(x) # (key_len, batch_size, num_heads * value_head_dim)
|
||||
x = x.reshape(key_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
||||
# now x: (num_heads, batch_size, key_len, value_head_dim)
|
||||
value_head_dim = x.shape[-1]
|
||||
|
||||
# todo: see whether there is benefit in overriding matmul
|
||||
x = torch.matmul(attn_weights, x)
|
||||
# v: (num_heads, batch_size, seq_len, value_head_dim)
|
||||
# v: (num_heads, batch_size, query_len, value_head_dim)
|
||||
|
||||
x = x.permute(2, 1, 0, 3).contiguous().view(
|
||||
seq_len, batch_size, num_heads * value_head_dim)
|
||||
query_len, batch_size, num_heads * value_head_dim)
|
||||
|
||||
# returned value is of shape (seq_len, batch_size, embed_dim), like the input.
|
||||
# returned value is of shape (query_len, batch_size, embed_dim), like the input.
|
||||
x = self.out_proj(x)
|
||||
x = self.whiten(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class MultiheadAttentionWeights(nn.Module):
|
||||
r"""Module that computes multi-head cross-attention weights. Allows src and target
|
||||
to have different dims.
|
||||
|
||||
Args:
|
||||
key_embed_dim: number of channels of the thing that we'll project to
|
||||
make the query (corresponds to source). e.g. 256
|
||||
query_embed_dim: number of channels of the thing that we'll project to
|
||||
make the query (corresponds to target). e.g. 256
|
||||
num_heads: number of heads to compute weights for, e.g. 8
|
||||
head_dim: dimension of the query and key, per head. e.g. 24.
|
||||
dropout: dropout probability for attn_output_weights. Default: 0.0.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key_embed_dim: int,
|
||||
query_embed_dim: int,
|
||||
num_heads: int,
|
||||
head_dim: int,
|
||||
dropout: float = 0.0,
|
||||
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.key_embed_dim = key_embed_dim
|
||||
self.query_embed_dim = query_embed_dim
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = head_dim
|
||||
self.dropout = dropout
|
||||
self.name = None # will be overwritten in training code; for diagnostics.
|
||||
|
||||
|
||||
# the initial_scale is supposed to take over the "scaling" factor of
|
||||
# head_dim ** -0.5 that has been used in previous forms of attention,
|
||||
# dividing it between the query and key. Note: this module is intended
|
||||
# to be used with the ScaledAdam optimizer; with most other optimizers,
|
||||
# it would be necessary to apply the scaling factor in the forward function.
|
||||
self.query_in_proj = ScaledLinear(query_embed_dim,
|
||||
head_dim * num_heads,
|
||||
bias=True,
|
||||
initial_scale=head_dim ** -0.25)
|
||||
|
||||
# weights produced by this module are invariant to adding a constant to
|
||||
# the keys, so we don't need a bias for the keys.
|
||||
self.key_in_proj = ScaledLinear(key_embed_dim,
|
||||
head_dim * num_heads,
|
||||
bias=False,
|
||||
initial_scale=head_dim ** -0.25)
|
||||
|
||||
self.whiten_keys = Whiten(num_groups=num_heads,
|
||||
whitening_limit=_whitening_schedule(3.0),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.025)
|
||||
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
key: Tensor,
|
||||
query: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Args:
|
||||
key: input of shape (key_len, batch_size, key_embed_dim)
|
||||
query: input of shape (query_len, batch_size, query_embed_dim)
|
||||
key_padding_mask: an optional bool tensor of shape (batch_size, key_len). Positions that
|
||||
are True in this mask will be ignored as sources in the attention weighting.
|
||||
Returns:
|
||||
a tensor of attention weights, of shape (hum_heads, batch_size, query_len, key_len)
|
||||
"""
|
||||
q = self.query_in_proj(query)
|
||||
k = self.key_in_proj(key)
|
||||
|
||||
head_dim = self.head_dim
|
||||
num_heads = self.num_heads
|
||||
|
||||
query_len, batch_size, _ = q.shape
|
||||
key_len, _batch_size, _ = k.shape
|
||||
assert _batch_size == batch_size
|
||||
|
||||
k = self.whiten_keys(k) # does nothing in the forward pass.
|
||||
|
||||
q = q.reshape(query_len, batch_size, num_heads, head_dim)
|
||||
k = k.reshape(key_len, batch_size, num_heads, head_dim)
|
||||
|
||||
# time1 refers to target, time2 refers to source.
|
||||
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
|
||||
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
|
||||
|
||||
attn_scores = torch.matmul(q, k)
|
||||
|
||||
if self.training and random.random() < 0.1:
|
||||
# This is a way of limiting the attention scores to not be
|
||||
# too large. It incurs a penalty if any of them has an absolute
|
||||
# value greater than 25.0. this should be outside the normal range
|
||||
# of the attention scores. We use this mechanism instead of, say,
|
||||
# something added to the loss function involving the entropy,
|
||||
# because once the entropy gets very small gradients through the
|
||||
# softmax can become very small, and we'd get zero derivatives. The
|
||||
# choices of 1.0e-04 as the scale on the penalty makes this
|
||||
# mechanism vulnerable to the absolute scale of the loss function,
|
||||
# but we view this as a failsafe to avoid "implausible" parameter
|
||||
# values rather than a regularization method that should be active
|
||||
# under normal circumstances.
|
||||
attn_scores = penalize_abs_values_gt(attn_scores,
|
||||
limit=25.0,
|
||||
penalty=1.0e-04,
|
||||
name=self.name)
|
||||
|
||||
assert attn_scores.shape == (num_heads, batch_size, query_len, key_len)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape == (batch_size, key_len), key_padding_mask.shape
|
||||
attn_scores = attn_scores.masked_fill(
|
||||
key_padding_mask.unsqueeze(1),
|
||||
-1000,
|
||||
)
|
||||
|
||||
# We use our own version of softmax, defined in scaling.py, which should
|
||||
# save a little of the memory used in backprop by, if we are in
|
||||
# automatic mixed precision mode (amp / autocast), by only storing the
|
||||
# half-precision output for backprop purposes.
|
||||
attn_weights = softmax(attn_scores, dim=-1)
|
||||
|
||||
if random.random() < 0.001:
|
||||
self._print_attn_entropy(attn_weights)
|
||||
|
||||
attn_weights = nn.functional.dropout(
|
||||
attn_weights, p=self.dropout, training=self.training
|
||||
)
|
||||
|
||||
return attn_weights
|
||||
|
||||
|
||||
def _print_attn_entropy(
|
||||
self,
|
||||
attn_weights: Tensor):
|
||||
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
|
||||
|
||||
with torch.no_grad():
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
attn_weights = attn_weights.to(torch.float32)
|
||||
attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum(
|
||||
dim=-1).mean(dim=(1,2))
|
||||
logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}")
|
||||
|
||||
|
||||
|
||||
class FeedforwardModule(nn.Module):
|
||||
"""Feedforward module in Zipformer2 model.
|
||||
"""
|
||||
@ -1650,12 +1864,14 @@ def _test_zipformer_main(causal: bool = False):
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
# Just make sure the forward pass runs.
|
||||
memory_dim = 100
|
||||
|
||||
c = Zipformer2(
|
||||
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
||||
causal=causal,
|
||||
chunk_size=(4,) if causal else (-1,),
|
||||
left_context_frames=(64,)
|
||||
left_context_frames=(64,),
|
||||
memory_dim=memory_dim,
|
||||
)
|
||||
batch_size = 5
|
||||
seq_len = 20
|
||||
@ -1663,6 +1879,7 @@ def _test_zipformer_main(causal: bool = False):
|
||||
f = c(
|
||||
torch.randn(seq_len, batch_size, 64),
|
||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||
memory=torch.randn(101, batch_size, memory_dim),
|
||||
)
|
||||
f[0].sum().backward()
|
||||
c.eval()
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user