Code cleanup

This commit is contained in:
Daniel Povey 2023-05-15 22:01:19 +08:00
parent 671e9ee5bd
commit 0a76215fd7
2 changed files with 92 additions and 161 deletions

View File

@ -58,9 +58,6 @@ class Subformer(EncoderInterface):
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
encoder stack.
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of
the encoder stacks for purposes of per-frame dropout (recommend 256 for
now).
query_head_dim (int or Tuple[int]): dimension of query and key per attention
head: per stack, if a tuple..
value_head_dim (int or Tuple[int]): dimension of value in each attention head
@ -83,11 +80,8 @@ class Subformer(EncoderInterface):
"""
def __init__(
self,
output_downsampling_factor: int = 2,
downsampling_factor: Tuple[int] = (2, 4),
encoder_dim: Union[int, Tuple[int]] = 384,
encoder_dim: Union[int, Tuple[int]] = (384, 512, 384),
num_encoder_layers: Union[int, Tuple[int]] = 4,
encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
query_head_dim: Union[int, Tuple[int]] = 24,
value_head_dim: Union[int, Tuple[int]] = 12,
num_heads: Union[int, Tuple[int]] = 8,
@ -106,19 +100,16 @@ class Subformer(EncoderInterface):
def _to_tuple(x):
""" Converts a single int or a 1-tuple of an int to a tuple with the same length
as downsampling_factor"""
as encoder_dim"""
if isinstance(x, int):
x = (x,)
if len(x) == 1:
x = x * len(downsampling_factor)
x = x * len(encoder_dim)
else:
assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
assert len(x) == len(encoder_dim) and isinstance(x[0], int)
return x
self.output_downsampling_factor = output_downsampling_factor # int
self.downsampling_factor = downsampling_factor # tuple
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple
self.encoder_dim = encoder_dim
num_encoder_layers = _to_tuple(num_encoder_layers)
query_head_dim = _to_tuple(query_head_dim)
value_head_dim = _to_tuple(value_head_dim)
@ -126,13 +117,17 @@ class Subformer(EncoderInterface):
feedforward_dim = _to_tuple(feedforward_dim)
self.causal = causal
for u,d in zip(encoder_unmasked_dim, encoder_dim):
assert u <= d
# each one will be SubformerEncoder or DownsampledSubformerEncoder
encoders = []
num_encoders = len(downsampling_factor)
num_encoders = len(encoder_dim)
assert num_encoders % 2 == 1
downsampling_factor = [ 1 ]
while len(downsampling_factor) < num_encoders:
downsampling_factor = [ 1 ] + [ d * 2 for d in downsampling_factor ] + [ 1 ]
for i in range(num_encoders):
encoder_layer = SubformerEncoderLayer(
@ -158,83 +153,36 @@ class Subformer(EncoderInterface):
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
)
if downsampling_factor[i] != 1:
encoders.append(encoder)
mid = len(encoders) // 2
encoder = DownsampledSubformerEncoder(
encoder,
dim=encoder_dim[i],
downsample=downsampling_factor[i],
dropout=dropout,
[ encoders[mid] ],
input_num_channels=encoder_dim[mid],
downsample=2
)
encoders.append(encoder)
encoder = encoders[mid]
for i in range(mid-1, -1, -1):
this_list = [ encoders[mid-i],
encoder,
encoders[mid+i] ]
encoder = DownsampledSubformerEncoder(
this_list,
input_num_channels=encoder_dim[max(0, mid-2)],
downsample=2,
)
self.encoder = encoder
self.encoder_pos = CompactRelPositionalEncoding(64, pos_dim,
dropout_rate=0.15,
length_factor=1.0)
self.encoders = nn.ModuleList(encoders)
#self.downsample_output = SimpleDownsample(max(encoder_dim),
# downsample=output_downsampling_factor,
# dropout=dropout)
def get_feature_masks(
self,
x: torch.Tensor) -> List[Union[float, Tensor]]:
"""
In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
randomized feature masks, one per encoder.
On e.g. 15% of frames, these masks will zero out all enocder dims larger than
some supplied number, e.g. >256, so in effect on those frames we are using
a smaller encoer dim.
We generate the random masks at this level because we want the 2 masks to 'agree'
all the way up the encoder stack. This will mean that the 1st mask will have
mask values repeated self.zipformer_subsampling_factor times.
Args:
x: the embeddings (needed for the shape and dtype and device), of shape
(1, batch_size, encoder_dims0)
"""
num_encoders = len(self.encoder_dim)
if not self.training:
return [ 1.0 ] * num_encoders
(num_frames0, batch_size, _encoder_dims0) = x.shape
assert self.encoder_dim[0] == _encoder_dims0
feature_mask_dropout_prob = 0.125
# mask1 shape: (1, batch_size, 1)
mask1 = (torch.rand(1, batch_size, 1,
device=x.device) >
feature_mask_dropout_prob).to(x.dtype)
# mask2 has additional sequences masked, about twice the number.
mask2 = torch.logical_and(mask1,
(torch.rand(1, batch_size, 1,
device=x.device) >
feature_mask_dropout_prob).to(x.dtype))
# dim: (1, batch_size, 2)
mask = torch.cat((mask1, mask2), dim=-1)
feature_masks = []
for i in range(num_encoders):
channels = self.encoder_dim[i]
feature_mask = torch.ones(1, batch_size, channels,
dtype=x.dtype, device=x.device)
u1 = self.encoder_unmasked_dim[i]
u2 = u1 + (channels - u1) // 2
feature_mask[:, :, u1:u2] *= mask[..., 0:1]
feature_mask[:, :, u2:] *= mask[..., 1:2]
feature_masks.append(feature_mask)
return feature_masks
def forward(
@ -267,7 +215,6 @@ class Subformer(EncoderInterface):
of frames in `embeddings` before padding.
"""
outputs = []
feature_masks = self.get_feature_masks(x)
attn_offset = self._get_attn_offset(x, src_key_padding_mask)
@ -281,44 +228,17 @@ class Subformer(EncoderInterface):
pos_emb = self.encoder_pos(x)
for i, module in enumerate(self.encoders):
x = convert_num_channels(x, self.encoder_dim[i])
x = module(x,
x = self.encoder(x,
pos_emb,
feature_mask=feature_masks[i],
attn_offset=attn_offset,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
outputs.append(x)
def get_full_dim_output():
num_encoders = len(self.encoder_dim)
assert len(outputs) == num_encoders
output_dim = max(self.encoder_dim)
output_pieces = [ outputs[-1] ]
cur_dim = self.encoder_dim[-1]
for i in range(num_encoders - 2, -1, -1):
d = self.encoder_dim[i]
if d > cur_dim:
this_output = outputs[i]
output_pieces.append(this_output[..., cur_dim:d])
cur_dim = d
assert cur_dim == output_dim
return torch.cat(output_pieces, dim=-1)
# d = self.output_downsampling_factor
# lengths = (x_lens + d - 1) // d
# if the last output has the largest dimension, x will be unchanged,
# it will be the same as outputs[-1]. Otherwise it will be concatenated
# from different pieces of 'outputs', taking each dimension from the
# most recent output that has it present.
x = get_full_dim_output()
#x = self.downsample_output(x)
d = self.output_downsampling_factor
lengths = (x_lens + d - 1) // d
return x, lengths
return x, x_lens
def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]:
"""
@ -507,7 +427,6 @@ class SubformerEncoderLayer(nn.Module):
min_abs=0.1, max_abs=4.0,
)
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
return None
@ -650,11 +569,14 @@ class SubformerEncoder(nn.Module):
final_layerdrop_rate: float = 0.05,
) -> None:
super().__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers
self.bypass = BypassModule(self.embed_dim())
assert 0 <= warmup_begin <= warmup_end
delta = (1. / num_layers) * (warmup_end - warmup_begin)
@ -666,11 +588,14 @@ class SubformerEncoder(nn.Module):
default=0.0)
cur_begin = cur_end
def embed_dim(self):
return self.layers[0].embed_dim
def forward(
self,
src: Tensor,
pos_emb: Tensor,
feature_mask: Union[Tensor, float] = 1.0,
feature_mask: Optional[Tensor] = None,
attn_offset: Optional[Tensor] = None,
memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
@ -693,10 +618,12 @@ class SubformerEncoder(nn.Module):
Returns: a Tensor with the same shape as src.
"""
src = convert_num_channels(src, self.embed_dim())
output = src
rnd_seed = src.numel() + random.randint(0, 1000)
if feature_mask is not None:
output = output * feature_mask
for i, mod in enumerate(self.layers):
@ -708,9 +635,10 @@ class SubformerEncoder(nn.Module):
memory_key_padding_mask=memory_key_padding_mask,
)
if feature_mask is not None:
output = output * feature_mask
return output
return self.bypass(src, output)
class BypassModule(nn.Module):
@ -1014,17 +942,17 @@ class DownsampledSubformerEncoder(nn.Module):
with the origin input, so that the output has the same shape as the input.
"""
def __init__(self,
encoder: nn.Module,
dim: int,
downsample: int,
dropout: FloatLike):
encoders: List[nn.Module],
input_num_channels: int,
downsample: int):
super(DownsampledSubformerEncoder, self).__init__()
self.downsample_factor = downsample
self.downsampler = LearnedDownsamplingModule(dim,
self.downsampler = LearnedDownsamplingModule(input_num_channels,
downsample)
self.encoder = encoder
self.encoders = nn.ModuleList(encoders)
self.out_combiner = BypassModule(dim, straight_through_rate=0.025)
self.out_combiner = BypassModule(max(e.embed_dim() for e in encoders),
straight_through_rate=0.0)
def forward(self,
src: Tensor,
@ -1059,15 +987,35 @@ class DownsampledSubformerEncoder(nn.Module):
attn_offset = self.downsampler.downsample_attn_offset(attn_offset,
indexes,
weights)
outputs = [ src ]
src = self.encoder(
for encoder in self.encoders:
src = encoder(
src,
pos_emb,
feature_mask=feature_mask,
attn_offset=attn_offset,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
)
outputs.append(src)
def get_full_dim_output():
num_encoders = len(outputs)
output_dim = max(o.shape[-1] for o in outputs)
output_pieces = [ outputs[-1] ]
cur_dim = outputs[-1].shape[-1]
for i in range(num_encoders - 2, -1, -1):
d = outputs[i].shape[-1]
if d > cur_dim:
this_output = outputs[i]
output_pieces.append(this_output[..., cur_dim:d])
cur_dim = d
assert cur_dim == output_dim
return torch.cat(output_pieces, dim=-1)
src = get_full_dim_output()
src_orig = convert_num_channels(src_orig, src.shape[-1])
src = self.downsampler.upsample(src_orig, src, indexes)
return self.out_combiner(src_orig, src)
@ -1793,7 +1741,8 @@ def _test_zipformer_main(causal: bool = False):
memory_dim = 100
c = Subformer(
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
encoder_dim=(64, 96, 64),
num_heads=(4, 4, 8),
causal=causal,
memory_dim=memory_dim,
)

View File

@ -125,15 +125,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Number of subformer encoder layers per stack, comma separated.",
)
parser.add_argument(
"--downsampling-factor",
type=str,
default="1,2,4,8,4,2,1",
help="Downsampling factor for each stack of encoder layers.",
)
parser.add_argument(
"--feedforward-dim",
type=str,
@ -176,13 +167,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Positional-encoding dimension in encoder stacks: a single int or comma-separated list."
)
parser.add_argument(
"--encoder-unmasked-dim",
type=str,
default="192,192,256,256,256,192,192",
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
)
def get_parser():
@ -478,11 +463,8 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
def get_encoder_model(params: AttributeDict) -> nn.Module:
#chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
encoder = Subformer(
#output_downsampling_factor=chunk_size,
downsampling_factor=_to_int_tuple(params.downsampling_factor),
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
encoder_dim=_to_int_tuple(params.encoder_dim),
encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
query_head_dim=_to_int_tuple(params.query_head_dim),
pos_dim=int(params.pos_dim),
value_head_dim=_to_int_tuple(params.value_head_dim),