mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add memory to model
This commit is contained in:
parent
6f5c4688ef
commit
fa696e919b
@ -26,15 +26,16 @@ from icefall.utils import add_sos, make_pad_mask
|
|||||||
from scaling import penalize_abs_values_gt, ScaledLinear
|
from scaling import penalize_abs_values_gt, ScaledLinear
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class PromptedTransducer(nn.Module):
|
||||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||||
"Sequence Transduction with Recurrent Neural Networks"
|
"Sequence Transduction with Recurrent Neural Networks"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
encoder_embed: nn.Module,
|
encoder_embed: nn.Module,
|
||||||
encoder: EncoderInterface,
|
encoder: EncoderInterface,
|
||||||
|
text_embed: nn.Module,
|
||||||
|
text_encoder: EncoderInterface,
|
||||||
decoder: nn.Module,
|
decoder: nn.Module,
|
||||||
joiner: nn.Module,
|
joiner: nn.Module,
|
||||||
encoder_dim: int,
|
encoder_dim: int,
|
||||||
@ -68,6 +69,8 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
self.encoder_embed = encoder_embed
|
self.encoder_embed = encoder_embed
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
|
self.text_embed = text_embed
|
||||||
|
self.text_encoder = text_encoder
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.joiner = joiner
|
self.joiner = joiner
|
||||||
|
|
||||||
@ -86,6 +89,9 @@ class Transducer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
|
text: torch.Tensor,
|
||||||
|
style_lens: torch.Tensor,
|
||||||
|
text_lens: torch.Tensor,
|
||||||
y: k2.RaggedTensor,
|
y: k2.RaggedTensor,
|
||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
@ -98,6 +104,21 @@ class Transducer(nn.Module):
|
|||||||
x_lens:
|
x_lens:
|
||||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
before padding.
|
before padding.
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
|
before padding.
|
||||||
|
text:
|
||||||
|
A 2-D tensor of integer dtype containing prompt text, of shape (N, T).
|
||||||
|
It is exptected to contain the style prompt (first) and then the content
|
||||||
|
prompt.
|
||||||
|
style_lens:
|
||||||
|
A 1-D tensor of shape (N,), containing the number of elements (bytes)
|
||||||
|
within each row of `text` that correspond to the style prompt (these
|
||||||
|
are expected to come first).
|
||||||
|
text_lens:
|
||||||
|
A 1-D tensor of shape (N,). It contains the number of elements (bytes)
|
||||||
|
in `text` before padding, which will include the lengths of the
|
||||||
|
style plus the content prompt.
|
||||||
y:
|
y:
|
||||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||||
utterance.
|
utterance.
|
||||||
@ -125,14 +146,25 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
assert x.size(0) == x_lens.size(0) == y.dim0
|
assert x.size(0) == x_lens.size(0) == y.dim0
|
||||||
|
|
||||||
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
|
|
||||||
x, x_lens = self.encoder_embed(x, x_lens)
|
x, x_lens = self.encoder_embed(x, x_lens)
|
||||||
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
|
|
||||||
|
|
||||||
src_key_padding_mask = make_pad_mask(x_lens)
|
src_key_padding_mask = make_pad_mask(x_lens)
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
text = text.t() # now (T, N)
|
||||||
|
text = self.text_embed(text) # now (T, N, C)
|
||||||
|
text_key_padding_mask = make_pad_mask(text_lens)
|
||||||
|
|
||||||
|
memory, text_lens = self.text_encoder(text, text_lens,
|
||||||
|
text_key_padding_mask)
|
||||||
|
|
||||||
|
memory = self._add_style_indicator(memory, style_lens)
|
||||||
|
|
||||||
|
memory_key_padding_mask = make_pad_mask(text_lens)
|
||||||
|
|
||||||
|
encoder_out, x_lens = self.encoder(x, x_lens, src_key_padding_mask,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask)
|
||||||
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
|
||||||
|
|
||||||
assert torch.all(x_lens > 0)
|
assert torch.all(x_lens > 0)
|
||||||
@ -217,3 +249,24 @@ class Transducer(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return (simple_loss, pruned_loss)
|
return (simple_loss, pruned_loss)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_style_indicator(self, memory: Tensor, style_lens: Tensor):
|
||||||
|
"""
|
||||||
|
Adds to `memory` an indicator that is 0.1 for positions that correspond to
|
||||||
|
the `style prompt` and 0 elsewhere. The scale can be fixed because the
|
||||||
|
scale of the memory vector can adjust to compensate (within limits set
|
||||||
|
by the balancers)..
|
||||||
|
|
||||||
|
Args:
|
||||||
|
memory: (memory_len, batch_size, embed_dim)
|
||||||
|
style_lens: (batch_size,), a vector of lengths of the style prompt.
|
||||||
|
"""
|
||||||
|
|
||||||
|
(memory_len, batch_size, embed_dim) = memory.shape
|
||||||
|
|
||||||
|
|
||||||
|
indicator = torch.arange(memory_len, device=memory.device).unsqueeze(-1) < style_lens
|
||||||
|
indicator = indicator.to(memory.dtype).unsqueeze(-1)
|
||||||
|
|
||||||
|
return memory + indicator
|
||||||
|
|||||||
@ -71,7 +71,9 @@ from lhotse.utils import fix_random_seed
|
|||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from torch import nn
|
||||||
from torch.cuda.amp import GradScaler
|
from torch.cuda.amp import GradScaler
|
||||||
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
@ -159,6 +161,14 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--text-encoder-dim",
|
||||||
|
type=str,
|
||||||
|
default="256,256,384,512",
|
||||||
|
help="Embedding dimension in text encoder stacks: a comma-separated list of 4 elements, "
|
||||||
|
"or you should change other configs in the code."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--query-head-dim",
|
"--query-head-dim",
|
||||||
type=str,
|
type=str,
|
||||||
@ -547,6 +557,32 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
|||||||
return encoder_embed
|
return encoder_embed
|
||||||
|
|
||||||
|
|
||||||
|
def get_text_embed(params: AttributeDict) -> nn.Module:
|
||||||
|
return nn.Embedding(
|
||||||
|
num_embeddings=256, # we encode the text as UTF-8 bytes
|
||||||
|
embedding_dim=_to_int_tuple(params.text_encoder_dim)[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_text_encoder(params: AttributeDict) -> nn.Module:
|
||||||
|
return Zipformer2(
|
||||||
|
output_downsampling_factor=8,
|
||||||
|
downsampling_factor=(1,2,4,8),
|
||||||
|
num_encoder_layers=(2,4,6,6),
|
||||||
|
encoder_dim=_to_int_tuple(params.text_encoder_dim),
|
||||||
|
encoder_unmasked_dim=(196,196,256,256),
|
||||||
|
query_head_dim=(32,32,32,32),
|
||||||
|
pos_head_dim=(4,4,4,4),
|
||||||
|
value_head_dim=(12,12,12,12),
|
||||||
|
pos_dim=48,
|
||||||
|
num_heads=(4,4,4,8),
|
||||||
|
feedforward_dim=(384,512,768,1024), # could increase this if there is nough data
|
||||||
|
cnn_module_kernel=(31,31,15,15),
|
||||||
|
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||||
|
warmup_batches=4000.0,
|
||||||
|
causal=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = Zipformer2(
|
encoder = Zipformer2(
|
||||||
output_downsampling_factor=2,
|
output_downsampling_factor=2,
|
||||||
@ -566,6 +602,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
causal=params.causal,
|
causal=params.causal,
|
||||||
chunk_size=_to_int_tuple(params.chunk_size),
|
chunk_size=_to_int_tuple(params.chunk_size),
|
||||||
left_context_frames=_to_int_tuple(params.left_context_frames),
|
left_context_frames=_to_int_tuple(params.left_context_frames),
|
||||||
|
memory_dim=_to_int_tuple(params.text_encoder_dim)[-1],
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -593,12 +630,16 @@ def get_joiner_model(params: AttributeDict) -> nn.Module:
|
|||||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder_embed = get_encoder_embed(params)
|
encoder_embed = get_encoder_embed(params)
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
|
text_embed = get_text_embed(params)
|
||||||
|
text_encoder = get_text_encoder(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
|
||||||
model = Transducer(
|
model = Transducer(
|
||||||
encoder_embed=encoder_embed,
|
encoder_embed=encoder_embed,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
|
text_embed=text_embed,
|
||||||
|
text_encoder=text_encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
encoder_dim=int(max(params.encoder_dim.split(','))),
|
encoder_dim=int(max(params.encoder_dim.split(','))),
|
||||||
|
|||||||
@ -89,6 +89,9 @@ class Zipformer2(EncoderInterface):
|
|||||||
context chunks for causal training; will be rounded to a number of
|
context chunks for causal training; will be rounded to a number of
|
||||||
chunks. Must not be less than cnn_module_kernel (after factoring in
|
chunks. Must not be less than cnn_module_kernel (after factoring in
|
||||||
rounding and downsampling); an error will be thrown if this is violated.
|
rounding and downsampling); an error will be thrown if this is violated.
|
||||||
|
memory_dim: if supplied and >0, will be the dimension of the memory embeddings
|
||||||
|
passed into the zipformer (e.g. this might be the output of another
|
||||||
|
Zipformer used to create embedding vectors.)
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -103,6 +106,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
num_heads: Union[int, Tuple[int]] = 8,
|
num_heads: Union[int, Tuple[int]] = 8,
|
||||||
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
||||||
cnn_module_kernel: Union[int, Tuple[int]] = 31,
|
cnn_module_kernel: Union[int, Tuple[int]] = 31,
|
||||||
|
memory_dim: int = -1,
|
||||||
pos_dim: int = 192,
|
pos_dim: int = 192,
|
||||||
dropout: FloatLike = None, # see code below for default
|
dropout: FloatLike = None, # see code below for default
|
||||||
warmup_batches: float = 4000.0,
|
warmup_batches: float = 4000.0,
|
||||||
@ -160,6 +164,7 @@ class Zipformer2(EncoderInterface):
|
|||||||
pos_head_dim=pos_head_dim[i],
|
pos_head_dim=pos_head_dim[i],
|
||||||
value_head_dim=value_head_dim[i],
|
value_head_dim=value_head_dim[i],
|
||||||
feedforward_dim=feedforward_dim[i],
|
feedforward_dim=feedforward_dim[i],
|
||||||
|
memory_dim=memory_dim,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
cnn_module_kernel=cnn_module_kernel[i],
|
cnn_module_kernel=cnn_module_kernel[i],
|
||||||
causal=causal,
|
causal=causal,
|
||||||
@ -271,9 +276,12 @@ class Zipformer2(EncoderInterface):
|
|||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor,
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
src_key_padding_mask: Optional[torch.Tensor] = None,
|
src_key_padding_mask: Optional[torch.Tensor] = None,
|
||||||
|
memory: Optional[Tensor] = None,
|
||||||
|
memory_key_padding_mask: Optional[Tensor] = None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -284,7 +292,12 @@ class Zipformer2(EncoderInterface):
|
|||||||
`x` before padding.
|
`x` before padding.
|
||||||
src_key_padding_mask:
|
src_key_padding_mask:
|
||||||
The mask for padding, of shape (batch_size, seq_len); True means
|
The mask for padding, of shape (batch_size, seq_len); True means
|
||||||
masked position. May be None.
|
masked position. May be None.
|
||||||
|
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
|
||||||
|
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
|
||||||
|
attention), of shape (batch_size, memory_len); True means
|
||||||
|
masked position. May be None.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple containing 2 tensors:
|
Return a tuple containing 2 tensors:
|
||||||
- embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim))
|
- embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim))
|
||||||
@ -298,6 +311,13 @@ class Zipformer2(EncoderInterface):
|
|||||||
|
|
||||||
attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
|
attn_mask = self._get_attn_mask(x, chunk_size, left_context_chunks)
|
||||||
|
|
||||||
|
if self.training and memory is not None:
|
||||||
|
batch_size = x.shape[1]
|
||||||
|
# setting memory to zero should be equivalent to not using the
|
||||||
|
# memory input at all, since the Attention module has no biases.
|
||||||
|
memory_dropout_rate = 0.05
|
||||||
|
memory = memory * (torch.rand(batch_size, 1) > memory_dropout_rate)
|
||||||
|
|
||||||
for i, module in enumerate(self.encoders):
|
for i, module in enumerate(self.encoders):
|
||||||
ds = self.downsampling_factor[i]
|
ds = self.downsampling_factor[i]
|
||||||
x = convert_num_channels(x, self.encoder_dim[i])
|
x = convert_num_channels(x, self.encoder_dim[i])
|
||||||
@ -308,6 +328,8 @@ class Zipformer2(EncoderInterface):
|
|||||||
src_key_padding_mask=(None if src_key_padding_mask is None
|
src_key_padding_mask=(None if src_key_padding_mask is None
|
||||||
else src_key_padding_mask[...,::ds]),
|
else src_key_padding_mask[...,::ds]),
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
)
|
)
|
||||||
outputs.append(x)
|
outputs.append(x)
|
||||||
|
|
||||||
@ -416,6 +438,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
dropout: FloatLike = 0.1,
|
dropout: FloatLike = 0.1,
|
||||||
cnn_module_kernel: int = 31,
|
cnn_module_kernel: int = 31,
|
||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
|
memory_dim: int = -1,
|
||||||
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
attention_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||||
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
conv_skip_rate: FloatLike = ScheduledFloat((0.0, 0.2), (4000.0, 0.05), (16000, 0.0), default=0),
|
||||||
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
const_attention_rate: FloatLike = ScheduledFloat((0.0, 0.25), (4000.0, 0.025), default=0),
|
||||||
@ -452,11 +475,25 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.self_attn1 = SelfAttention(embed_dim, num_heads,
|
|
||||||
|
self.self_attn1 = Attention(embed_dim, embed_dim, num_heads,
|
||||||
value_head_dim)
|
value_head_dim)
|
||||||
|
|
||||||
self.self_attn2 = SelfAttention(embed_dim, num_heads,
|
self.self_attn2 = Attention(embed_dim, embed_dim, num_heads,
|
||||||
value_head_dim)
|
value_head_dim)
|
||||||
|
|
||||||
|
if memory_dim > 0:
|
||||||
|
self.attn_weights = MultiheadAttentionWeights(
|
||||||
|
memory_dim, embed_dim,
|
||||||
|
num_heads=num_heads,
|
||||||
|
head_dim=query_head_dim,
|
||||||
|
dropout=0.0,
|
||||||
|
)
|
||||||
|
self.src_attn1 = Attention(memory_dim, embed_dim, num_heads,
|
||||||
|
value_head_dim)
|
||||||
|
self.src_attn2 = Attention(memory_dim, embed_dim, num_heads,
|
||||||
|
value_head_dim)
|
||||||
|
|
||||||
|
|
||||||
self.feed_forward1 = FeedforwardModule(embed_dim,
|
self.feed_forward1 = FeedforwardModule(embed_dim,
|
||||||
(feedforward_dim * 3) // 4,
|
(feedforward_dim * 3) // 4,
|
||||||
@ -579,6 +616,8 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
chunk_size: int = -1,
|
chunk_size: int = -1,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
memory: Optional[Tensor] = None,
|
||||||
|
memory_key_padding_mask: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Pass the input through the encoder layer.
|
Pass the input through the encoder layer.
|
||||||
@ -608,11 +647,14 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
pos_emb=pos_emb,
|
pos_emb=pos_emb,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
key_padding_mask=src_key_padding_mask,
|
key_padding_mask=src_key_padding_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if memory is not None and hasattr(self, 'attn_weights'):
|
||||||
|
src_attn_weights = self.attn_weights(memory, src, memory_key_padding_mask)
|
||||||
|
|
||||||
src = src + self.feed_forward1(src)
|
src = src + self.feed_forward1(src)
|
||||||
|
|
||||||
self_attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
attn_dropout_mask = self.get_sequence_dropout_mask(src, attention_skip_rate)
|
||||||
|
|
||||||
if True:
|
if True:
|
||||||
selected_attn_weights = attn_weights[0:2]
|
selected_attn_weights = attn_weights[0:2]
|
||||||
@ -630,12 +672,16 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
na = self.balancer_na(self.nonlin_attention(src,
|
na = self.balancer_na(self.nonlin_attention(src,
|
||||||
selected_attn_weights[0:1]))
|
selected_attn_weights[0:1]))
|
||||||
|
|
||||||
src = src + (na if self_attn_dropout_mask is None else na * self_attn_dropout_mask)
|
src = src + (na if attn_dropout_mask is None else na * attn_dropout_mask)
|
||||||
|
|
||||||
self_attn = self.self_attn1(
|
self_attn = self.self_attn1(
|
||||||
src, attn_weights)
|
src, attn_weights)
|
||||||
|
|
||||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
src = src + (self_attn if attn_dropout_mask is None else self_attn * attn_dropout_mask)
|
||||||
|
|
||||||
|
if memory is not None and hasattr(self, 'attn_weights'):
|
||||||
|
src = src + self.sequence_dropout(self.src_attn1(memory, src_attn_weights),
|
||||||
|
attention_skip_rate)
|
||||||
|
|
||||||
src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size,
|
src = src + self.sequence_dropout(self.conv_module1(src, chunk_size=chunk_size,
|
||||||
src_key_padding_mask=src_key_padding_mask),
|
src_key_padding_mask=src_key_padding_mask),
|
||||||
@ -650,7 +696,11 @@ class Zipformer2EncoderLayer(nn.Module):
|
|||||||
self_attn = self.self_attn2(
|
self_attn = self.self_attn2(
|
||||||
src, attn_weights)
|
src, attn_weights)
|
||||||
|
|
||||||
src = src + (self_attn if self_attn_dropout_mask is None else self_attn * self_attn_dropout_mask)
|
src = src + (self_attn if attn_dropout_mask is None else self_attn * attn_dropout_mask)
|
||||||
|
|
||||||
|
if memory is not None and hasattr(self, 'attn_weights'):
|
||||||
|
src = src + self.sequence_dropout(self.src_attn2(memory, src_attn_weights),
|
||||||
|
attention_skip_rate)
|
||||||
|
|
||||||
src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size,
|
src = src + self.sequence_dropout(self.conv_module2(src, chunk_size=chunk_size,
|
||||||
src_key_padding_mask=src_key_padding_mask),
|
src_key_padding_mask=src_key_padding_mask),
|
||||||
@ -673,9 +723,9 @@ class Zipformer2Encoder(nn.Module):
|
|||||||
r"""Zipformer2Encoder is a stack of N encoder layers
|
r"""Zipformer2Encoder is a stack of N encoder layers
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
|
encoder_layer: an instance of the Zipformer2EncoderLayer() class (required).
|
||||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||||
pos_dim: the dimension for the relative positional encoding
|
pos_dim: the dimension for the relative positional encoding
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
>>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
|
>>> encoder_layer = Zipformer2EncoderLayer(embed_dim=512, nhead=8)
|
||||||
@ -721,6 +771,8 @@ class Zipformer2Encoder(nn.Module):
|
|||||||
feature_mask: Union[Tensor, float] = 1.0,
|
feature_mask: Union[Tensor, float] = 1.0,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
memory: Optional[Tensor] = None,
|
||||||
|
memory_key_padding_mask: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""Pass the input through the encoder layers in turn.
|
r"""Pass the input through the encoder layers in turn.
|
||||||
|
|
||||||
@ -734,6 +786,10 @@ class Zipformer2Encoder(nn.Module):
|
|||||||
True means masked position. May be None.
|
True means masked position. May be None.
|
||||||
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
||||||
masked position. May be None.
|
masked position. May be None.
|
||||||
|
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
|
||||||
|
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
|
||||||
|
attention), of shape (batch_size, memory_len); True means
|
||||||
|
masked position. May be None.
|
||||||
|
|
||||||
Returns: a Tensor with the same shape as src.
|
Returns: a Tensor with the same shape as src.
|
||||||
"""
|
"""
|
||||||
@ -751,6 +807,8 @@ class Zipformer2Encoder(nn.Module):
|
|||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = output * feature_mask
|
output = output * feature_mask
|
||||||
@ -843,6 +901,8 @@ class DownsampledZipformer2Encoder(nn.Module):
|
|||||||
feature_mask: Union[Tensor, float] = 1.0,
|
feature_mask: Union[Tensor, float] = 1.0,
|
||||||
attn_mask: Optional[Tensor] = None,
|
attn_mask: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
|
memory: Optional[Tensor] = None,
|
||||||
|
memory_key_padding_mask: Optional[Tensor] = None,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
r"""Downsample, go through encoder, upsample.
|
r"""Downsample, go through encoder, upsample.
|
||||||
|
|
||||||
@ -855,6 +915,10 @@ class DownsampledZipformer2Encoder(nn.Module):
|
|||||||
True means masked position. May be None.
|
True means masked position. May be None.
|
||||||
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
||||||
masked position. May be None.
|
masked position. May be None.
|
||||||
|
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
|
||||||
|
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
|
||||||
|
attention), of shape (batch_size, memory_len); True means
|
||||||
|
masked position. May be None.
|
||||||
|
|
||||||
Returns: a Tensor with the same shape as src.
|
Returns: a Tensor with the same shape as src.
|
||||||
"""
|
"""
|
||||||
@ -870,6 +934,8 @@ class DownsampledZipformer2Encoder(nn.Module):
|
|||||||
feature_mask=feature_mask,
|
feature_mask=feature_mask,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
src_key_padding_mask=src_key_padding_mask,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
)
|
)
|
||||||
src = self.upsample(src)
|
src = self.upsample(src)
|
||||||
# remove any extra frames that are not a multiple of downsample_factor
|
# remove any extra frames that are not a multiple of downsample_factor
|
||||||
@ -1185,7 +1251,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
|
|
||||||
query_dim = query_head_dim * num_heads
|
query_dim = query_head_dim * num_heads
|
||||||
|
|
||||||
# self-attention
|
|
||||||
q = x[...,0:query_dim]
|
q = x[...,0:query_dim]
|
||||||
k = x[...,query_dim:2*query_dim]
|
k = x[...,query_dim:2*query_dim]
|
||||||
# p is the position-encoding query
|
# p is the position-encoding query
|
||||||
@ -1231,9 +1296,9 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
attn_scores = attn_scores + pos_scores
|
attn_scores = attn_scores + pos_scores
|
||||||
|
|
||||||
if self.training and random.random() < 0.1:
|
if self.training and random.random() < 0.1:
|
||||||
# This is a harder way of limiting the attention scores to not be
|
# This is away of limiting the attention scores to not be
|
||||||
# too large. It incurs a penalty if any of them has an absolute
|
# too large. It incurs a penalty if any of them has an absolute
|
||||||
# value greater than 50.0. this should be outside the normal range
|
# value greater than 25.0. this should be outside the normal range
|
||||||
# of the attention scores. We use this mechanism instead of, say,
|
# of the attention scores. We use this mechanism instead of, say,
|
||||||
# something added to the loss function involving the entropy,
|
# something added to the loss function involving the entropy,
|
||||||
# because once the entropy gets very small gradients through the
|
# because once the entropy gets very small gradients through the
|
||||||
@ -1295,29 +1360,31 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}")
|
logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}")
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class Attention(nn.Module):
|
||||||
"""
|
"""
|
||||||
The simplest possible attention module. This one works with already-computed attention
|
The simplest possible attention module. This one works with already-computed attention
|
||||||
weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
|
weights, e.g. as computed by RelPositionMultiheadAttentionWeights.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
embed_dim: the input and output embedding dimension
|
embed_dim_in: the input embedding dimension
|
||||||
|
embed_dim_out: the output embedding dimension (normally the same as input)
|
||||||
num_heads: the number of attention heads
|
num_heads: the number of attention heads
|
||||||
value_head_dim: the value dimension per head
|
value_head_dim: the value dimension per head
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embed_dim: int,
|
embed_dim_in: int,
|
||||||
|
embed_dim_out: int,
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
value_head_dim: int,
|
value_head_dim: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_proj = nn.Linear(embed_dim,
|
self.in_proj = nn.Linear(embed_dim_in,
|
||||||
num_heads * value_head_dim,
|
num_heads * value_head_dim,
|
||||||
bias=True)
|
bias=False)
|
||||||
|
|
||||||
self.out_proj = ScaledLinear(num_heads * value_head_dim,
|
self.out_proj = ScaledLinear(num_heads * value_head_dim,
|
||||||
embed_dim, bias=True,
|
embed_dim_out, bias=False,
|
||||||
initial_scale=0.05)
|
initial_scale=0.05)
|
||||||
|
|
||||||
self.whiten = Whiten(num_groups=1,
|
self.whiten = Whiten(num_groups=1,
|
||||||
@ -1334,35 +1401,182 @@ class SelfAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x: input tensor, of shape (seq_len, batch_size, embed_dim)
|
x: input tensor, of shape (seq_len, batch_size, embed_dim)
|
||||||
attn_weights: a tensor of shape (num_heads, batch_size, seq_len, seq_len),
|
attn_weights: a tensor of shape (num_heads, batch_size, query_len, key_len),
|
||||||
with seq_len being interpreted as (tgt_seq_len, src_seq_len). Expect
|
Expect attn_weights.sum(dim=-1) == 1.
|
||||||
attn_weights.sum(dim=-1) == 1.
|
|
||||||
Returns:
|
Returns:
|
||||||
a tensor with the same shape as x.
|
a tensor with the same shape as x.
|
||||||
"""
|
"""
|
||||||
(seq_len, batch_size, embed_dim) = x.shape
|
(num_heads, batch_size, query_len, key_len) = attn_weights.shape
|
||||||
num_heads = attn_weights.shape[0]
|
|
||||||
assert attn_weights.shape == (num_heads, batch_size, seq_len, seq_len)
|
|
||||||
|
|
||||||
x = self.in_proj(x) # (seq_len, batch_size, num_heads * value_head_dim)
|
x = self.in_proj(x) # (key_len, batch_size, num_heads * value_head_dim)
|
||||||
x = x.reshape(seq_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
x = x.reshape(key_len, batch_size, num_heads, -1).permute(2, 1, 0, 3)
|
||||||
# now x: (num_heads, batch_size, seq_len, value_head_dim)
|
# now x: (num_heads, batch_size, key_len, value_head_dim)
|
||||||
value_head_dim = x.shape[-1]
|
value_head_dim = x.shape[-1]
|
||||||
|
|
||||||
# todo: see whether there is benefit in overriding matmul
|
# todo: see whether there is benefit in overriding matmul
|
||||||
x = torch.matmul(attn_weights, x)
|
x = torch.matmul(attn_weights, x)
|
||||||
# v: (num_heads, batch_size, seq_len, value_head_dim)
|
# v: (num_heads, batch_size, query_len, value_head_dim)
|
||||||
|
|
||||||
x = x.permute(2, 1, 0, 3).contiguous().view(
|
x = x.permute(2, 1, 0, 3).contiguous().view(
|
||||||
seq_len, batch_size, num_heads * value_head_dim)
|
query_len, batch_size, num_heads * value_head_dim)
|
||||||
|
|
||||||
# returned value is of shape (seq_len, batch_size, embed_dim), like the input.
|
# returned value is of shape (query_len, batch_size, embed_dim), like the input.
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
x = self.whiten(x)
|
x = self.whiten(x)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiheadAttentionWeights(nn.Module):
|
||||||
|
r"""Module that computes multi-head cross-attention weights. Allows src and target
|
||||||
|
to have different dims.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key_embed_dim: number of channels of the thing that we'll project to
|
||||||
|
make the query (corresponds to source). e.g. 256
|
||||||
|
query_embed_dim: number of channels of the thing that we'll project to
|
||||||
|
make the query (corresponds to target). e.g. 256
|
||||||
|
num_heads: number of heads to compute weights for, e.g. 8
|
||||||
|
head_dim: dimension of the query and key, per head. e.g. 24.
|
||||||
|
dropout: dropout probability for attn_output_weights. Default: 0.0.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
key_embed_dim: int,
|
||||||
|
query_embed_dim: int,
|
||||||
|
num_heads: int,
|
||||||
|
head_dim: int,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.key_embed_dim = key_embed_dim
|
||||||
|
self.query_embed_dim = query_embed_dim
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_dim = head_dim
|
||||||
|
self.dropout = dropout
|
||||||
|
self.name = None # will be overwritten in training code; for diagnostics.
|
||||||
|
|
||||||
|
|
||||||
|
# the initial_scale is supposed to take over the "scaling" factor of
|
||||||
|
# head_dim ** -0.5 that has been used in previous forms of attention,
|
||||||
|
# dividing it between the query and key. Note: this module is intended
|
||||||
|
# to be used with the ScaledAdam optimizer; with most other optimizers,
|
||||||
|
# it would be necessary to apply the scaling factor in the forward function.
|
||||||
|
self.query_in_proj = ScaledLinear(query_embed_dim,
|
||||||
|
head_dim * num_heads,
|
||||||
|
bias=True,
|
||||||
|
initial_scale=head_dim ** -0.25)
|
||||||
|
|
||||||
|
# weights produced by this module are invariant to adding a constant to
|
||||||
|
# the keys, so we don't need a bias for the keys.
|
||||||
|
self.key_in_proj = ScaledLinear(key_embed_dim,
|
||||||
|
head_dim * num_heads,
|
||||||
|
bias=False,
|
||||||
|
initial_scale=head_dim ** -0.25)
|
||||||
|
|
||||||
|
self.whiten_keys = Whiten(num_groups=num_heads,
|
||||||
|
whitening_limit=_whitening_schedule(3.0),
|
||||||
|
prob=(0.025, 0.25),
|
||||||
|
grad_scale=0.025)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
key: Tensor,
|
||||||
|
query: Tensor,
|
||||||
|
key_padding_mask: Optional[Tensor] = None,
|
||||||
|
) -> Tensor:
|
||||||
|
r"""
|
||||||
|
Args:
|
||||||
|
key: input of shape (key_len, batch_size, key_embed_dim)
|
||||||
|
query: input of shape (query_len, batch_size, query_embed_dim)
|
||||||
|
key_padding_mask: an optional bool tensor of shape (batch_size, key_len). Positions that
|
||||||
|
are True in this mask will be ignored as sources in the attention weighting.
|
||||||
|
Returns:
|
||||||
|
a tensor of attention weights, of shape (hum_heads, batch_size, query_len, key_len)
|
||||||
|
"""
|
||||||
|
q = self.query_in_proj(query)
|
||||||
|
k = self.key_in_proj(key)
|
||||||
|
|
||||||
|
head_dim = self.head_dim
|
||||||
|
num_heads = self.num_heads
|
||||||
|
|
||||||
|
query_len, batch_size, _ = q.shape
|
||||||
|
key_len, _batch_size, _ = k.shape
|
||||||
|
assert _batch_size == batch_size
|
||||||
|
|
||||||
|
k = self.whiten_keys(k) # does nothing in the forward pass.
|
||||||
|
|
||||||
|
q = q.reshape(query_len, batch_size, num_heads, head_dim)
|
||||||
|
k = k.reshape(key_len, batch_size, num_heads, head_dim)
|
||||||
|
|
||||||
|
# time1 refers to target, time2 refers to source.
|
||||||
|
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
|
||||||
|
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
|
||||||
|
|
||||||
|
attn_scores = torch.matmul(q, k)
|
||||||
|
|
||||||
|
if self.training and random.random() < 0.1:
|
||||||
|
# This is a way of limiting the attention scores to not be
|
||||||
|
# too large. It incurs a penalty if any of them has an absolute
|
||||||
|
# value greater than 25.0. this should be outside the normal range
|
||||||
|
# of the attention scores. We use this mechanism instead of, say,
|
||||||
|
# something added to the loss function involving the entropy,
|
||||||
|
# because once the entropy gets very small gradients through the
|
||||||
|
# softmax can become very small, and we'd get zero derivatives. The
|
||||||
|
# choices of 1.0e-04 as the scale on the penalty makes this
|
||||||
|
# mechanism vulnerable to the absolute scale of the loss function,
|
||||||
|
# but we view this as a failsafe to avoid "implausible" parameter
|
||||||
|
# values rather than a regularization method that should be active
|
||||||
|
# under normal circumstances.
|
||||||
|
attn_scores = penalize_abs_values_gt(attn_scores,
|
||||||
|
limit=25.0,
|
||||||
|
penalty=1.0e-04,
|
||||||
|
name=self.name)
|
||||||
|
|
||||||
|
assert attn_scores.shape == (num_heads, batch_size, query_len, key_len)
|
||||||
|
|
||||||
|
if key_padding_mask is not None:
|
||||||
|
assert key_padding_mask.shape == (batch_size, key_len), key_padding_mask.shape
|
||||||
|
attn_scores = attn_scores.masked_fill(
|
||||||
|
key_padding_mask.unsqueeze(1),
|
||||||
|
-1000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We use our own version of softmax, defined in scaling.py, which should
|
||||||
|
# save a little of the memory used in backprop by, if we are in
|
||||||
|
# automatic mixed precision mode (amp / autocast), by only storing the
|
||||||
|
# half-precision output for backprop purposes.
|
||||||
|
attn_weights = softmax(attn_scores, dim=-1)
|
||||||
|
|
||||||
|
if random.random() < 0.001:
|
||||||
|
self._print_attn_entropy(attn_weights)
|
||||||
|
|
||||||
|
attn_weights = nn.functional.dropout(
|
||||||
|
attn_weights, p=self.dropout, training=self.training
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_weights
|
||||||
|
|
||||||
|
|
||||||
|
def _print_attn_entropy(
|
||||||
|
self,
|
||||||
|
attn_weights: Tensor):
|
||||||
|
# attn_weights: (num_heads, batch_size, seq_len, seq_len)
|
||||||
|
(num_heads, batch_size, seq_len, seq_len) = attn_weights.shape
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
with torch.cuda.amp.autocast(enabled=False):
|
||||||
|
attn_weights = attn_weights.to(torch.float32)
|
||||||
|
attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum(
|
||||||
|
dim=-1).mean(dim=(1,2))
|
||||||
|
logging.info(f"name={self.name}, attn_weights_entropy = {attn_weights_entropy}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FeedforwardModule(nn.Module):
|
class FeedforwardModule(nn.Module):
|
||||||
"""Feedforward module in Zipformer2 model.
|
"""Feedforward module in Zipformer2 model.
|
||||||
"""
|
"""
|
||||||
@ -1650,12 +1864,14 @@ def _test_zipformer_main(causal: bool = False):
|
|||||||
batch_size = 5
|
batch_size = 5
|
||||||
seq_len = 20
|
seq_len = 20
|
||||||
# Just make sure the forward pass runs.
|
# Just make sure the forward pass runs.
|
||||||
|
memory_dim = 100
|
||||||
|
|
||||||
c = Zipformer2(
|
c = Zipformer2(
|
||||||
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
||||||
causal=causal,
|
causal=causal,
|
||||||
chunk_size=(4,) if causal else (-1,),
|
chunk_size=(4,) if causal else (-1,),
|
||||||
left_context_frames=(64,)
|
left_context_frames=(64,),
|
||||||
|
memory_dim=memory_dim,
|
||||||
)
|
)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
seq_len = 20
|
seq_len = 20
|
||||||
@ -1663,6 +1879,7 @@ def _test_zipformer_main(causal: bool = False):
|
|||||||
f = c(
|
f = c(
|
||||||
torch.randn(seq_len, batch_size, 64),
|
torch.randn(seq_len, batch_size, 64),
|
||||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
|
memory=torch.randn(101, batch_size, memory_dim),
|
||||||
)
|
)
|
||||||
f[0].sum().backward()
|
f[0].sum().backward()
|
||||||
c.eval()
|
c.eval()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user