Code cleanup

This commit is contained in:
Daniel Povey 2023-05-15 22:01:19 +08:00
parent 671e9ee5bd
commit 0a76215fd7
2 changed files with 92 additions and 161 deletions

View File

@ -58,9 +58,6 @@ class Subformer(EncoderInterface):
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
encoder stack. encoder stack.
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of
the encoder stacks for purposes of per-frame dropout (recommend 256 for
now).
query_head_dim (int or Tuple[int]): dimension of query and key per attention query_head_dim (int or Tuple[int]): dimension of query and key per attention
head: per stack, if a tuple.. head: per stack, if a tuple..
value_head_dim (int or Tuple[int]): dimension of value in each attention head value_head_dim (int or Tuple[int]): dimension of value in each attention head
@ -83,11 +80,8 @@ class Subformer(EncoderInterface):
""" """
def __init__( def __init__(
self, self,
output_downsampling_factor: int = 2, encoder_dim: Union[int, Tuple[int]] = (384, 512, 384),
downsampling_factor: Tuple[int] = (2, 4),
encoder_dim: Union[int, Tuple[int]] = 384,
num_encoder_layers: Union[int, Tuple[int]] = 4, num_encoder_layers: Union[int, Tuple[int]] = 4,
encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
query_head_dim: Union[int, Tuple[int]] = 24, query_head_dim: Union[int, Tuple[int]] = 24,
value_head_dim: Union[int, Tuple[int]] = 12, value_head_dim: Union[int, Tuple[int]] = 12,
num_heads: Union[int, Tuple[int]] = 8, num_heads: Union[int, Tuple[int]] = 8,
@ -106,19 +100,16 @@ class Subformer(EncoderInterface):
def _to_tuple(x): def _to_tuple(x):
""" Converts a single int or a 1-tuple of an int to a tuple with the same length """ Converts a single int or a 1-tuple of an int to a tuple with the same length
as downsampling_factor""" as encoder_dim"""
if isinstance(x, int): if isinstance(x, int):
x = (x,) x = (x,)
if len(x) == 1: if len(x) == 1:
x = x * len(downsampling_factor) x = x * len(encoder_dim)
else: else:
assert len(x) == len(downsampling_factor) and isinstance(x[0], int) assert len(x) == len(encoder_dim) and isinstance(x[0], int)
return x return x
self.output_downsampling_factor = output_downsampling_factor # int self.encoder_dim = encoder_dim
self.downsampling_factor = downsampling_factor # tuple
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple
num_encoder_layers = _to_tuple(num_encoder_layers) num_encoder_layers = _to_tuple(num_encoder_layers)
query_head_dim = _to_tuple(query_head_dim) query_head_dim = _to_tuple(query_head_dim)
value_head_dim = _to_tuple(value_head_dim) value_head_dim = _to_tuple(value_head_dim)
@ -126,13 +117,17 @@ class Subformer(EncoderInterface):
feedforward_dim = _to_tuple(feedforward_dim) feedforward_dim = _to_tuple(feedforward_dim)
self.causal = causal self.causal = causal
for u,d in zip(encoder_unmasked_dim, encoder_dim):
assert u <= d
# each one will be SubformerEncoder or DownsampledSubformerEncoder # each one will be SubformerEncoder or DownsampledSubformerEncoder
encoders = [] encoders = []
num_encoders = len(downsampling_factor)
num_encoders = len(encoder_dim)
assert num_encoders % 2 == 1
downsampling_factor = [ 1 ]
while len(downsampling_factor) < num_encoders:
downsampling_factor = [ 1 ] + [ d * 2 for d in downsampling_factor ] + [ 1 ]
for i in range(num_encoders): for i in range(num_encoders):
encoder_layer = SubformerEncoderLayer( encoder_layer = SubformerEncoderLayer(
@ -158,83 +153,36 @@ class Subformer(EncoderInterface):
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5), final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
) )
if downsampling_factor[i] != 1: encoders.append(encoder)
mid = len(encoders) // 2
encoder = DownsampledSubformerEncoder( encoder = DownsampledSubformerEncoder(
encoder, [ encoders[mid] ],
dim=encoder_dim[i], input_num_channels=encoder_dim[mid],
downsample=downsampling_factor[i], downsample=2
dropout=dropout,
) )
encoders.append(encoder) encoder = encoders[mid]
for i in range(mid-1, -1, -1):
this_list = [ encoders[mid-i],
encoder,
encoders[mid+i] ]
encoder = DownsampledSubformerEncoder(
this_list,
input_num_channels=encoder_dim[max(0, mid-2)],
downsample=2,
)
self.encoder = encoder
self.encoder_pos = CompactRelPositionalEncoding(64, pos_dim, self.encoder_pos = CompactRelPositionalEncoding(64, pos_dim,
dropout_rate=0.15, dropout_rate=0.15,
length_factor=1.0) length_factor=1.0)
self.encoders = nn.ModuleList(encoders)
#self.downsample_output = SimpleDownsample(max(encoder_dim), #self.downsample_output = SimpleDownsample(max(encoder_dim),
# downsample=output_downsampling_factor, # downsample=output_downsampling_factor,
# dropout=dropout) # dropout=dropout)
def get_feature_masks(
self,
x: torch.Tensor) -> List[Union[float, Tensor]]:
"""
In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
randomized feature masks, one per encoder.
On e.g. 15% of frames, these masks will zero out all enocder dims larger than
some supplied number, e.g. >256, so in effect on those frames we are using
a smaller encoer dim.
We generate the random masks at this level because we want the 2 masks to 'agree'
all the way up the encoder stack. This will mean that the 1st mask will have
mask values repeated self.zipformer_subsampling_factor times.
Args:
x: the embeddings (needed for the shape and dtype and device), of shape
(1, batch_size, encoder_dims0)
"""
num_encoders = len(self.encoder_dim)
if not self.training:
return [ 1.0 ] * num_encoders
(num_frames0, batch_size, _encoder_dims0) = x.shape
assert self.encoder_dim[0] == _encoder_dims0
feature_mask_dropout_prob = 0.125
# mask1 shape: (1, batch_size, 1)
mask1 = (torch.rand(1, batch_size, 1,
device=x.device) >
feature_mask_dropout_prob).to(x.dtype)
# mask2 has additional sequences masked, about twice the number.
mask2 = torch.logical_and(mask1,
(torch.rand(1, batch_size, 1,
device=x.device) >
feature_mask_dropout_prob).to(x.dtype))
# dim: (1, batch_size, 2)
mask = torch.cat((mask1, mask2), dim=-1)
feature_masks = []
for i in range(num_encoders):
channels = self.encoder_dim[i]
feature_mask = torch.ones(1, batch_size, channels,
dtype=x.dtype, device=x.device)
u1 = self.encoder_unmasked_dim[i]
u2 = u1 + (channels - u1) // 2
feature_mask[:, :, u1:u2] *= mask[..., 0:1]
feature_mask[:, :, u2:] *= mask[..., 1:2]
feature_masks.append(feature_mask)
return feature_masks
def forward( def forward(
@ -267,7 +215,6 @@ class Subformer(EncoderInterface):
of frames in `embeddings` before padding. of frames in `embeddings` before padding.
""" """
outputs = [] outputs = []
feature_masks = self.get_feature_masks(x)
attn_offset = self._get_attn_offset(x, src_key_padding_mask) attn_offset = self._get_attn_offset(x, src_key_padding_mask)
@ -281,44 +228,17 @@ class Subformer(EncoderInterface):
pos_emb = self.encoder_pos(x) pos_emb = self.encoder_pos(x)
for i, module in enumerate(self.encoders): x = self.encoder(x,
x = convert_num_channels(x, self.encoder_dim[i])
x = module(x,
pos_emb, pos_emb,
feature_mask=feature_masks[i],
attn_offset=attn_offset, attn_offset=attn_offset,
memory=memory, memory=memory,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
) )
outputs.append(x)
def get_full_dim_output(): # d = self.output_downsampling_factor
num_encoders = len(self.encoder_dim) # lengths = (x_lens + d - 1) // d
assert len(outputs) == num_encoders
output_dim = max(self.encoder_dim)
output_pieces = [ outputs[-1] ]
cur_dim = self.encoder_dim[-1]
for i in range(num_encoders - 2, -1, -1):
d = self.encoder_dim[i]
if d > cur_dim:
this_output = outputs[i]
output_pieces.append(this_output[..., cur_dim:d])
cur_dim = d
assert cur_dim == output_dim
return torch.cat(output_pieces, dim=-1)
# if the last output has the largest dimension, x will be unchanged, return x, x_lens
# it will be the same as outputs[-1]. Otherwise it will be concatenated
# from different pieces of 'outputs', taking each dimension from the
# most recent output that has it present.
x = get_full_dim_output()
#x = self.downsample_output(x)
d = self.output_downsampling_factor
lengths = (x_lens + d - 1) // d
return x, lengths
def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]: def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]:
""" """
@ -507,7 +427,6 @@ class SubformerEncoderLayer(nn.Module):
min_abs=0.1, max_abs=4.0, min_abs=0.1, max_abs=4.0,
) )
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]: def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting(): if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
return None return None
@ -650,11 +569,14 @@ class SubformerEncoder(nn.Module):
final_layerdrop_rate: float = 0.05, final_layerdrop_rate: float = 0.05,
) -> None: ) -> None:
super().__init__() super().__init__()
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)] [copy.deepcopy(encoder_layer) for i in range(num_layers)]
) )
self.num_layers = num_layers self.num_layers = num_layers
self.bypass = BypassModule(self.embed_dim())
assert 0 <= warmup_begin <= warmup_end assert 0 <= warmup_begin <= warmup_end
delta = (1. / num_layers) * (warmup_end - warmup_begin) delta = (1. / num_layers) * (warmup_end - warmup_begin)
@ -666,11 +588,14 @@ class SubformerEncoder(nn.Module):
default=0.0) default=0.0)
cur_begin = cur_end cur_begin = cur_end
def embed_dim(self):
return self.layers[0].embed_dim
def forward( def forward(
self, self,
src: Tensor, src: Tensor,
pos_emb: Tensor, pos_emb: Tensor,
feature_mask: Union[Tensor, float] = 1.0, feature_mask: Optional[Tensor] = None,
attn_offset: Optional[Tensor] = None, attn_offset: Optional[Tensor] = None,
memory: Optional[Tensor] = None, memory: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None, memory_key_padding_mask: Optional[Tensor] = None,
@ -693,10 +618,12 @@ class SubformerEncoder(nn.Module):
Returns: a Tensor with the same shape as src. Returns: a Tensor with the same shape as src.
""" """
src = convert_num_channels(src, self.embed_dim())
output = src output = src
rnd_seed = src.numel() + random.randint(0, 1000) rnd_seed = src.numel() + random.randint(0, 1000)
if feature_mask is not None:
output = output * feature_mask output = output * feature_mask
for i, mod in enumerate(self.layers): for i, mod in enumerate(self.layers):
@ -708,9 +635,10 @@ class SubformerEncoder(nn.Module):
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
) )
if feature_mask is not None:
output = output * feature_mask output = output * feature_mask
return output return self.bypass(src, output)
class BypassModule(nn.Module): class BypassModule(nn.Module):
@ -1014,17 +942,17 @@ class DownsampledSubformerEncoder(nn.Module):
with the origin input, so that the output has the same shape as the input. with the origin input, so that the output has the same shape as the input.
""" """
def __init__(self, def __init__(self,
encoder: nn.Module, encoders: List[nn.Module],
dim: int, input_num_channels: int,
downsample: int, downsample: int):
dropout: FloatLike):
super(DownsampledSubformerEncoder, self).__init__() super(DownsampledSubformerEncoder, self).__init__()
self.downsample_factor = downsample self.downsample_factor = downsample
self.downsampler = LearnedDownsamplingModule(dim, self.downsampler = LearnedDownsamplingModule(input_num_channels,
downsample) downsample)
self.encoder = encoder self.encoders = nn.ModuleList(encoders)
self.out_combiner = BypassModule(dim, straight_through_rate=0.025) self.out_combiner = BypassModule(max(e.embed_dim() for e in encoders),
straight_through_rate=0.0)
def forward(self, def forward(self,
src: Tensor, src: Tensor,
@ -1059,15 +987,35 @@ class DownsampledSubformerEncoder(nn.Module):
attn_offset = self.downsampler.downsample_attn_offset(attn_offset, attn_offset = self.downsampler.downsample_attn_offset(attn_offset,
indexes, indexes,
weights) weights)
outputs = [ src ]
src = self.encoder( for encoder in self.encoders:
src = encoder(
src, src,
pos_emb, pos_emb,
feature_mask=feature_mask,
attn_offset=attn_offset, attn_offset=attn_offset,
memory=memory, memory=memory,
memory_key_padding_mask=memory_key_padding_mask, memory_key_padding_mask=memory_key_padding_mask,
) )
outputs.append(src)
def get_full_dim_output():
num_encoders = len(outputs)
output_dim = max(o.shape[-1] for o in outputs)
output_pieces = [ outputs[-1] ]
cur_dim = outputs[-1].shape[-1]
for i in range(num_encoders - 2, -1, -1):
d = outputs[i].shape[-1]
if d > cur_dim:
this_output = outputs[i]
output_pieces.append(this_output[..., cur_dim:d])
cur_dim = d
assert cur_dim == output_dim
return torch.cat(output_pieces, dim=-1)
src = get_full_dim_output()
src_orig = convert_num_channels(src_orig, src.shape[-1])
src = self.downsampler.upsample(src_orig, src, indexes) src = self.downsampler.upsample(src_orig, src, indexes)
return self.out_combiner(src_orig, src) return self.out_combiner(src_orig, src)
@ -1793,7 +1741,8 @@ def _test_zipformer_main(causal: bool = False):
memory_dim = 100 memory_dim = 100
c = Subformer( c = Subformer(
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4), encoder_dim=(64, 96, 64),
num_heads=(4, 4, 8),
causal=causal, causal=causal,
memory_dim=memory_dim, memory_dim=memory_dim,
) )

View File

@ -125,15 +125,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Number of subformer encoder layers per stack, comma separated.", help="Number of subformer encoder layers per stack, comma separated.",
) )
parser.add_argument(
"--downsampling-factor",
type=str,
default="1,2,4,8,4,2,1",
help="Downsampling factor for each stack of encoder layers.",
)
parser.add_argument( parser.add_argument(
"--feedforward-dim", "--feedforward-dim",
type=str, type=str,
@ -176,13 +167,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="Positional-encoding dimension in encoder stacks: a single int or comma-separated list." help="Positional-encoding dimension in encoder stacks: a single int or comma-separated list."
) )
parser.add_argument(
"--encoder-unmasked-dim",
type=str,
default="192,192,256,256,256,192,192",
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
)
def get_parser(): def get_parser():
@ -478,11 +463,8 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
def get_encoder_model(params: AttributeDict) -> nn.Module: def get_encoder_model(params: AttributeDict) -> nn.Module:
#chunk_size = _to_int_tuple(params.downsampling_factor)[-1] #chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
encoder = Subformer( encoder = Subformer(
#output_downsampling_factor=chunk_size,
downsampling_factor=_to_int_tuple(params.downsampling_factor),
num_encoder_layers=_to_int_tuple(params.num_encoder_layers), num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
encoder_dim=_to_int_tuple(params.encoder_dim), encoder_dim=_to_int_tuple(params.encoder_dim),
encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
query_head_dim=_to_int_tuple(params.query_head_dim), query_head_dim=_to_int_tuple(params.query_head_dim),
pos_dim=int(params.pos_dim), pos_dim=int(params.pos_dim),
value_head_dim=_to_int_tuple(params.value_head_dim), value_head_dim=_to_int_tuple(params.value_head_dim),