mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Fix various bugs
This commit is contained in:
parent
f740282a1a
commit
1b8be0744f
@ -19,7 +19,6 @@
|
||||
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
from chunk_decoder import ChunkDecoder
|
||||
from zipformer import Zipformer2
|
||||
|
||||
|
||||
@ -28,7 +27,7 @@ class Zipformer2LM(nn.Module):
|
||||
def __init__(self,
|
||||
encoder_embed: nn.Module,
|
||||
encoder: Zipformer2,
|
||||
decoder: ChunkDecoder):
|
||||
decoder: nn.Module):
|
||||
super().__init__()
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder = encoder # does subsampling
|
||||
@ -47,18 +46,17 @@ class Zipformer2LM(nn.Module):
|
||||
"""
|
||||
(batch_size, seq_len) = labels.shape
|
||||
|
||||
chunk_size = self.decoder.chunk_size
|
||||
chunk_size = 1
|
||||
labels_shifted = labels.t() # (time, batch)
|
||||
labels_shifted = torch.cat((torch.zeros_like(labels_shifted[:chunk_size]),
|
||||
labels_shifted[:-chunk_size]),
|
||||
labels_shifted = torch.cat((torch.zeros_like(labels_shifted[:1]),
|
||||
labels_shifted[:-1]),
|
||||
dim=0)
|
||||
|
||||
x = self.encoder_embed(labels_shifted)
|
||||
x_lens = torch.full((batch_size,), seq_len,
|
||||
dtype=torch.long, device=labels.device)
|
||||
|
||||
# x_lens is after subsampling. Actually we don't need it.
|
||||
|
||||
|
||||
(x, x_lens) = self.encoder(x, x_lens)
|
||||
|
||||
logprobs = self.decoder(labels, x)
|
||||
|
@ -76,11 +76,7 @@ class Subformer2(EncoderInterface):
|
||||
dropout (float): dropout rate
|
||||
warmup_batches (float): number of batches to warm up over; this controls
|
||||
dropout of encoder layers.
|
||||
causal (bool): if True, support chunkwise causal convolution. This should
|
||||
not hurt WER as no modeling power is lost, but the convolution modules will be
|
||||
slightly slower and use more memory. Enables use of the chunk_size and
|
||||
left_context_chunks options in forward(), which simulates streaming
|
||||
decoding.
|
||||
causal (bool): if True, use causal attention-mask.
|
||||
memory_dim: if supplied and >0, will be the dimension of the memory embeddings
|
||||
passed into the zipformer (e.g. this might be the output of another
|
||||
Subformer used to create embedding vectors.)
|
||||
@ -97,7 +93,6 @@ class Subformer2(EncoderInterface):
|
||||
num_heads: Union[int, Tuple[int]] = 8,
|
||||
feedforward_dim: Union[int, Tuple[int]] = 1536,
|
||||
memory_dim: int = -1,
|
||||
pos_emb_dim: int = 192,
|
||||
pos_dim: int = 4,
|
||||
dropout: FloatLike = None, # see code below for default
|
||||
warmup_batches: float = 4000.0,
|
||||
@ -129,6 +124,7 @@ class Subformer2(EncoderInterface):
|
||||
value_head_dim = _to_tuple(value_head_dim)
|
||||
num_heads = _to_tuple(num_heads)
|
||||
feedforward_dim = _to_tuple(feedforward_dim)
|
||||
self.causal = causal
|
||||
|
||||
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
||||
assert u <= d
|
||||
@ -156,7 +152,6 @@ class Subformer2(EncoderInterface):
|
||||
encoder = Subformer2Encoder(
|
||||
encoder_layer,
|
||||
num_encoder_layers[i],
|
||||
pos_dim=pos_dim,
|
||||
dropout=dropout,
|
||||
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
||||
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
||||
@ -173,14 +168,15 @@ class Subformer2(EncoderInterface):
|
||||
|
||||
encoders.append(encoder)
|
||||
|
||||
self.encoder_pos = CompactRelPositionalEncoding(pos_emb_dim, pos_dim, dropout_rate=0.15,
|
||||
self.encoder_pos = CompactRelPositionalEncoding(64, pos_dim,
|
||||
dropout_rate=0.15,
|
||||
length_factor=1.0)
|
||||
|
||||
self.encoders = nn.ModuleList(encoders)
|
||||
|
||||
self.downsample_output = SimpleDownsample(max(encoder_dim),
|
||||
downsample=output_downsampling_factor,
|
||||
dropout=dropout)
|
||||
#self.downsample_output = SimpleDownsample(max(encoder_dim),
|
||||
# downsample=output_downsampling_factor,
|
||||
# dropout=dropout)
|
||||
|
||||
def get_feature_masks(
|
||||
self,
|
||||
@ -273,7 +269,7 @@ class Subformer2(EncoderInterface):
|
||||
outputs = []
|
||||
feature_masks = self.get_feature_masks(x)
|
||||
|
||||
attn_offset = self._get_attn_offset(x)
|
||||
attn_offset = self._get_attn_offset(x, src_key_padding_mask)
|
||||
|
||||
if self.training and memory is not None:
|
||||
batch_size = x.shape[1]
|
||||
@ -286,15 +282,11 @@ class Subformer2(EncoderInterface):
|
||||
pos_emb = self.encoder_pos(x)
|
||||
|
||||
for i, module in enumerate(self.encoders):
|
||||
ds = self.downsampling_factor[i]
|
||||
x = convert_num_channels(x, self.encoder_dim[i])
|
||||
|
||||
x = module(x,
|
||||
pos_emb,
|
||||
chunk_size=chunk_size,
|
||||
feature_mask=feature_masks[i],
|
||||
src_key_padding_mask=(None if src_key_padding_mask is None
|
||||
else src_key_padding_mask[...,::ds]),
|
||||
attn_offset=attn_offset,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
@ -321,37 +313,37 @@ class Subformer2(EncoderInterface):
|
||||
# from different pieces of 'outputs', taking each dimension from the
|
||||
# most recent output that has it present.
|
||||
x = get_full_dim_output()
|
||||
x = self.downsample_output(x)
|
||||
#x = self.downsample_output(x)
|
||||
|
||||
d = self.output_downsampling_factor
|
||||
lengths = (x_lens + d - 1) // d
|
||||
|
||||
return x, lengths
|
||||
|
||||
def _get_attn_offset(self, x: Tensor) -> Optional[Tensor]:
|
||||
def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]:
|
||||
"""
|
||||
Return attention offset of shape (1, seq_len, seq_len), interpreted as (tgt_seq_len,
|
||||
Return attention offset of shape (1 or batch_size, seq_len, seq_len), interpreted as (1 or batch_size, tgt_seq_len,
|
||||
src_seq_len); this reflects masking, if causal == True, otherwise will be all zeros.
|
||||
|
||||
Args:
|
||||
x: embeddings after self.encoder_embed(), of shape (seq_len, batch_size, embed_dim).
|
||||
chunk_size: chunk size, must divide
|
||||
src_key_padding_mask: optional key-padding mask of shape (batch_size, seq_len) with True in masked positions.
|
||||
"""
|
||||
if not self.causal:
|
||||
return None
|
||||
|
||||
seq_len = x.shape[0]
|
||||
|
||||
# t is frame index, shape (seq_len,)
|
||||
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
||||
src_c = c
|
||||
tgt_c = c.unsqueeze(-1)
|
||||
|
||||
attn_mask = (src_c > tgt_c)
|
||||
seq_len, batch_size, _num_channels = x.shape
|
||||
|
||||
ans = torch.zeros(1, seq_len, seq_len, device=x.device)
|
||||
|
||||
ans.masked_fill(attn_mask, float('-inf'))
|
||||
if self.causal:
|
||||
# t is frame index, shape (seq_len,)
|
||||
t = torch.arange(seq_len, dtype=torch.int32, device=x.device)
|
||||
src_t = t
|
||||
tgt_t = t.unsqueeze(-1)
|
||||
attn_mask = (src_t > tgt_t)
|
||||
ans.masked_fill(attn_mask, float('-inf'))
|
||||
|
||||
if src_key_padding_mask is not None:
|
||||
ans = ans * src_key_padding_mask.unsqueeze(1).logical_not()
|
||||
# now ans: (batch_size, seq_len, seq_len).
|
||||
|
||||
return ans
|
||||
|
||||
@ -384,11 +376,10 @@ class Subformer2EncoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embed_dim: int,
|
||||
pos_dim: int,
|
||||
num_heads: int,
|
||||
query_head_dim: int,
|
||||
pos_dim: int,
|
||||
value_head_dim: int,
|
||||
pos_dim: int,
|
||||
feedforward_dim: int,
|
||||
dropout: FloatLike = 0.1,
|
||||
causal: bool = False,
|
||||
@ -431,14 +422,15 @@ class Subformer2EncoderLayer(nn.Module):
|
||||
|
||||
|
||||
self.self_attn1 = Attention(embed_dim, embed_dim, num_heads,
|
||||
value_head_dim)
|
||||
value_head_dim)
|
||||
|
||||
self.self_attn2 = Attention(embed_dim, embed_dim, num_heads,
|
||||
value_head_dim)
|
||||
|
||||
if memory_dim > 0:
|
||||
self.attn_weights = MultiheadAttentionWeights(
|
||||
memory_dim, embed_dim,
|
||||
memory_dim,
|
||||
embed_dim,
|
||||
num_heads=num_heads,
|
||||
head_dim=query_head_dim,
|
||||
dropout=0.0,
|
||||
@ -559,7 +551,6 @@ class Subformer2EncoderLayer(nn.Module):
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
chunk_size: int = -1,
|
||||
attn_offset: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
memory: Optional[Tensor] = None,
|
||||
@ -570,7 +561,6 @@ class Subformer2EncoderLayer(nn.Module):
|
||||
Args:
|
||||
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
||||
pos_emb: (1, 2*seq_len-1, pos_emb_dim) or (batch_size, 2*seq_len-1, pos_emb_dim)
|
||||
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
|
||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
||||
attn_offset: the attention offset, of shape broadcasting with (batch_size, seq_len, seq_len),
|
||||
@ -591,7 +581,6 @@ class Subformer2EncoderLayer(nn.Module):
|
||||
src,
|
||||
pos_emb=pos_emb,
|
||||
attn_offset=attn_offset,
|
||||
key_padding_mask=src_key_padding_mask,
|
||||
)
|
||||
|
||||
if memory is not None and hasattr(self, 'attn_weights'):
|
||||
@ -662,7 +651,6 @@ class Subformer2Encoder(nn.Module):
|
||||
Args:
|
||||
encoder_layer: an instance of the Subformer2EncoderLayer() class (required).
|
||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||
pos_dim: the dimension for the relative positional encoding
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = Subformer2EncoderLayer(embed_dim=512, nhead=8)
|
||||
@ -674,7 +662,6 @@ class Subformer2Encoder(nn.Module):
|
||||
self,
|
||||
encoder_layer: nn.Module,
|
||||
num_layers: int,
|
||||
pos_dim: int,
|
||||
dropout: float,
|
||||
warmup_begin: float,
|
||||
warmup_end: float,
|
||||
@ -701,10 +688,9 @@ class Subformer2Encoder(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
chunk_size: int = -1,
|
||||
pos_emb: Tensor,
|
||||
feature_mask: Union[Tensor, float] = 1.0,
|
||||
attn_offset: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
memory: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tensor:
|
||||
@ -712,14 +698,13 @@ class Subformer2Encoder(nn.Module):
|
||||
|
||||
Args:
|
||||
src: the sequence to the encoder (required): shape (seq_len, batch_size, embedding_dim).
|
||||
chunk_size: the number of frames per chunk, of >= 0; if -1, no chunking.
|
||||
pos_emb: positional embedding tensor, of shape (batch_size, seq_len, seq_len, pos_dim),
|
||||
e.g. pos_dim=4.
|
||||
feature_mask: something that broadcasts with src, that we'll multiply `src`
|
||||
by at every layer: if a Tensor, likely of shape (seq_len, batch_size, embedding_dim)
|
||||
attn_offset: the attention offset (does masking and related tasks), of shape
|
||||
broadcasting with (batch_size, seq_len, seq_len),
|
||||
interpreted as (batch_size, tgt_seq_len, src_seq_len).
|
||||
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
||||
masked position. May be None.
|
||||
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
|
||||
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
|
||||
attention), of shape (batch_size, memory_len); True means
|
||||
@ -727,7 +712,6 @@ class Subformer2Encoder(nn.Module):
|
||||
|
||||
Returns: a Tensor with the same shape as src.
|
||||
"""
|
||||
pos_emb = self.encoder_pos(src)
|
||||
output = src
|
||||
|
||||
rnd_seed = src.numel() + random.randint(0, 1000)
|
||||
@ -738,9 +722,7 @@ class Subformer2Encoder(nn.Module):
|
||||
output = mod(
|
||||
output,
|
||||
pos_emb,
|
||||
chunk_size=chunk_size,
|
||||
attn_offset=attn_offset,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
@ -827,12 +809,13 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
embed_dim: int,
|
||||
downsampling_factor: int,
|
||||
intermediate_rate: FloatLike = 0.2):
|
||||
super().__init__()
|
||||
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
|
||||
# score_balancer is just to keep the magnitudes of the scores in
|
||||
# a fixed range and keep them balanced around zero, to stop
|
||||
# these drifting around.
|
||||
self.score_balancer = Balancer(1, channel_dim=-1,
|
||||
min_positive=0.4, max_positive=0.6
|
||||
min_positive=0.4, max_positive=0.6,
|
||||
min_abs=1.0, max_abs=1.2,
|
||||
prob=0.025)
|
||||
|
||||
@ -856,14 +839,14 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
corresponding to the kept frames; these will be between 0 and 1, but
|
||||
mostly exactly 1.
|
||||
"""
|
||||
(seq_len, batch_size, _)
|
||||
(seq_len, batch_size, _) = x.shape
|
||||
scores = self.to_scores(x) # (seq_len, batch_size, 1)
|
||||
scores = self.score_balancer(scores)
|
||||
|
||||
scores = scores.squeeze(-1).t() # (batch_size, seq_len)
|
||||
|
||||
# indexes, sscores: (batch_size, seq_len)
|
||||
indexes, sscores = scores.sort(dim=-1, descending=True)
|
||||
# sscores, indexes: (batch_size, seq_len)
|
||||
sscores, indexes = scores.sort(dim=-1, descending=True)
|
||||
|
||||
d = self.downsampling_factor
|
||||
seq_len_reduced = (seq_len + d - 1) // d
|
||||
@ -883,10 +866,10 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
collar = max(1, int(seq_len_reduced * 0.5 * self.intermediate_rate))
|
||||
|
||||
# right_avg: shape (batch_size,), this is to be mapped to 0.0
|
||||
right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1)
|
||||
right_avg = sscores[:, right-collar:right+collar+1].mean(dim=-1, keepdim=True)
|
||||
|
||||
# left_avg: shape (batch_size,), this is to be mapped to 1.0
|
||||
left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1)
|
||||
left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1, keepdim=True)
|
||||
|
||||
# the + 0.001 is to avoid possible division by zero in case of ties.
|
||||
weights = (sscores - right_avg) / (left_avg - right_avg + 0.001)
|
||||
@ -901,11 +884,11 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
indexes, reorder = indexes.sort(dim=-1)
|
||||
weights = torch.gather(weights, dim=-1, index=reorder)
|
||||
|
||||
x_downsampled = downsample(indexes, x)
|
||||
x_downsampled = self.downsample(x, indexes)
|
||||
return indexes, weights, x_downsampled
|
||||
|
||||
|
||||
def downsample(x: Tensor, indexes: Tensor) -> Tensor:
|
||||
def downsample(self, x: Tensor, indexes: Tensor) -> Tensor:
|
||||
"""
|
||||
Downsamples x via indexing with the indexes obtained from the
|
||||
forward() function.
|
||||
@ -917,19 +900,19 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
Returns:
|
||||
x_downsampled, of shape (seq_len_reduced, batch_size, num_channels)
|
||||
"""
|
||||
indexes = indexes.t().unsqueeze(-1).expand(-1, -1, x.shape[-1])
|
||||
# indexes now: (seq_len_reduced, batch_size, num_channels)
|
||||
ans = torch.gather(x, dim=0, index=indexes)
|
||||
indexes_expanded = indexes.t().unsqueeze(-1).expand(-1, -1, x.shape[-1])
|
||||
# indexe_expanded: (seq_len_reduced, batch_size, num_channels)
|
||||
ans = torch.gather(x, dim=0, index=indexes_expanded)
|
||||
|
||||
if __name__ == __main__:
|
||||
if __name__ == '__main__':
|
||||
# temp, for testing
|
||||
x_reconstructed = upsample(x, ans, indexes)
|
||||
x_reconstructed = self.upsample(x, ans, indexes)
|
||||
assert torch.allclose(x, x_reconstructed)
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
def downsample_pos_emb(pos_emb: Tensor, indexes: Tensor) -> Tensor:
|
||||
def downsample_pos_emb(self, pos_emb: Tensor, indexes: Tensor) -> Tensor:
|
||||
"""
|
||||
Downsample positional embedding tensor with the provided indexes.
|
||||
Args:
|
||||
@ -958,7 +941,8 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
return pos_emb
|
||||
|
||||
|
||||
def downsample_attn_offset(attn_offset: Tensor,
|
||||
def downsample_attn_offset(self,
|
||||
attn_offset: Tensor,
|
||||
indexes: Tensor,
|
||||
weights: Tensor,
|
||||
eps: float = 1.0e-05) -> Tensor:
|
||||
@ -979,18 +963,17 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
assert len(attn_offset.shape) == 3 # (1, seq_len, seq_len) or (batch_size, seq_len, seq_len)
|
||||
attn_offset = attn_offset.expand(batch_size, seq_len, seq_len)
|
||||
|
||||
|
||||
attn_offset = attn_offset.gather(dim=1, src=indices.unsqueeze(-1).expand(
|
||||
attn_offset = attn_offset.gather(dim=1, index=indexes.unsqueeze(-1).expand(
|
||||
batch_size, seq_len_reduced, seq_len))
|
||||
attn_offset = attn_offset.gather(dim=2, src=indices.unsqueeze(1).expand(
|
||||
attn_offset = attn_offset.gather(dim=2, index=indexes.unsqueeze(1).expand(
|
||||
batch_size, seq_len_reduced, seq_len_reduced))
|
||||
# unsqueeze at position 1 so the extra cost relates to the source position.
|
||||
attn_offset = attn_offset + weights.clamp(min=eps).log().unsqueeze(1)
|
||||
|
||||
return attn_offst
|
||||
return attn_offset
|
||||
|
||||
|
||||
def upsample(x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor:
|
||||
def upsample(self, x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor:
|
||||
"""
|
||||
Upsamples, reversing the downsample() operation and filling in
|
||||
any not-chosen frames with their original value before downsampling
|
||||
@ -1013,14 +996,14 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
|
||||
not_kept = torch.ones(batch_size, seq_len, dtype=torch.bool,
|
||||
device=x.device)
|
||||
not_kept.scatter_(src=False, dim=1, index=indexes)
|
||||
not_kept.scatter_(dim=1, index=indexes, value=False)
|
||||
|
||||
indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels)
|
||||
# indexes now: (seq_len_reduced, batch_size, num_channels)
|
||||
|
||||
ans = torch.zeros_like(x_orig)
|
||||
|
||||
ans.scatter_(x, dim=0, index=indexes)
|
||||
ans.scatter_(dim=0, index=indexes, src=x)
|
||||
|
||||
# add in x_orig in the frames that were not originally kept.
|
||||
return ans + x_orig * not_kept.t().unsqueeze(-1)
|
||||
@ -1051,7 +1034,6 @@ class DownsampledSubformer2Encoder(nn.Module):
|
||||
pos_emb: Tensor,
|
||||
feature_mask: Union[Tensor, float] = 1.0,
|
||||
attn_offset: Optional[Tensor] = None,
|
||||
src_key_padding_mask: Optional[Tensor] = None,
|
||||
memory: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
@ -1065,8 +1047,6 @@ class DownsampledSubformer2Encoder(nn.Module):
|
||||
attn_offset: the attention offset, added to scores for attention of shape
|
||||
(batch_size, seq_len, seq_len) or (seq_len, seq_len),
|
||||
interpreted as (batch_size, tgt_seq_len, src_seq_len) or (tgt_seq_len, src_seq_len).
|
||||
src_key_padding_mask: the mask for padding, of shape (batch_size, seq_len); True means
|
||||
masked position. May be None.
|
||||
memory: optionally, the memory embeddings of shape (memory_len, batch_size, memory_dim)
|
||||
memory_key_padding_mask: optionally the mask for padding of memory input (for source-
|
||||
attention), of shape (batch_size, memory_len); True means
|
||||
@ -1079,30 +1059,24 @@ class DownsampledSubformer2Encoder(nn.Module):
|
||||
|
||||
pos_emb = self.downsampler.downsample_pos_emb(pos_emb, indexes)
|
||||
|
||||
attn_offset = self.downsample.downsample_attn_offset(attn_offset,
|
||||
indexes,
|
||||
weights.clamp(min=1.0e-05))
|
||||
|
||||
attn_offset = self.downsampler.downsample_attn_offset(attn_offset,
|
||||
indexes,
|
||||
weights)
|
||||
|
||||
src = self.encoder(
|
||||
src,
|
||||
os_emb,
|
||||
pos_emb,
|
||||
feature_mask=feature_mask,
|
||||
attn_offset=attn_offset,
|
||||
src_key_padding_mask=src_key_padding_mask,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
src = self.upsample(src)
|
||||
# remove any extra frames that are not a multiple of downsample_factor
|
||||
src = src[:src_orig.shape[0]]
|
||||
src = self.downsampler.upsample(src_orig, src, indexes)
|
||||
|
||||
return self.out_combiner(src_orig, src)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
"""
|
||||
Relative positional encoding module. This version is "compact" meaning it is able to encode
|
||||
@ -1123,6 +1097,7 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
|
||||
Args:
|
||||
embed_dim: Temporary embedding dimension used inside this module
|
||||
pos_dim: Smaller positional-encoding dim used after a projecction.
|
||||
dropout_rate: Dropout rate.
|
||||
max_len: Maximum input length: just a heuristic for initialization.
|
||||
length_factor: a heuristic scale (should be >= 1.0) which, if larger, gives
|
||||
@ -1130,11 +1105,12 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
pos_dim: dimension at the output of this module.
|
||||
"""
|
||||
def __init__(
|
||||
self, embed_dim: int,
|
||||
self,
|
||||
embed_dim: int,
|
||||
pos_dim: int,
|
||||
dropout_rate: FloatLike,
|
||||
max_len: int = 1000,
|
||||
length_factor: float = 1.0,
|
||||
pos_dim: int = 4,
|
||||
) -> None:
|
||||
"""Construct a CompactRelPositionalEncoding object."""
|
||||
super(CompactRelPositionalEncoding, self).__init__()
|
||||
@ -1211,16 +1187,13 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
x (torch.Tensor): Input tensor (time, batch, num_channels_in)
|
||||
|
||||
Returns:
|
||||
positional embedding, of shape (1, 2*time-1, pos_dim).
|
||||
positional embedding, of shape (batch_size, 2*time-1, pos_dim).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
seq_len = x.size(0)
|
||||
pos_emb = self.pe[
|
||||
self.pe.size(0) // 2
|
||||
- seq_len,
|
||||
+ 1 : self.pe.size(0) // 2 # noqa E203
|
||||
+ seq_len,
|
||||
self.pe.size(0) // 2 - seq_len + 1 : self.pe.size(0) // 2 + seq_len,
|
||||
:
|
||||
]
|
||||
pos_emb = pos_emb.unsqueeze(0)
|
||||
@ -1230,12 +1203,20 @@ class CompactRelPositionalEncoding(torch.nn.Module):
|
||||
# currenly pos_emb: (1, 2*seq_len-1, pos_dim)
|
||||
pos_dim = pos_emb.shape[-1]
|
||||
batch_size = x.size(1)
|
||||
(_, seq_stride, channel_stride) = pos_emb.stride()
|
||||
# it doesn't really matter which one we make positive and which negative here, it
|
||||
# would just flip the meaning of the embedding.
|
||||
|
||||
|
||||
# expand the '1' dimension to seq_len; this introduces a dimension that
|
||||
# 'does nothing', just creates copies, as a workaround for lack of torch support
|
||||
# for negative strides.
|
||||
pos_emb = pos_emb.expand(seq_len, 2*seq_len-1, pos_dim).contiguous()
|
||||
|
||||
(useless_stride, seq_stride, channel_stride) = pos_emb.stride()
|
||||
|
||||
pos_emb = pos_emb.as_strided((batch_size, seq_len, seq_len, pos_dim),
|
||||
(0, -seq_stride, seq_stride, channel_stride),
|
||||
storage_offset=seq_stride * (seqs_len - 1))
|
||||
(0, useless_stride-seq_stride, seq_stride, channel_stride),
|
||||
storage_offset=seq_stride * (seq_len - 1))
|
||||
|
||||
return pos_emb # (batch_size, seq_len, seq_len, pos_dim)
|
||||
|
||||
@ -1326,8 +1307,6 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
x: Tensor,
|
||||
pos_emb: Tensor,
|
||||
attn_offset: Optional[Tensor] = None,
|
||||
pos_emb: Tensor,
|
||||
quadratic_pos_weight: Tensor,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Args:
|
||||
@ -1368,35 +1347,23 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
|
||||
|
||||
q = q.reshape(seq_len, batch_size, num_heads, query_head_dim)
|
||||
p = p.reshape(seq_len, batch_size, num_heads, pos_head_dim)
|
||||
p = p.reshape(seq_len, batch_size, num_heads, pos_dim)
|
||||
k = k.reshape(seq_len, batch_size, num_heads, query_head_dim)
|
||||
|
||||
# time1 refers to target, time2 refers to source.
|
||||
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
|
||||
p = p.permute(2, 1, 0, 3) # (head, batch, time1, pos_head_dim)
|
||||
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
|
||||
q = q.permute(2, 1, 0, 3) # (head, batch, tgt_seq_len, query_head_dim)
|
||||
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, src_seq_len)
|
||||
|
||||
# attn_scores: (num_heads, batch_size, tgt_seq_len, src_esq_len)
|
||||
attn_scores = torch.matmul(q, k)
|
||||
|
||||
if not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||
pos_emb = self.linear_pos(pos_emb)
|
||||
seq_len2 = 2 * seq_len - 1
|
||||
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(2, 0, 3, 1)
|
||||
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
||||
|
||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||
# [where seq_len2 represents relative position.]
|
||||
pos_scores = torch.matmul(p, pos_emb)
|
||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
# not, but let this code define which way round it is supposed to be.
|
||||
pos_scores = pos_scores.as_strided((num_heads, batch_size, seq_len, seq_len),
|
||||
(pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2)-pos_scores.stride(3),
|
||||
pos_scores.stride(3)),
|
||||
storage_offset=pos_scores.stride(3) * (seq_len - 1))
|
||||
# pos_emb: (batch_size, tgt_seq_len, src_seq_len, pos_dim)
|
||||
p = p.permute(1, 0, 3, 2) # (batch_size, tgt_seq_len, pos_dim, num_heads)
|
||||
|
||||
pos_scores = torch.matmul(pos_emb, p)
|
||||
# pos_scores: (batch_size, tgt_seq_len, src_seq_len, num_heads)
|
||||
pos_scores = pos_scores.permute(3, 0, 1, 2)
|
||||
# pos_scores: (num_heads, batch_size, tgt_seq_len, src_seq_len)
|
||||
attn_scores = attn_scores + pos_scores
|
||||
|
||||
if self.training and random.random() < 0.1:
|
||||
@ -1417,23 +1384,12 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
penalty=1.0e-04,
|
||||
name=self.name)
|
||||
|
||||
# attn_offset includes key-padding mask and attention-mask, plus any weights
|
||||
# from the subsampling.
|
||||
attn_scores = attn_scores + attn_offset
|
||||
|
||||
assert attn_scores.shape == (num_heads, batch_size, seq_len, seq_len)
|
||||
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.dtype == torch.bool
|
||||
# use -1000 to avoid nan's where attn_mask and key_padding_mask make
|
||||
# all scores zero. It's important that this be large enough that exp(-1000)
|
||||
# is exactly zero, for reasons related to const_attention_rate, it
|
||||
# compares the final weights with zero.
|
||||
attn_scores = attn_scores.masked_fill(attn_mask, -1000)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.shape == (batch_size, seq_len), key_padding_mask.shape
|
||||
attn_scores = attn_scores.masked_fill(
|
||||
key_padding_mask.unsqueeze(1),
|
||||
-1000,
|
||||
)
|
||||
|
||||
# We use our own version of softmax, defined in scaling.py, which should
|
||||
# save a little of the memory used in backprop by, if we are in
|
||||
# automatic mixed precision mode (amp / autocast), by only storing the
|
||||
@ -1617,9 +1573,9 @@ class MultiheadAttentionWeights(nn.Module):
|
||||
q = q.reshape(query_len, batch_size, num_heads, head_dim)
|
||||
k = k.reshape(key_len, batch_size, num_heads, head_dim)
|
||||
|
||||
# time1 refers to target, time2 refers to source.
|
||||
q = q.permute(2, 1, 0, 3) # (head, batch, time1, query_head_dim)
|
||||
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, time2)
|
||||
# tgt_seq_len refers to target, src_seq_len refers to source.
|
||||
q = q.permute(2, 1, 0, 3) # (head, batch, tgt_seq_len, query_head_dim)
|
||||
k = k.permute(2, 1, 3, 0) # (head, batch, d_k, src_seq_len)
|
||||
|
||||
attn_scores = torch.matmul(q, k)
|
||||
|
||||
@ -1842,8 +1798,6 @@ def _test_zipformer_main(causal: bool = False):
|
||||
c = Subformer2(
|
||||
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
||||
causal=causal,
|
||||
chunk_size=(4,) if causal else (-1,),
|
||||
left_context_frames=(64,),
|
||||
memory_dim=memory_dim,
|
||||
)
|
||||
batch_size = 5
|
||||
|
@ -63,7 +63,7 @@ from lm_datamodule import LmDataset, LmDataloader
|
||||
from zipformer import Zipformer2
|
||||
from scaling import ScheduledFloat
|
||||
from lhotse.utils import fix_random_seed
|
||||
from chunk_decoder import ChunkDecoder
|
||||
from decoder import Decoder
|
||||
from model import Zipformer2LM
|
||||
from optim import Eden, ScaledAdam
|
||||
from torch import Tensor
|
||||
@ -176,13 +176,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Positional-encoding dimension per head in encoder stacks: a single int or comma-separated list."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--pos-dim",
|
||||
type=int,
|
||||
default="48",
|
||||
help="Positional-encoding embedding dimension"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-unmasked-dim",
|
||||
type=str,
|
||||
@ -505,9 +498,9 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||
|
||||
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
|
||||
#chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
|
||||
encoder = Zipformer2(
|
||||
output_downsampling_factor=chunk_size,
|
||||
#output_downsampling_factor=chunk_size,
|
||||
downsampling_factor=_to_int_tuple(params.downsampling_factor),
|
||||
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
||||
encoder_dim=_to_int_tuple(params.encoder_dim),
|
||||
@ -515,10 +508,8 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
query_head_dim=_to_int_tuple(params.query_head_dim),
|
||||
pos_head_dim=_to_int_tuple(params.pos_head_dim),
|
||||
value_head_dim=_to_int_tuple(params.value_head_dim),
|
||||
pos_dim=params.pos_dim,
|
||||
num_heads=_to_int_tuple(params.num_heads),
|
||||
feedforward_dim=_to_int_tuple(params.feedforward_dim),
|
||||
cnn_module_kernel=_to_int_tuple(params.cnn_module_kernel),
|
||||
dropout=ScheduledFloat((0.0, 0.3), (20000.0, 0.1)),
|
||||
warmup_batches=4000.0,
|
||||
causal=True,
|
||||
@ -529,13 +520,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
|
||||
|
||||
def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
|
||||
decoder = ChunkDecoder(
|
||||
decoder = DecoderDecoder(
|
||||
embed_dim=max(_to_int_tuple(params.encoder_dim)),
|
||||
chunk_size=chunk_size,
|
||||
vocab_size=256, # bytes
|
||||
hidden_size=params.decoder_hidden_size,
|
||||
num_layers=params.decoder_num_layers,
|
||||
)
|
||||
return decoder
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user