mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
More partial work
This commit is contained in:
parent
e51a2c9170
commit
8483ca2e8f
@ -57,6 +57,10 @@ class Subformer(EncoderInterface):
|
||||
the whole stack to downsample.)
|
||||
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
|
||||
encoder stack (i.e. one per "S" in structure).
|
||||
encoder_chunk_sizes (Tuple[Tuple[int]]): A tuple containing either one tuple or
|
||||
one tuple per encoder stack. Each element tuple is a list of the chunk sizes
|
||||
that we use during training, e.g. (128, 1024); we go through these round-robin
|
||||
in successive layers.
|
||||
downsampling_factor (Tuple[int]): downsampling factor for each downsampling
|
||||
operation (each open-parenthesis).
|
||||
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
|
||||
@ -85,7 +89,7 @@ class Subformer(EncoderInterface):
|
||||
structure: str = "S(S)S",
|
||||
encoder_dim: Tuple[int, ...] = (384, 512, 384),
|
||||
downsampling_factor: Tuple[int, ...] = (2,),
|
||||
encoder_chunk_sizes: Tuple[Tuple[int, ...]] = (128,),
|
||||
encoder_chunk_sizes: Tuple[Tuple[int, ...]] = ((128,1024),),
|
||||
num_encoder_layers: Union[int, Tuple[int, ...]] = (4,),
|
||||
query_head_dim: Tuple[int, ...] = (24,),
|
||||
value_head_dim: Tuple[int, ...] = (12,),
|
||||
@ -120,7 +124,7 @@ class Subformer(EncoderInterface):
|
||||
return x
|
||||
|
||||
self.encoder_dim = encoder_dim
|
||||
encoder_chunk_size = _to_tuple(encoder_chunk_size)
|
||||
encoder_chunk_sizes = _to_tuple(encoder_chunk_sizes)
|
||||
num_encoder_layers = _to_tuple(num_encoder_layers)
|
||||
query_head_dim = _to_tuple(query_head_dim)
|
||||
value_head_dim = _to_tuple(value_head_dim)
|
||||
@ -136,7 +140,17 @@ class Subformer(EncoderInterface):
|
||||
# each one will be SubformerEncoder or DownsampledSubformerEncoder
|
||||
encoders = []
|
||||
downsamplers = []
|
||||
bypass = []
|
||||
bypasses = []
|
||||
|
||||
layer_indexes = []
|
||||
|
||||
cur_max_dim = encoder_dim[0]
|
||||
|
||||
downsampling_factors_list = []
|
||||
def cur_downsampling_factor():
|
||||
c = 1
|
||||
for d in downsampling_factors_list: c *= d
|
||||
return c
|
||||
|
||||
for s in structure:
|
||||
if s == 'S':
|
||||
@ -152,61 +166,45 @@ class Subformer(EncoderInterface):
|
||||
dropout=dropout,
|
||||
causal=causal,
|
||||
)
|
||||
|
||||
cur_max_dim = max(cur_max_dim, encoder_dim[i])
|
||||
encoder = SubformerEncoder(
|
||||
encoder_layer,
|
||||
num_encoder_layers[i],
|
||||
embed_dim=cur_max_dim,
|
||||
dropout=dropout,
|
||||
chunk_size=encoder_chunk_size[i],
|
||||
chunk_sizes=encoder_chunk_sizes[i],
|
||||
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
||||
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
||||
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
||||
final_layerdrop_rate=0.035 * (cur_downsampling_factor() ** 0.5),
|
||||
)
|
||||
layer_indexes.append(len(encoders))
|
||||
encoders.append(encoder)
|
||||
|
||||
pass
|
||||
elif s =='(':
|
||||
pass
|
||||
i = len(downsamplers)
|
||||
downsampler = LearnedDownsamplingModule(cur_max_dim,
|
||||
downsampling_factor[i])
|
||||
downsampling_factors_list.append(downsampling_factor[i])
|
||||
layer_indexes.append(len(downsamplers))
|
||||
downsamplers.append(downsampler)
|
||||
else:
|
||||
assert s == ')'
|
||||
bypass = BypassModule(cur_max_dim, straight_through_rate=0.0)
|
||||
layer_indexes.append(len(bypasses))
|
||||
bypasses.append(bypass)
|
||||
downsampling_factors_list.pop()
|
||||
|
||||
logging.info(f"cur_downsampling_factor={cur_downsampling_factor()}")
|
||||
|
||||
num_encoders = len(encoder_dim)
|
||||
assert num_encoders % 2 == 1
|
||||
downsampling_factor = [ 1 ]
|
||||
while len(downsampling_factor) < num_encoders:
|
||||
downsampling_factor = [ 1 ] + [ d * 2 for d in downsampling_factor ] + [ 1 ]
|
||||
|
||||
for i in range(num_encoders):
|
||||
|
||||
|
||||
mid = len(encoders) // 2
|
||||
encoder = DownsampledSubformerEncoder(
|
||||
[ encoders[mid] ],
|
||||
input_num_channels=encoder_dim[mid-1],
|
||||
downsample=2
|
||||
)
|
||||
for i in range(1, mid+1):
|
||||
this_list = [ encoders[mid-i],
|
||||
encoder,
|
||||
encoders[mid+i] ]
|
||||
encoder = DownsampledSubformerEncoder(
|
||||
this_list,
|
||||
input_num_channels=encoder_dim[max(0, mid-i-1)],
|
||||
downsample=2 if i != mid else 1
|
||||
)
|
||||
|
||||
self.encoder = encoder
|
||||
self.layer_indexes = layer_indexes
|
||||
self.structure = structure
|
||||
self.encoders = nn.ModuleList(encoders)
|
||||
self.downsamplers = nn.ModuleList(downsamplers)
|
||||
self.bypasses = nn.ModuleList(bypasses)
|
||||
|
||||
self.encoder_pos = CompactRelPositionalEncoding(64, pos_dim,
|
||||
dropout_rate=0.15,
|
||||
length_factor=1.0)
|
||||
|
||||
#self.downsample_output = SimpleDownsample(max(encoder_dim),
|
||||
# downsample=output_downsampling_factor,
|
||||
# dropout=dropout)
|
||||
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -239,7 +237,6 @@ class Subformer(EncoderInterface):
|
||||
"""
|
||||
outputs = []
|
||||
|
||||
attn_offset = self._get_attn_offset(x, src_key_padding_mask)
|
||||
|
||||
if self.training and memory is not None:
|
||||
batch_size = x.shape[1]
|
||||
@ -249,14 +246,42 @@ class Subformer(EncoderInterface):
|
||||
memory = memory * (torch.rand(batch_size, 1, device=memory.device) >
|
||||
memory_dropout_rate)
|
||||
|
||||
pos_emb = self.encoder_pos(x)
|
||||
attn_offsets = [ self._get_attn_offset(x, src_key_padding_mask) ]
|
||||
pos_embs = [ self.encoder_pos(x) ]
|
||||
downsample_info = []
|
||||
|
||||
for s, i in zip(self.structure, self.layer_indexes):
|
||||
if s == 'S':
|
||||
encoder = self.encoders[i] # one encoder stack
|
||||
x = encoder(x,
|
||||
pos_embs[-1],
|
||||
attn_offset=attn_offsets[-1],
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask)
|
||||
# x will have the maximum dimension up till now, even if
|
||||
# `encoder` uses lower dim in its layers.
|
||||
elif s == '(':
|
||||
downsampler = self.downsamplers[i]
|
||||
|
||||
indexes, weights, x_new = downsampler(x)
|
||||
downsample_info.append((indexes, weights, x))
|
||||
x = x_new
|
||||
|
||||
pos_embs.append(downsampler.downsample_pos_emb(pos_embs[-1], indexes))
|
||||
|
||||
attn_offsets.append(downsampler.downsample_attn_offset(attn_offsets[-1],
|
||||
indexes,
|
||||
weights))
|
||||
|
||||
else:
|
||||
assert s == ')' # upsample
|
||||
indexes, weights, x_orig = downsample_info.pop()
|
||||
_attn_offset = attn_offsets.pop()
|
||||
_pos_emb = pos_embs.pop()
|
||||
x_orig = convert_num_channels(x_orig, x.shape[-1])
|
||||
|
||||
x = LearnedDownsamplingModule.upsample(x_orig, x, indexes, weights)
|
||||
|
||||
x = self.encoder(x,
|
||||
pos_emb,
|
||||
attn_offset=attn_offset,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
|
||||
# d = self.output_downsampling_factor
|
||||
# lengths = (x_lens + d - 1) // d
|
||||
@ -575,6 +600,9 @@ class SubformerEncoder(nn.Module):
|
||||
Args:
|
||||
encoder_layer: an instance of the SubformerEncoderLayer() class (required).
|
||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||
embed_dim: the embedding dimension to use for the bypass (may exceed the
|
||||
dimension of encoder_layer, as it may not operate on the full
|
||||
dimension).
|
||||
|
||||
Examples::
|
||||
>>> encoder_layer = SubformerEncoderLayer(embed_dim=512, nhead=8)
|
||||
@ -586,6 +614,7 @@ class SubformerEncoder(nn.Module):
|
||||
self,
|
||||
encoder_layer: nn.Module,
|
||||
num_layers: int,
|
||||
embed_dim: int,
|
||||
dropout: float,
|
||||
warmup_begin: float,
|
||||
warmup_end: float,
|
||||
@ -602,7 +631,7 @@ class SubformerEncoder(nn.Module):
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
|
||||
self.bypass = BypassModule(self.embed_dim())
|
||||
self.bypass = BypassModule(embed_dim)
|
||||
|
||||
assert 0 <= warmup_begin <= warmup_end
|
||||
|
||||
@ -616,7 +645,7 @@ class SubformerEncoder(nn.Module):
|
||||
cur_begin = cur_end
|
||||
|
||||
def embed_dim(self):
|
||||
return self.layers[0].embed_dim
|
||||
return self.bypass.embed_dim()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -644,14 +673,7 @@ class SubformerEncoder(nn.Module):
|
||||
|
||||
Returns: a Tensor with the same shape as src.
|
||||
"""
|
||||
src = convert_num_channels(src, self.embed_dim())
|
||||
output = src
|
||||
|
||||
rnd_seed = src.numel() + random.randint(0, 1000)
|
||||
|
||||
#if feature_mask is not None:
|
||||
# output = output * feature_mask
|
||||
|
||||
output = convert_num_channels(src, self.layers[0].embed_dim)
|
||||
|
||||
chunk_sizes, chunk_indexes = self._get_chunk_sizes(src)
|
||||
b = src.shape[1] # batch_size
|
||||
@ -678,6 +700,9 @@ class SubformerEncoder(nn.Module):
|
||||
|
||||
output = self._to_chunk_size(output, src.shape[0])
|
||||
|
||||
output = convert_num_channels(output, self.bypass.embed_dim())
|
||||
src = convert_num_channels(src, self.bypass.embed_dim())
|
||||
|
||||
return self.bypass(src, output)
|
||||
|
||||
def _get_chunk_sizes(self, src: Tensor) -> Tuple[List[int], List[int]]:
|
||||
@ -784,6 +809,8 @@ class BypassModule(nn.Module):
|
||||
self.scale_min = copy.deepcopy(scale_min)
|
||||
self.scale_max = copy.deepcopy(scale_max)
|
||||
|
||||
def embed_dim(self):
|
||||
return self.bypass_scale.numel()
|
||||
|
||||
def _get_bypass_scale(self, batch_size: int):
|
||||
# returns bypass-scale of shape (num_channels,),
|
||||
@ -840,7 +867,7 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
|
||||
self.to_scores.lr_factor = 0.5
|
||||
self.to_scores.lr_scale = 0.5
|
||||
# score_balancer is just to keep the magnitudes of the scores in
|
||||
# a fixed range and keep them balanced around zero, to stop
|
||||
# these drifting around.
|
||||
@ -1028,7 +1055,9 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
return attn_offset
|
||||
|
||||
|
||||
def upsample(self, x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor:
|
||||
@staticmethod
|
||||
def upsample(x_orig: Tensor, x: Tensor, indexes: Tensor,
|
||||
weights: Optional[Tensor] = None) -> Tensor:
|
||||
"""
|
||||
Upsamples, reversing the downsample() operation and filling in
|
||||
any not-chosen frames with their original value before downsampling
|
||||
@ -1038,30 +1067,40 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
x_orig: (seq_len, batch_size, num_channels)
|
||||
x: (seq_len_reduced, batch_size, num_channels)
|
||||
indexes: (batch_size, seq_len_reduced), contains original frame indexes
|
||||
weights: optional tensor
|
||||
|
||||
Downsamples x via indexing with the indexes obtained from the
|
||||
forward() function.
|
||||
|
||||
Args:
|
||||
x: tensor of shape (seq_len, batch_size, indexes)
|
||||
x: tensor of shape (seq_len, batch_size, indexes)
|
||||
weights: a tensor of shape (batch_size, seq_len_reduced) containing weights between
|
||||
0 and 1, where 1 means fully use this x value and 0 means use x_orig
|
||||
indexes: integer indexes of shape (batch_size, seq_len_reduced), with elements
|
||||
0 <= indexes < seq_len.
|
||||
"""
|
||||
(seq_len, batch_size, num_channels) = x_orig.shape
|
||||
|
||||
not_kept = torch.ones(batch_size, seq_len, dtype=torch.bool,
|
||||
device=x.device)
|
||||
not_kept.scatter_(dim=1, index=indexes, value=False)
|
||||
x_weight = 1.0 if weights is None else weights.t().unsqueeze(-1)
|
||||
# x_weight: (seq_len_reduced, batch_size, 1) if a tensor
|
||||
|
||||
orig_x_weight = torch.ones(batch_size, seq_len,
|
||||
device=x.device, dtype=x.dtype)
|
||||
if weights is None:
|
||||
orig_x_weight.scatter_(dim=1, index=indexes, value=0.)
|
||||
else:
|
||||
orig_x_weight.scatter_(dim=1, index=indexes,
|
||||
src=(1. - weights).to(x.dtype))
|
||||
|
||||
indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels)
|
||||
# indexes now: (seq_len_reduced, batch_size, num_channels)
|
||||
|
||||
ans = torch.zeros_like(x_orig)
|
||||
|
||||
ans.scatter_(dim=0, index=indexes, src=x)
|
||||
ans.scatter_(dim=0, index=indexes, src=(x * x_weight))
|
||||
|
||||
# add in x_orig in the frames that were not originally kept.
|
||||
return ans + x_orig * not_kept.t().unsqueeze(-1)
|
||||
return ans + x_orig * orig_x_weight.t().unsqueeze(-1)
|
||||
|
||||
|
||||
class DownsampledSubformerEncoder(nn.Module):
|
||||
@ -1151,7 +1190,7 @@ class DownsampledSubformerEncoder(nn.Module):
|
||||
src_orig = convert_num_channels(src_orig, src.shape[-1])
|
||||
|
||||
if hasattr(self, 'downsampler'):
|
||||
src = self.downsampler.upsample(src_orig, src, indexes)
|
||||
src = self.downsampler.upsample(src_orig, src, indexes, weights)
|
||||
|
||||
return self.out_combiner(src_orig, src)
|
||||
|
||||
|
||||
@ -118,6 +118,7 @@ def set_batch_count(
|
||||
|
||||
|
||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
|
||||
parser.add_argument(
|
||||
"--num-encoder-layers",
|
||||
type=str,
|
||||
@ -147,13 +148,21 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-chunk-size",
|
||||
"--encoder-chunk-sizes",
|
||||
type=str,
|
||||
default="128",
|
||||
default="128,1024",
|
||||
help="Base chunk size for attention in encoder stacks; alternate layers will use this value or "
|
||||
"double this value."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-structure",
|
||||
type=str,
|
||||
default="S(S(S(S)S)S)S",
|
||||
help="Structure of encoder, determines order of encoder stacks and (downsampling/upsampling) "
|
||||
"operations."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--query-head-dim",
|
||||
type=str,
|
||||
@ -421,9 +430,10 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
#chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
|
||||
encoder = Subformer(
|
||||
structure=params.encoder_structure,
|
||||
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
||||
encoder_dim=_to_int_tuple(params.encoder_dim),
|
||||
encoder_chunk_size=_to_int_tuple(params.encoder_chunk_size),
|
||||
encoder_chunk_sizes=(_to_int_tuple(params.encoder_chunk_sizes),),
|
||||
query_head_dim=_to_int_tuple(params.query_head_dim),
|
||||
pos_dim=int(params.pos_dim),
|
||||
value_head_dim=_to_int_tuple(params.value_head_dim),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user