More partial work

This commit is contained in:
Daniel Povey 2023-05-24 16:04:05 +08:00
parent e51a2c9170
commit 8483ca2e8f
2 changed files with 119 additions and 70 deletions

View File

@ -57,6 +57,10 @@ class Subformer(EncoderInterface):
the whole stack to downsample.)
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
encoder stack (i.e. one per "S" in structure).
encoder_chunk_sizes (Tuple[Tuple[int]]): A tuple containing either one tuple or
one tuple per encoder stack. Each element tuple is a list of the chunk sizes
that we use during training, e.g. (128, 1024); we go through these round-robin
in successive layers.
downsampling_factor (Tuple[int]): downsampling factor for each downsampling
operation (each open-parenthesis).
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
@ -85,7 +89,7 @@ class Subformer(EncoderInterface):
structure: str = "S(S)S",
encoder_dim: Tuple[int, ...] = (384, 512, 384),
downsampling_factor: Tuple[int, ...] = (2,),
encoder_chunk_sizes: Tuple[Tuple[int, ...]] = (128,),
encoder_chunk_sizes: Tuple[Tuple[int, ...]] = ((128,1024),),
num_encoder_layers: Union[int, Tuple[int, ...]] = (4,),
query_head_dim: Tuple[int, ...] = (24,),
value_head_dim: Tuple[int, ...] = (12,),
@ -120,7 +124,7 @@ class Subformer(EncoderInterface):
return x
self.encoder_dim = encoder_dim
encoder_chunk_size = _to_tuple(encoder_chunk_size)
encoder_chunk_sizes = _to_tuple(encoder_chunk_sizes)
num_encoder_layers = _to_tuple(num_encoder_layers)
query_head_dim = _to_tuple(query_head_dim)
value_head_dim = _to_tuple(value_head_dim)
@ -136,7 +140,17 @@ class Subformer(EncoderInterface):
# each one will be SubformerEncoder or DownsampledSubformerEncoder
encoders = []
downsamplers = []
bypass = []
bypasses = []
layer_indexes = []
cur_max_dim = encoder_dim[0]
downsampling_factors_list = []
def cur_downsampling_factor():
c = 1
for d in downsampling_factors_list: c *= d
return c
for s in structure:
if s == 'S':
@ -152,61 +166,45 @@ class Subformer(EncoderInterface):
dropout=dropout,
causal=causal,
)
cur_max_dim = max(cur_max_dim, encoder_dim[i])
encoder = SubformerEncoder(
encoder_layer,
num_encoder_layers[i],
embed_dim=cur_max_dim,
dropout=dropout,
chunk_size=encoder_chunk_size[i],
chunk_sizes=encoder_chunk_sizes[i],
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
final_layerdrop_rate=0.035 * (cur_downsampling_factor() ** 0.5),
)
layer_indexes.append(len(encoders))
encoders.append(encoder)
pass
elif s =='(':
pass
i = len(downsamplers)
downsampler = LearnedDownsamplingModule(cur_max_dim,
downsampling_factor[i])
downsampling_factors_list.append(downsampling_factor[i])
layer_indexes.append(len(downsamplers))
downsamplers.append(downsampler)
else:
assert s == ')'
bypass = BypassModule(cur_max_dim, straight_through_rate=0.0)
layer_indexes.append(len(bypasses))
bypasses.append(bypass)
downsampling_factors_list.pop()
logging.info(f"cur_downsampling_factor={cur_downsampling_factor()}")
num_encoders = len(encoder_dim)
assert num_encoders % 2 == 1
downsampling_factor = [ 1 ]
while len(downsampling_factor) < num_encoders:
downsampling_factor = [ 1 ] + [ d * 2 for d in downsampling_factor ] + [ 1 ]
for i in range(num_encoders):
mid = len(encoders) // 2
encoder = DownsampledSubformerEncoder(
[ encoders[mid] ],
input_num_channels=encoder_dim[mid-1],
downsample=2
)
for i in range(1, mid+1):
this_list = [ encoders[mid-i],
encoder,
encoders[mid+i] ]
encoder = DownsampledSubformerEncoder(
this_list,
input_num_channels=encoder_dim[max(0, mid-i-1)],
downsample=2 if i != mid else 1
)
self.encoder = encoder
self.layer_indexes = layer_indexes
self.structure = structure
self.encoders = nn.ModuleList(encoders)
self.downsamplers = nn.ModuleList(downsamplers)
self.bypasses = nn.ModuleList(bypasses)
self.encoder_pos = CompactRelPositionalEncoding(64, pos_dim,
dropout_rate=0.15,
length_factor=1.0)
#self.downsample_output = SimpleDownsample(max(encoder_dim),
# downsample=output_downsampling_factor,
# dropout=dropout)
def forward(
self,
@ -239,7 +237,6 @@ class Subformer(EncoderInterface):
"""
outputs = []
attn_offset = self._get_attn_offset(x, src_key_padding_mask)
if self.training and memory is not None:
batch_size = x.shape[1]
@ -249,14 +246,42 @@ class Subformer(EncoderInterface):
memory = memory * (torch.rand(batch_size, 1, device=memory.device) >
memory_dropout_rate)
pos_emb = self.encoder_pos(x)
attn_offsets = [ self._get_attn_offset(x, src_key_padding_mask) ]
pos_embs = [ self.encoder_pos(x) ]
downsample_info = []
for s, i in zip(self.structure, self.layer_indexes):
if s == 'S':
encoder = self.encoders[i] # one encoder stack
x = encoder(x,
pos_embs[-1],
attn_offset=attn_offsets[-1],
memory=memory,
memory_key_padding_mask=memory_key_padding_mask)
# x will have the maximum dimension up till now, even if
# `encoder` uses lower dim in its layers.
elif s == '(':
downsampler = self.downsamplers[i]
indexes, weights, x_new = downsampler(x)
downsample_info.append((indexes, weights, x))
x = x_new
pos_embs.append(downsampler.downsample_pos_emb(pos_embs[-1], indexes))
attn_offsets.append(downsampler.downsample_attn_offset(attn_offsets[-1],
indexes,
weights))
else:
assert s == ')' # upsample
indexes, weights, x_orig = downsample_info.pop()
_attn_offset = attn_offsets.pop()
_pos_emb = pos_embs.pop()
x_orig = convert_num_channels(x_orig, x.shape[-1])
x = LearnedDownsamplingModule.upsample(x_orig, x, indexes, weights)
x = self.encoder(x,
pos_emb,
attn_offset=attn_offset,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
# d = self.output_downsampling_factor
# lengths = (x_lens + d - 1) // d
@ -575,6 +600,9 @@ class SubformerEncoder(nn.Module):
Args:
encoder_layer: an instance of the SubformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
embed_dim: the embedding dimension to use for the bypass (may exceed the
dimension of encoder_layer, as it may not operate on the full
dimension).
Examples::
>>> encoder_layer = SubformerEncoderLayer(embed_dim=512, nhead=8)
@ -586,6 +614,7 @@ class SubformerEncoder(nn.Module):
self,
encoder_layer: nn.Module,
num_layers: int,
embed_dim: int,
dropout: float,
warmup_begin: float,
warmup_end: float,
@ -602,7 +631,7 @@ class SubformerEncoder(nn.Module):
)
self.num_layers = num_layers
self.bypass = BypassModule(self.embed_dim())
self.bypass = BypassModule(embed_dim)
assert 0 <= warmup_begin <= warmup_end
@ -616,7 +645,7 @@ class SubformerEncoder(nn.Module):
cur_begin = cur_end
def embed_dim(self):
return self.layers[0].embed_dim
return self.bypass.embed_dim()
def forward(
self,
@ -644,14 +673,7 @@ class SubformerEncoder(nn.Module):
Returns: a Tensor with the same shape as src.
"""
src = convert_num_channels(src, self.embed_dim())
output = src
rnd_seed = src.numel() + random.randint(0, 1000)
#if feature_mask is not None:
# output = output * feature_mask
output = convert_num_channels(src, self.layers[0].embed_dim)
chunk_sizes, chunk_indexes = self._get_chunk_sizes(src)
b = src.shape[1] # batch_size
@ -678,6 +700,9 @@ class SubformerEncoder(nn.Module):
output = self._to_chunk_size(output, src.shape[0])
output = convert_num_channels(output, self.bypass.embed_dim())
src = convert_num_channels(src, self.bypass.embed_dim())
return self.bypass(src, output)
def _get_chunk_sizes(self, src: Tensor) -> Tuple[List[int], List[int]]:
@ -784,6 +809,8 @@ class BypassModule(nn.Module):
self.scale_min = copy.deepcopy(scale_min)
self.scale_max = copy.deepcopy(scale_max)
def embed_dim(self):
return self.bypass_scale.numel()
def _get_bypass_scale(self, batch_size: int):
# returns bypass-scale of shape (num_channels,),
@ -840,7 +867,7 @@ class LearnedDownsamplingModule(nn.Module):
super().__init__()
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
self.to_scores.lr_factor = 0.5
self.to_scores.lr_scale = 0.5
# score_balancer is just to keep the magnitudes of the scores in
# a fixed range and keep them balanced around zero, to stop
# these drifting around.
@ -1028,7 +1055,9 @@ class LearnedDownsamplingModule(nn.Module):
return attn_offset
def upsample(self, x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor:
@staticmethod
def upsample(x_orig: Tensor, x: Tensor, indexes: Tensor,
weights: Optional[Tensor] = None) -> Tensor:
"""
Upsamples, reversing the downsample() operation and filling in
any not-chosen frames with their original value before downsampling
@ -1038,30 +1067,40 @@ class LearnedDownsamplingModule(nn.Module):
x_orig: (seq_len, batch_size, num_channels)
x: (seq_len_reduced, batch_size, num_channels)
indexes: (batch_size, seq_len_reduced), contains original frame indexes
weights: optional tensor
Downsamples x via indexing with the indexes obtained from the
forward() function.
Args:
x: tensor of shape (seq_len, batch_size, indexes)
x: tensor of shape (seq_len, batch_size, indexes)
weights: a tensor of shape (batch_size, seq_len_reduced) containing weights between
0 and 1, where 1 means fully use this x value and 0 means use x_orig
indexes: integer indexes of shape (batch_size, seq_len_reduced), with elements
0 <= indexes < seq_len.
"""
(seq_len, batch_size, num_channels) = x_orig.shape
not_kept = torch.ones(batch_size, seq_len, dtype=torch.bool,
device=x.device)
not_kept.scatter_(dim=1, index=indexes, value=False)
x_weight = 1.0 if weights is None else weights.t().unsqueeze(-1)
# x_weight: (seq_len_reduced, batch_size, 1) if a tensor
orig_x_weight = torch.ones(batch_size, seq_len,
device=x.device, dtype=x.dtype)
if weights is None:
orig_x_weight.scatter_(dim=1, index=indexes, value=0.)
else:
orig_x_weight.scatter_(dim=1, index=indexes,
src=(1. - weights).to(x.dtype))
indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels)
# indexes now: (seq_len_reduced, batch_size, num_channels)
ans = torch.zeros_like(x_orig)
ans.scatter_(dim=0, index=indexes, src=x)
ans.scatter_(dim=0, index=indexes, src=(x * x_weight))
# add in x_orig in the frames that were not originally kept.
return ans + x_orig * not_kept.t().unsqueeze(-1)
return ans + x_orig * orig_x_weight.t().unsqueeze(-1)
class DownsampledSubformerEncoder(nn.Module):
@ -1151,7 +1190,7 @@ class DownsampledSubformerEncoder(nn.Module):
src_orig = convert_num_channels(src_orig, src.shape[-1])
if hasattr(self, 'downsampler'):
src = self.downsampler.upsample(src_orig, src, indexes)
src = self.downsampler.upsample(src_orig, src, indexes, weights)
return self.out_combiner(src_orig, src)

View File

@ -118,6 +118,7 @@ def set_batch_count(
def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--num-encoder-layers",
type=str,
@ -147,13 +148,21 @@ def add_model_arguments(parser: argparse.ArgumentParser):
)
parser.add_argument(
"--encoder-chunk-size",
"--encoder-chunk-sizes",
type=str,
default="128",
default="128,1024",
help="Base chunk size for attention in encoder stacks; alternate layers will use this value or "
"double this value."
)
parser.add_argument(
"--encoder-structure",
type=str,
default="S(S(S(S)S)S)S",
help="Structure of encoder, determines order of encoder stacks and (downsampling/upsampling) "
"operations."
)
parser.add_argument(
"--query-head-dim",
type=str,
@ -421,9 +430,10 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
def get_encoder_model(params: AttributeDict) -> nn.Module:
#chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
encoder = Subformer(
structure=params.encoder_structure,
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
encoder_dim=_to_int_tuple(params.encoder_dim),
encoder_chunk_size=_to_int_tuple(params.encoder_chunk_size),
encoder_chunk_sizes=(_to_int_tuple(params.encoder_chunk_sizes),),
query_head_dim=_to_int_tuple(params.query_head_dim),
pos_dim=int(params.pos_dim),
value_head_dim=_to_int_tuple(params.value_head_dim),