mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Fix various bugs
This commit is contained in:
parent
f740282a1a
commit
1b8be0744f
@ -19,7 +19,6 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn, Tensor
|
from torch import nn, Tensor
|
||||||
from chunk_decoder import ChunkDecoder
|
|
||||||
from zipformer import Zipformer2
|
from zipformer import Zipformer2
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +27,7 @@ class Zipformer2LM(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
encoder_embed: nn.Module,
|
encoder_embed: nn.Module,
|
||||||
encoder: Zipformer2,
|
encoder: Zipformer2,
|
||||||
decoder: ChunkDecoder):
|
decoder: nn.Module):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.encoder_embed = encoder_embed
|
self.encoder_embed = encoder_embed
|
||||||
self.encoder = encoder # does subsampling
|
self.encoder = encoder # does subsampling
|
||||||
@ -47,18 +46,17 @@ class Zipformer2LM(nn.Module):
|
|||||||
"""
|
"""
|
||||||
(batch_size, seq_len) = labels.shape
|
(batch_size, seq_len) = labels.shape
|
||||||
|
|
||||||
chunk_size = self.decoder.chunk_size
|
chunk_size = 1
|
||||||
labels_shifted = labels.t() # (time, batch)
|
labels_shifted = labels.t() # (time, batch)
|
||||||
labels_shifted = torch.cat((torch.zeros_like(labels_shifted[:chunk_size]),
|
labels_shifted = torch.cat((torch.zeros_like(labels_shifted[:1]),
|
||||||
labels_shifted[:-chunk_size]),
|
labels_shifted[:-1]),
|
||||||
dim=0)
|
dim=0)
|
||||||
|
|
||||||
x = self.encoder_embed(labels_shifted)
|
x = self.encoder_embed(labels_shifted)
|
||||||
x_lens = torch.full((batch_size,), seq_len,
|
x_lens = torch.full((batch_size,), seq_len,
|
||||||
dtype=torch.long, device=labels.device)
|
dtype=torch.long, device=labels.device)
|
||||||
|
|
||||||
# x_lens is after subsampling. Actually we don't need it.
|
# x_lens is after subsampling. Actually we don't need it.
|
||||||
|
|
||||||
|
|
||||||
(x, x_lens) = self.encoder(x, x_lens)
|
(x, x_lens) = self.encoder(x, x_lens)
|
||||||
|
|
||||||
logprobs = self.decoder(labels, x)
|
logprobs = self.decoder(labels, x)
|
||||||
|
@ -76,11 +76,7 @@ class Subformer2(EncoderInterface):
|
|||||||
dropout (float): dropout rate
|
dropout (float): dropout rate
|
||||||
warmup_batches (float): number of batches to warm up over; this controls
|
warmup_batches (float): number of batches to warm up over; this controls
|
||||||
dropout of encoder layers.
|
dropout of encoder layers.
|
||||||
causal (bool): if True, support chunkwise causal convolution. This should
|
causal (bool): if True, use causal attention-mask.
|
||||||
not hurt WER as no modeling power is lost, but the convolution modules will be
|
|
||||||
slightly slower and use more memory. Enables use of the chunk_size and
|
|
||||||
left_context_chunks options in forward(), which simulates streaming
|
|
||||||
decoding.
|
|
||||||
memory_dim: if supplied and >0, will be the dimension of the memory embeddings
|
memory_dim: if supplied and >0, will be the dimension of the memory embeddings
|
||||||
passed into the zipformer (e.g. this might be the output of another
|
passed into the zipformer (e.g. this might be the output of another
|
||||||
Subformer used to create embedding vectors.)
|
Subformer used to create embedding vectors.)
|
||||||
@ -97,7 +93,6 @@ class Subformer2(EncoderInterface):
|
|||||||
num_heads: Union[int, Tuple[int]] = 8,
|
num_heads: Union[int, Tuple[int]] = 8,
|
||||||
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
||||||
memory_dim: int = -1,
|
memory_dim: int = -1,
|
||||||
pos_emb_dim: int = 192,
|
|
||||||
pos_dim: int = 4,
|
pos_dim: int = 4,
|
||||||
dropout: FloatLike = None, # see code below for default
|
dropout: FloatLike = None, # see code below for default
|
||||||
warmup_batches: float = 4000.0,
|
warmup_batches: float = 4000.0,
|
||||||
@ -129,6 +124,7 @@ class Subformer2(EncoderInterface):
|
|||||||
value_head_dim = _to_tuple(value_head_dim)
|
value_head_dim = _to_tuple(value_head_dim)
|
||||||
num_heads = _to_tuple(num_heads)
|
num_heads = _to_tuple(num_heads)
|
||||||
feedforward_dim = _to_tuple(feedforward_dim)
|
feedforward_dim = _to_tuple(feedforward_dim)
|
||||||
|
self.causal = causal
|
||||||
|
|
||||||
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
||||||
assert u <= d
|
assert u <= d
|
||||||
@ -156,7 +152,6 @@ class Subformer2(EncoderInterface):
|
|||||||
encoder = Subformer2Encoder(
|
encoder = Subformer2Encoder(
|
||||||
encoder_layer,
|
encoder_layer,
|
||||||
num_encoder_layers[i],
|
num_encoder_layers[i],
|
||||||
pos_dim=pos_dim,
|
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
||||||
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
||||||
@ -173,14 +168,15 @@ class Subformer2(EncoderInterface):
|
|||||||
|
|
||||||
encoders.append(encoder)
|
encoders.append(encoder)
|
||||||
|
|
||||||
self.encoder_pos = CompactRelPositionalEncoding(pos_emb_dim, pos_dim, dropout_rate=0.15,
|
self.encoder_pos = CompactRelPositionalEncoding(64, pos_dim,
|
||||||
|
dropout_rate=0.15,
|
||||||
length_factor=1.0)
|
length_factor=1.0)
|
||||||
|
|
||||||
self.encoders = nn.ModuleList(encoders)
|
self.encoders = nn.ModuleList(encoders)
|
||||||
|
|
||||||
self.downsample_output = SimpleDownsample(max(encoder_dim),
|
#self.downsample_output = SimpleDownsample(max(encoder_dim),
|
||||||
downsample=output_downsampling_factor,
|
# downsample=output_downsampling_factor,
|
||||||
dropout=dropout)
|
# dropout=dropout)
|
||||||
|
|
||||||
def get_feature_masks(
|
def get_feature_masks(
|
||||||
self,
|
self,
|
||||||
@ -273,7 +269,7 @@ class Subformer2(EncoderInterface):
|
|||||||
outputs = []
|
outputs = []
|
||||||
feature_masks = self.get_feature_masks(x)
|
feature_masks = self.get_feature_masks(x)
|
||||||
|
|
||||||
attn_offset = self._get_attn_offset(x)
|
attn_offset = self._get_attn_offset(x, src_key_padding_mask)
|
||||||
|
|
||||||
if self.training and memory is not None:
|
if self.training and memory is not None:
|
||||||
batch_size = x.shape[1]
|
batch_size = x.shape[1]
|
||||||
@ -286,15 +282,11 @@ class Subformer2(EncoderInterface):
|
|||||||
pos_emb = self.encoder_pos(x)
|
pos_emb = self.encoder_pos(x)
|
||||||
|
|
||||||
for i, module in enumerate(self.encoders):
|
for i, module in enumerate(self.encoders):
|
||||||
ds = self.downsampling_factor[i]
|
|
||||||
x = convert_num_channels(x, self.encoder_dim[i])
|
x = convert_num_channels(x, self.encoder_dim[i])
|
||||||
|
|
||||||
x = module(x,
|
x = module(x,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
chunk_size=chunk_size,
|
|
||||||
feature_mask=feature_masks[i],
|
feature_mask=feature_masks[i],
|
||||||
src_key_padding_mask=(None if src_key_padding_mask is None
|
|
||||||
else src_key_padding_mask[...,::ds]),
|
|
||||||
attn_offset=attn_offset,
|
attn_offset=attn_offset,
|
||||||
memory=memory,
|
memory=memory,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
@ -321,37 +313,37 @@ class Subformer2(EncoderInterface):
|
|||||||
# from different pieces of 'outputs', taking each dimension from the
|
# from different pieces of 'outputs', taking each dimension from the
|
||||||
# most recent output that has it present.
|
# most recent output that has it present.
|
||||||
x = get_full_dim_output()
|
x = get_full_dim_output()
|
||||||
x = self.downsample_output(x)
|
#x = self.downsample_output(x)
|
||||||
|
|
||||||
d = self.output_downsampling_factor
|
d = self.output_downsampling_factor
|
||||||
lengths = (x_lens + d - 1) // d
|
lengths = (x_lens + d - 1) // d
|
||||||
|
|
||||||
return x, lengths
|
return x, lengths
|
||||||
|
|
||||||
def _get_attn_offset(self, x: Tensor) -> Optional[Tensor]:
|
def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]:
|
||||||
"""
|
"""
|
||||||
Return attention offset of shape (1, seq_len, seq_len), interpreted as (tgt_seq_len,
|
Return attention offset of shape (1 or batch_size, seq_len, seq_len), interpreted as (1 or batch_size, tgt_seq_len,
|
||||||
src_seq_len); this reflects masking, if causal == True, otherwise will be all zeros.
|
src_seq_len); this reflects masking, if causal == True, otherwise will be all zeros.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
|
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
|
||||||
chunk_size: chunk size, must divide
|
src_key_padding_mask: optional key-padding mask of shape (batch_size, seq_len) with True in masked positions.
|
||||||
"""
|
"""
|
||||||
if not self.causal:
|
seq_len, batch_size, _num_channels = x.shape
|
||||||
return None
|
|
||||||
|
|
||||||
seq_len = x.shape[0]
|
|
||||||
|
|
||||||
# t is frame index, shape (seq_len,)
|
|
||||||
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
|
||||||
src_c = c
|
|
||||||
tgt_c = c.unsqueeze(-1)
|
|
||||||
|
|
||||||
attn_mask = (src_c > tgt_c)
|
|
||||||
|
|
||||||
ans = torch.zeros(1, seq_len, seq_len, device=x.device)
|
ans = torch.zeros(1, seq_len, seq_len, device=x.device)
|
||||||
|
|
||||||
ans.masked_fill(attn_mask, float('-inf'))
|
if self.causal:
|
||||||
|
# t is frame index, shape (seq_len,)
|
||||||
|
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
||||||
|
src_t = t
|
||||||
|
tgt_t = t.unsqueeze(-1)
|
||||||
|
attn_mask = (src_t > tgt_t)
|
||||||
|
ans.masked_fill(attn_mask, float('-inf'))
|
||||||
|
|
||||||
|
if src_key_padding_mask is not None:
|
||||||
|
ans = ans * src_key_padding_mask.unsqueeze(1).logical_not()
|
||||||
|
# now ans: (batch_size, seq_len, seq_len).
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
@ -384,11 +376,10 @@ class Subformer2EncoderLayer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
pos_dim: int,
|
|
||||||
num_heads: int,
|
num_heads: int,
|
||||||
query_head_dim: int,
|
query_head_dim: int,
|
||||||
pos_dim: int,
|
|
||||||
value_head_dim: int,
|
value_head_dim: int,
|
||||||
|
pos_dim: int,
|
||||||
feedforward_dim: int,
|
feedforward_dim: int,
|
||||||
dropout: FloatLike = 0.1,
|
dropout: FloatLike = 0.1,
|
||||||
causal: bool = False,
|
causal: bool = False,
|
||||||
@ -431,14 +422,15 @@ class Subformer2EncoderLayer(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
self.self_attn1 = Attention(embed_dim, embed_dim, num_heads,
|
self.self_attn1 = Attention(embed_dim, embed_dim, num_heads,
|
||||||
value_head_dim)
|
value_head_dim)
|
||||||
|
|
||||||
self.self_attn2 = Attention(embed_dim, embed_dim, num_heads,
|
self.self_attn2 = Attention(embed_dim, embed_dim, num_heads,
|
||||||
value_head_dim)
|
value_head_dim)
|
||||||
|
|
||||||
if memory_dim > 0:
|
if memory_dim > 0:
|
||||||
self.attn_weights = MultiheadAttentionWeights(
|
self.attn_weights = MultiheadAttentionWeights(
|
||||||
memory_dim, embed_dim,
|
memory_dim,
|
||||||
|
embed_dim,
|
||||||
num_heads=num_heads,
|
num_heads=num_heads,
|
||||||
head_dim=query_head_dim,
|
head_dim=query_head_dim,
|
||||||
dropout=0.0,
|
dropout=0.0,
|
||||||
@ -559,7 +551,6 @@ class Subformer2EncoderLayer(nn.Module):
|
|||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
chunk_size: int = -1,
|
|
||||||
attn_offset: Optional[Tensor] = None,
|
attn_offset: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
src_key_padding_mask: Optional[Tensor] = None,
|
||||||
memory: Optional[Tensor] = None,
|
memory: Optional[Tensor] = None,
|
||||||
@ -570,7 +561,6 @@ class Subformer2EncoderLayer(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
||||||
pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
|
pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
|
||||||
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
|
|
||||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||||
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
||||||
attn_offset: the attention offset, of shape broadcasting with (batch_size, seq_len, seq_len),
|
attn_offset: the attention offset, of shape broadcasting with (batch_size, seq_len, seq_len),
|
||||||
@ -591,7 +581,6 @@ class Subformer2EncoderLayer(nn.Module):
|
|||||||
src,
|
src,
|
||||||
pos_emb=pos_emb,
|
pos_emb=pos_emb,
|
||||||
attn_offset=attn_offset,
|
attn_offset=attn_offset,
|
||||||
key_padding_mask=src_key_padding_mask,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if memory is not None and hasattr(self, 'attn_weights'):
|
if memory is not None and hasattr(self, 'attn_weights'):
|
||||||
@ -662,7 +651,6 @@ class Subformer2Encoder(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
encoder_layer: an instance of the Subformer2EncoderLayer() class (required).
|
encoder_layer: an instance of the Subformer2EncoderLayer() class (required).
|
||||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||||
pos_dim: the dimension for the relative positional encoding
|
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
>>> encoder_layer = Subformer2EncoderLayer(embed_dim=512, nhead=8)
|
>>> encoder_layer = Subformer2EncoderLayer(embed_dim=512, nhead=8)
|
||||||
@ -674,7 +662,6 @@ class Subformer2Encoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
encoder_layer: nn.Module,
|
encoder_layer: nn.Module,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
pos_dim: int,
|
|
||||||
dropout: float,
|
dropout: float,
|
||||||
warmup_begin: float,
|
warmup_begin: float,
|
||||||
warmup_end: float,
|
warmup_end: float,
|
||||||
@ -701,10 +688,9 @@ class Subformer2Encoder(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
chunk_size: int = -1,
|
pos_emb: Tensor,
|
||||||
feature_mask: Union[Tensor, float] = 1.0,
|
feature_mask: Union[Tensor, float] = 1.0,
|
||||||
attn_offset: Optional[Tensor] = None,
|
attn_offset: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
memory: Optional[Tensor] = None,
|
memory: Optional[Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[Tensor] = None,
|
memory_key_padding_mask: Optional[Tensor] = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
@ -712,14 +698,13 @@ class Subformer2Encoder(nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
||||||
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
|
pos_emb: positional embedding tensor, of shape (batch_size, seq_len, seq_len, pos_dim),
|
||||||
|
e.g. pos_dim=4.
|
||||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||||
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
||||||
attn_offset: the attention offset (does masking and related tasks), of shape
|
attn_offset: the attention offset (does masking and related tasks), of shape
|
||||||
broadcasting with (batch_size, seq_len, seq_len),
|
broadcasting with (batch_size, seq_len, seq_len),
|
||||||
interpreted as (batch_size, tgt_seq_len, src_seq_len).
|
interpreted as (batch_size, tgt_seq_len, src_seq_len).
|
||||||
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
|
||||||
masked position. May be None.
|
|
||||||
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
|
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
|
||||||
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
|
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
|
||||||
attention), of shape (batch_size, memory_len); True means
|
attention), of shape (batch_size, memory_len); True means
|
||||||
@ -727,7 +712,6 @@ class Subformer2Encoder(nn.Module):
|
|||||||
|
|
||||||
Returns: a Tensor with the same shape as src.
|
Returns: a Tensor with the same shape as src.
|
||||||
"""
|
"""
|
||||||
pos_emb = self.encoder_pos(src)
|
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
rnd_seed = src.numel() + random.randint(0, 1000)
|
rnd_seed = src.numel() + random.randint(0, 1000)
|
||||||
@ -738,9 +722,7 @@ class Subformer2Encoder(nn.Module):
|
|||||||
output = mod(
|
output = mod(
|
||||||
output,
|
output,
|
||||||
pos_emb,
|
pos_emb,
|
||||||
chunk_size=chunk_size,
|
|
||||||
attn_offset=attn_offset,
|
attn_offset=attn_offset,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
|
||||||
memory=memory,
|
memory=memory,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
)
|
)
|
||||||
@ -827,12 +809,13 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
embed_dim: int,
|
embed_dim: int,
|
||||||
downsampling_factor: int,
|
downsampling_factor: int,
|
||||||
intermediate_rate: FloatLike = 0.2):
|
intermediate_rate: FloatLike = 0.2):
|
||||||
|
super().__init__()
|
||||||
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
|
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
|
||||||
# score_balancer is just to keep the magnitudes of the scores in
|
# score_balancer is just to keep the magnitudes of the scores in
|
||||||
# a fixed range and keep them balanced around zero, to stop
|
# a fixed range and keep them balanced around zero, to stop
|
||||||
# these drifting around.
|
# these drifting around.
|
||||||
self.score_balancer = Balancer(1, channel_dim=-1,
|
self.score_balancer = Balancer(1, channel_dim=-1,
|
||||||
min_positive=0.4, max_positive=0.6
|
min_positive=0.4, max_positive=0.6,
|
||||||
min_abs=1.0, max_abs=1.2,
|
min_abs=1.0, max_abs=1.2,
|
||||||
prob=0.025)
|
prob=0.025)
|
||||||
|
|
||||||
@ -856,14 +839,14 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
corresponding to the kept frames; these will be between 0 and 1, but
|
corresponding to the kept frames; these will be between 0 and 1, but
|
||||||
mostly exactly 1.
|
mostly exactly 1.
|
||||||
"""
|
"""
|
||||||
(seq_len, batch_size, _)
|
(seq_len, batch_size, _) = x.shape
|
||||||
scores = self.to_scores(x) # (seq_len, batch_size, 1)
|
scores = self.to_scores(x) # (seq_len, batch_size, 1)
|
||||||
scores = self.score_balancer(scores)
|
scores = self.score_balancer(scores)
|
||||||
|
|
||||||
scores = scores.squeeze(-1).t() # (batch_size, seq_len)
|
scores = scores.squeeze(-1).t() # (batch_size, seq_len)
|
||||||
|
|
||||||
# indexes, sscores: (batch_size, seq_len)
|
# sscores, indexes: (batch_size, seq_len)
|
||||||
indexes, sscores = scores.sort(dim=-1, descending=True)
|
sscores, indexes = scores.sort(dim=-1, descending=True)
|
||||||
|
|
||||||
d = self.downsampling_factor
|
d = self.downsampling_factor
|
||||||
seq_len_reduced = (seq_len + d - 1) // d
|
seq_len_reduced = (seq_len + d - 1) // d
|
||||||
@ -883,10 +866,10 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
collar = max(1, int(seq_len_reduced * 0.5 * self.intermediate_rate))
|
collar = max(1, int(seq_len_reduced * 0.5 * self.intermediate_rate))
|
||||||
|
|
||||||
# right_avg: shape (batch_size,), this is to be mapped to 0.0
|
# right_avg: shape (batch_size,), this is to be mapped to 0.0
|
||||||
right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1)
|
right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1, keepdim=True)
|
||||||
|
|
||||||
# left_avg: shape (batch_size,), this is to be mapped to 1.0
|
# left_avg: shape (batch_size,), this is to be mapped to 1.0
|
||||||
left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1)
|
left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1, keepdim=True)
|
||||||
|
|
||||||
# the + 0.001 is to avoid possible division by zero in case of ties.
|
# the + 0.001 is to avoid possible division by zero in case of ties.
|
||||||
weights = (sscores - right_avg) / (left_avg - right_avg + 0.001)
|
weights = (sscores - right_avg) / (left_avg - right_avg + 0.001)
|
||||||
@ -901,11 +884,11 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
indexes, reorder = indexes.sort(dim=-1)
|
indexes, reorder = indexes.sort(dim=-1)
|
||||||
weights = torch.gather(weights, dim=-1, index=reorder)
|
weights = torch.gather(weights, dim=-1, index=reorder)
|
||||||
|
|
||||||
x_downsampled = downsample(indexes, x)
|
x_downsampled = self.downsample(x, indexes)
|
||||||
return indexes, weights, x_downsampled
|
return indexes, weights, x_downsampled
|
||||||
|
|
||||||
|
|
||||||
def downsample(x: Tensor, indexes: Tensor) -> Tensor:
|
def downsample(self, x: Tensor, indexes: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Downsamples x via indexing with the indexes obtained from the
|
Downsamples x via indexing with the indexes obtained from the
|
||||||
forward() function.
|
forward() function.
|
||||||
@ -917,19 +900,19 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
x_downsampled, of shape (seq_len_reduced, batch_size, num_channels)
|
x_downsampled, of shape (seq_len_reduced, batch_size, num_channels)
|
||||||
"""
|
"""
|
||||||
indexes = indexes.t().unsqueeze(-1).expand(-1, -1, x.shape[-1])
|
indexes_expanded = indexes.t().unsqueeze(-1).expand(-1, -1, x.shape[-1])
|
||||||
# indexes now: (seq_len_reduced, batch_size, num_channels)
|
# indexe_expanded: (seq_len_reduced, batch_size, num_channels)
|
||||||
ans = torch.gather(x, dim=0, index=indexes)
|
ans = torch.gather(x, dim=0, index=indexes_expanded)
|
||||||
|
|
||||||
if __name__ == __main__:
|
if __name__ == '__main__':
|
||||||
# temp, for testing
|
# temp, for testing
|
||||||
x_reconstructed = upsample(x, ans, indexes)
|
x_reconstructed = self.upsample(x, ans, indexes)
|
||||||
assert torch.allclose(x, x_reconstructed)
|
assert torch.allclose(x, x_reconstructed)
|
||||||
|
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
def downsample_pos_emb(pos_emb: Tensor, indexes: Tensor) -> Tensor:
|
def downsample_pos_emb(self, pos_emb: Tensor, indexes: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Downsample positional embedding tensor with the provided indexes.
|
Downsample positional embedding tensor with the provided indexes.
|
||||||
Args:
|
Args:
|
||||||
@ -958,7 +941,8 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
return pos_emb
|
return pos_emb
|
||||||
|
|
||||||
|
|
||||||
def downsample_attn_offset(attn_offset: Tensor,
|
def downsample_attn_offset(self,
|
||||||
|
attn_offset: Tensor,
|
||||||
indexes: Tensor,
|
indexes: Tensor,
|
||||||
weights: Tensor,
|
weights: Tensor,
|
||||||
eps: float = 1.0e-05) -> Tensor:
|
eps: float = 1.0e-05) -> Tensor:
|
||||||
@ -979,18 +963,17 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len)
|
assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len)
|
||||||
attn_offset = attn_offset.expand(batch_size, seq_len, seq_len)
|
attn_offset = attn_offset.expand(batch_size, seq_len, seq_len)
|
||||||
|
|
||||||
|
attn_offset = attn_offset.gather(dim=1, index=indexes.unsqueeze(-1).expand(
|
||||||
attn_offset = attn_offset.gather(dim=1, src=indices.unsqueeze(-1).expand(
|
|
||||||
batch_size, seq_len_reduced, seq_len))
|
batch_size, seq_len_reduced, seq_len))
|
||||||
attn_offset = attn_offset.gather(dim=2, src=indices.unsqueeze(1).expand(
|
attn_offset = attn_offset.gather(dim=2, index=indexes.unsqueeze(1).expand(
|
||||||
batch_size, seq_len_reduced, seq_len_reduced))
|
batch_size, seq_len_reduced, seq_len_reduced))
|
||||||
# unsqueeze at position 1 so the extra cost relates to the source position.
|
# unsqueeze at position 1 so the extra cost relates to the source position.
|
||||||
attn_offset = attn_offset + weights.clamp(min=eps).log().unsqueeze(1)
|
attn_offset = attn_offset + weights.clamp(min=eps).log().unsqueeze(1)
|
||||||
|
|
||||||
return attn_offst
|
return attn_offset
|
||||||
|
|
||||||
|
|
||||||
def upsample(x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor:
|
def upsample(self, x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Upsamples, reversing the downsample() operation and filling in
|
Upsamples, reversing the downsample() operation and filling in
|
||||||
any not-chosen frames with their original value before downsampling
|
any not-chosen frames with their original value before downsampling
|
||||||
@ -1013,14 +996,14 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
|
|
||||||
not_kept = torch.ones(batch_size, seq_len, dtype=torch.bool,
|
not_kept = torch.ones(batch_size, seq_len, dtype=torch.bool,
|
||||||
device=x.device)
|
device=x.device)
|
||||||
not_kept.scatter_(src=False, dim=1, index=indexes)
|
not_kept.scatter_(dim=1, index=indexes, value=False)
|
||||||
|
|
||||||
indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels)
|
indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels)
|
||||||
# indexes now: (seq_len_reduced, batch_size, num_channels)
|
# indexes now: (seq_len_reduced, batch_size, num_channels)
|
||||||
|
|
||||||
ans = torch.zeros_like(x_orig)
|
ans = torch.zeros_like(x_orig)
|
||||||
|
|
||||||
ans.scatter_(x, dim=0, index=indexes)
|
ans.scatter_(dim=0, index=indexes, src=x)
|
||||||
|
|
||||||
# add in x_orig in the frames that were not originally kept.
|
# add in x_orig in the frames that were not originally kept.
|
||||||
return ans + x_orig * not_kept.t().unsqueeze(-1)
|
return ans + x_orig * not_kept.t().unsqueeze(-1)
|
||||||
@ -1051,7 +1034,6 @@ class DownsampledSubformer2Encoder(nn.Module):
|
|||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
feature_mask: Union[Tensor, float] = 1.0,
|
feature_mask: Union[Tensor, float] = 1.0,
|
||||||
attn_offset: Optional[Tensor] = None,
|
attn_offset: Optional[Tensor] = None,
|
||||||
src_key_padding_mask: Optional[Tensor] = None,
|
|
||||||
memory: Optional[Tensor] = None,
|
memory: Optional[Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[Tensor] = None,
|
memory_key_padding_mask: Optional[Tensor] = None,
|
||||||
) -> Tuple[Tensor, Tensor]:
|
) -> Tuple[Tensor, Tensor]:
|
||||||
@ -1065,8 +1047,6 @@ class DownsampledSubformer2Encoder(nn.Module):
|
|||||||
attn_offset: the attention offset, added to scores for attention of shape
|
attn_offset: the attention offset, added to scores for attention of shape
|
||||||
(batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
(batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
||||||
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
||||||
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
|
||||||
masked position. May be None.
|
|
||||||
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
|
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
|
||||||
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
|
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
|
||||||
attention), of shape (batch_size, memory_len); True means
|
attention), of shape (batch_size, memory_len); True means
|
||||||
@ -1079,30 +1059,24 @@ class DownsampledSubformer2Encoder(nn.Module):
|
|||||||
|
|
||||||
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
|
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
|
||||||
|
|
||||||
attn_offset = self.downsample.downsample_attn_offset(attn_offset,
|
attn_offset = self.downsampler.downsample_attn_offset(attn_offset,
|
||||||
indexes,
|
indexes,
|
||||||
weights.clamp(min=1.0e-05))
|
weights)
|
||||||
|
|
||||||
|
|
||||||
src = self.encoder(
|
src = self.encoder(
|
||||||
src,
|
src,
|
||||||
os_emb,
|
pos_emb,
|
||||||
feature_mask=feature_mask,
|
feature_mask=feature_mask,
|
||||||
attn_offset=attn_offset,
|
attn_offset=attn_offset,
|
||||||
src_key_padding_mask=src_key_padding_mask,
|
|
||||||
memory=memory,
|
memory=memory,
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
)
|
)
|
||||||
src = self.upsample(src)
|
src = self.downsampler.upsample(src_orig, src, indexes)
|
||||||
# remove any extra frames that are not a multiple of downsample_factor
|
|
||||||
src = src[:src_orig.shape[0]]
|
|
||||||
|
|
||||||
return self.out_combiner(src_orig, src)
|
return self.out_combiner(src_orig, src)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CompactRelPositionalEncoding(torch.nn.Module):
|
class CompactRelPositionalEncoding(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Relative positional encoding module. This version is "compact" meaning it is able to encode
|
Relative positional encoding module. This version is "compact" meaning it is able to encode
|
||||||
@ -1123,6 +1097,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
embed_dim: Temporary embedding dimension used inside this module
|
embed_dim: Temporary embedding dimension used inside this module
|
||||||
|
pos_dim: Smaller positional-encoding dim used after a projecction.
|
||||||
dropout_rate: Dropout rate.
|
dropout_rate: Dropout rate.
|
||||||
max_len: Maximum input length: just a heuristic for initialization.
|
max_len: Maximum input length: just a heuristic for initialization.
|
||||||
length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
|
length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
|
||||||
@ -1130,11 +1105,12 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
pos_dim: dimension at the output of this module.
|
pos_dim: dimension at the output of this module.
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self, embed_dim: int,
|
self,
|
||||||
|
embed_dim: int,
|
||||||
|
pos_dim: int,
|
||||||
dropout_rate: FloatLike,
|
dropout_rate: FloatLike,
|
||||||
max_len: int = 1000,
|
max_len: int = 1000,
|
||||||
length_factor: float = 1.0,
|
length_factor: float = 1.0,
|
||||||
pos_dim: int = 4,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Construct a CompactRelPositionalEncoding object."""
|
"""Construct a CompactRelPositionalEncoding object."""
|
||||||
super(CompactRelPositionalEncoding, self).__init__()
|
super(CompactRelPositionalEncoding, self).__init__()
|
||||||
@ -1211,16 +1187,13 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
x (torch.Tensor): Input tensor (time, batch, num_channels_in)
|
x (torch.Tensor): Input tensor (time, batch, num_channels_in)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
positional embedding, of shape (1, 2*time-1, pos_dim).
|
positional embedding, of shape (batch_size, 2*time-1, pos_dim).
|
||||||
|
|
||||||
"""
|
"""
|
||||||
self.extend_pe(x)
|
self.extend_pe(x)
|
||||||
seq_len = x.size(0)
|
seq_len = x.size(0)
|
||||||
pos_emb = self.pe[
|
pos_emb = self.pe[
|
||||||
self.pe.size(0) // 2
|
self.pe.size(0) // 2 - seq_len + 1 : self.pe.size(0) // 2 + seq_len,
|
||||||
- seq_len,
|
|
||||||
+ 1 : self.pe.size(0) // 2 # noqa E203
|
|
||||||
+ seq_len,
|
|
||||||
:
|
:
|
||||||
]
|
]
|
||||||
pos_emb = pos_emb.unsqueeze(0)
|
pos_emb = pos_emb.unsqueeze(0)
|
||||||
@ -1230,12 +1203,20 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
|||||||
# currenly pos_emb: (1, 2*seq_len-1, pos_dim)
|
# currenly pos_emb: (1, 2*seq_len-1, pos_dim)
|
||||||
pos_dim = pos_emb.shape[-1]
|
pos_dim = pos_emb.shape[-1]
|
||||||
batch_size = x.size(1)
|
batch_size = x.size(1)
|
||||||
(_, seq_stride, channel_stride) = pos_emb.stride()
|
|
||||||
# it doesn't really matter which one we make positive and which negative here, it
|
# it doesn't really matter which one we make positive and which negative here, it
|
||||||
# would just flip the meaning of the embedding.
|
# would just flip the meaning of the embedding.
|
||||||
|
|
||||||
|
|
||||||
|
# expand the '1' dimension to seq_len; this introduces a dimension that
|
||||||
|
# 'does nothing', just creates copies, as a workaround for lack of torch support
|
||||||
|
# for negative strides.
|
||||||
|
pos_emb = pos_emb.expand(seq_len, 2*seq_len-1, pos_dim).contiguous()
|
||||||
|
|
||||||
|
(useless_stride, seq_stride, channel_stride) = pos_emb.stride()
|
||||||
|
|
||||||
pos_emb = pos_emb.as_strided((batch_size, seq_len, seq_len, pos_dim),
|
pos_emb = pos_emb.as_strided((batch_size, seq_len, seq_len, pos_dim),
|
||||||
(0, -seq_stride, seq_stride, channel_stride),
|
(0, useless_stride-seq_stride, seq_stride, channel_stride),
|
||||||
storage_offset=seq_stride * (seqs_len - 1))
|
storage_offset=seq_stride * (seq_len - 1))
|
||||||
|
|
||||||
return pos_emb # (batch_size, seq_len, seq_len, pos_dim)
|
return pos_emb # (batch_size, seq_len, seq_len, pos_dim)
|
||||||
|
|
||||||
@ -1326,8 +1307,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
x: Tensor,
|
x: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
attn_offset: Optional[Tensor] = None,
|
attn_offset: Optional[Tensor] = None,
|
||||||
pos_emb: Tensor,
|
|
||||||
quadratic_pos_weight: Tensor,
|
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@ -1368,35 +1347,23 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
|
q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
|
||||||
p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
|
p = p.reshape(seq_len, batch_size, num_heads, pos_dim)
|
||||||
k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
|
k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
|
||||||
|
|
||||||
# time1 refers to target, time2 refers to source.
|
q = q.permute(2, 1, 0, 3) # (head, batch, tgt_seq_len, query_head_dim)
|
||||||
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
|
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, src_seq_len)
|
||||||
p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
|
|
||||||
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
|
|
||||||
|
|
||||||
|
# attn_scores: (num_heads, batch_size, tgt_seq_len, src_esq_len)
|
||||||
attn_scores = torch.matmul(q, k)
|
attn_scores = torch.matmul(q, k)
|
||||||
|
|
||||||
if not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
if not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||||
pos_emb = self.linear_pos(pos_emb)
|
# pos_emb: (batch_size, tgt_seq_len, src_seq_len, pos_dim)
|
||||||
seq_len2 = 2 * seq_len - 1
|
p = p.permute(1, 0, 3, 2) # (batch_size, tgt_seq_len, pos_dim, num_heads)
|
||||||
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
|
|
||||||
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
|
||||||
|
|
||||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
|
||||||
# [where seq_len2 represents relative position.]
|
|
||||||
pos_scores = torch.matmul(p, pos_emb)
|
|
||||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
|
||||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
|
||||||
# not, but let this code define which way round it is supposed to be.
|
|
||||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
|
|
||||||
(pos_scores.stride(0),
|
|
||||||
pos_scores.stride(1),
|
|
||||||
pos_scores.stride(2)-pos_scores.stride(3),
|
|
||||||
pos_scores.stride(3)),
|
|
||||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
|
||||||
|
|
||||||
|
pos_scores = torch.matmul(pos_emb, p)
|
||||||
|
# pos_scores: (batch_size, tgt_seq_len, src_seq_len, num_heads)
|
||||||
|
pos_scores = pos_scores.permute(3, 0, 1, 2)
|
||||||
|
# pos_scores: (num_heads, batch_size, tgt_seq_len, src_seq_len)
|
||||||
attn_scores = attn_scores + pos_scores
|
attn_scores = attn_scores + pos_scores
|
||||||
|
|
||||||
if self.training and random.random() < 0.1:
|
if self.training and random.random() < 0.1:
|
||||||
@ -1417,23 +1384,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
penalty=1.0e-04,
|
penalty=1.0e-04,
|
||||||
name=self.name)
|
name=self.name)
|
||||||
|
|
||||||
|
# attn_offset includes key-padding mask and attention-mask, plus any weights
|
||||||
|
# from the subsampling.
|
||||||
|
attn_scores = attn_scores + attn_offset
|
||||||
|
|
||||||
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
|
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
|
||||||
|
|
||||||
if attn_mask is not None:
|
|
||||||
assert attn_mask.dtype == torch.bool
|
|
||||||
# use -1000 to avoid nan's where attn_mask and key_padding_mask make
|
|
||||||
# all scores zero. It's important that this be large enough that exp(-1000)
|
|
||||||
# is exactly zero, for reasons related to const_attention_rate, it
|
|
||||||
# compares the final weights with zero.
|
|
||||||
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
|
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
|
||||||
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
|
|
||||||
attn_scores = attn_scores.masked_fill(
|
|
||||||
key_padding_mask.unsqueeze(1),
|
|
||||||
-1000,
|
|
||||||
)
|
|
||||||
|
|
||||||
# We use our own version of softmax, defined in scaling.py, which should
|
# We use our own version of softmax, defined in scaling.py, which should
|
||||||
# save a little of the memory used in backprop by, if we are in
|
# save a little of the memory used in backprop by, if we are in
|
||||||
# automatic mixed precision mode (amp / autocast), by only storing the
|
# automatic mixed precision mode (amp / autocast), by only storing the
|
||||||
@ -1617,9 +1573,9 @@ class MultiheadAttentionWeights(nn.Module):
|
|||||||
q = q.reshape(query_len, batch_size, num_heads, head_dim)
|
q = q.reshape(query_len, batch_size, num_heads, head_dim)
|
||||||
k = k.reshape(key_len, batch_size, num_heads, head_dim)
|
k = k.reshape(key_len, batch_size, num_heads, head_dim)
|
||||||
|
|
||||||
# time1 refers to target, time2 refers to source.
|
# tgt_seq_len refers to target, src_seq_len refers to source.
|
||||||
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
|
q = q.permute(2, 1, 0, 3) # (head, batch, tgt_seq_len, query_head_dim)
|
||||||
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
|
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, src_seq_len)
|
||||||
|
|
||||||
attn_scores = torch.matmul(q, k)
|
attn_scores = torch.matmul(q, k)
|
||||||
|
|
||||||
@ -1842,8 +1798,6 @@ def _test_zipformer_main(causal: bool = False):
|
|||||||
c = Subformer2(
|
c = Subformer2(
|
||||||
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
||||||
causal=causal,
|
causal=causal,
|
||||||
chunk_size=(4,) if causal else (-1,),
|
|
||||||
left_context_frames=(64,),
|
|
||||||
memory_dim=memory_dim,
|
memory_dim=memory_dim,
|
||||||
)
|
)
|
||||||
batch_size = 5
|
batch_size = 5
|
||||||
|
@ -63,7 +63,7 @@ from lm_datamodule import LmDataset, LmDataloader
|
|||||||
from zipformer import Zipformer2
|
from zipformer import Zipformer2
|
||||||
from scaling import ScheduledFloat
|
from scaling import ScheduledFloat
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
from chunk_decoder import ChunkDecoder
|
from decoder import Decoder
|
||||||
from model import Zipformer2LM
|
from model import Zipformer2LM
|
||||||
from optim import Eden, ScaledAdam
|
from optim import Eden, ScaledAdam
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -176,13 +176,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list."
|
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list."
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--pos-dim",
|
|
||||||
type=int,
|
|
||||||
default="48",
|
|
||||||
help="Positional-encoding embedding dimension"
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-unmasked-dim",
|
"--encoder-unmasked-dim",
|
||||||
type=str,
|
type=str,
|
||||||
@ -505,9 +498,9 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
|
#chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
|
||||||
encoder = Zipformer2(
|
encoder = Zipformer2(
|
||||||
output_downsampling_factor=chunk_size,
|
#output_downsampling_factor=chunk_size,
|
||||||
downsampling_factor=_to_int_tuple(params.downsampling_factor),
|
downsampling_factor=_to_int_tuple(params.downsampling_factor),
|
||||||
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
||||||
encoder_dim=_to_int_tuple(params.encoder_dim),
|
encoder_dim=_to_int_tuple(params.encoder_dim),
|
||||||
@ -515,10 +508,8 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
query_head_dim=_to_int_tuple(params.query_head_dim),
|
query_head_dim=_to_int_tuple(params.query_head_dim),
|
||||||
pos_head_dim=_to_int_tuple(params.pos_head_dim),
|
pos_head_dim=_to_int_tuple(params.pos_head_dim),
|
||||||
value_head_dim=_to_int_tuple(params.value_head_dim),
|
value_head_dim=_to_int_tuple(params.value_head_dim),
|
||||||
pos_dim=params.pos_dim,
|
|
||||||
num_heads=_to_int_tuple(params.num_heads),
|
num_heads=_to_int_tuple(params.num_heads),
|
||||||
feedforward_dim=_to_int_tuple(params.feedforward_dim),
|
feedforward_dim=_to_int_tuple(params.feedforward_dim),
|
||||||
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
|
|
||||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||||
warmup_batches=4000.0,
|
warmup_batches=4000.0,
|
||||||
causal=True,
|
causal=True,
|
||||||
@ -529,13 +520,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
|
|
||||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||||
chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
|
decoder = DecoderDecoder(
|
||||||
decoder = ChunkDecoder(
|
|
||||||
embed_dim=max(_to_int_tuple(params.encoder_dim)),
|
embed_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||||
chunk_size=chunk_size,
|
|
||||||
vocab_size=256, # bytes
|
vocab_size=256, # bytes
|
||||||
hidden_size=params.decoder_hidden_size,
|
|
||||||
num_layers=params.decoder_num_layers,
|
|
||||||
)
|
)
|
||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user