mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Code cleanup
This commit is contained in:
parent
671e9ee5bd
commit
0a76215fd7
@ -58,9 +58,6 @@ class Subformer(EncoderInterface):
|
||||
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
|
||||
encoder stack.
|
||||
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
|
||||
encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of
|
||||
the encoder stacks for purposes of per-frame dropout (recommend 256 for
|
||||
now).
|
||||
query_head_dim (int or Tuple[int]): dimension of query and key per attention
|
||||
head: per stack, if a tuple..
|
||||
value_head_dim (int or Tuple[int]): dimension of value in each attention head
|
||||
@ -83,11 +80,8 @@ class Subformer(EncoderInterface):
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
output_downsampling_factor: int = 2,
|
||||
downsampling_factor: Tuple[int] = (2, 4),
|
||||
encoder_dim: Union[int, Tuple[int]] = 384,
|
||||
encoder_dim: Union[int, Tuple[int]] = (384, 512, 384),
|
||||
num_encoder_layers: Union[int, Tuple[int]] = 4,
|
||||
encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
|
||||
query_head_dim: Union[int, Tuple[int]] = 24,
|
||||
value_head_dim: Union[int, Tuple[int]] = 12,
|
||||
num_heads: Union[int, Tuple[int]] = 8,
|
||||
@ -106,19 +100,16 @@ class Subformer(EncoderInterface):
|
||||
|
||||
def _to_tuple(x):
|
||||
""" Converts a single int or a 1-tuple of an int to a tuple with the same length
|
||||
as downsampling_factor"""
|
||||
as encoder_dim"""
|
||||
if isinstance(x, int):
|
||||
x = (x,)
|
||||
if len(x) == 1:
|
||||
x = x * len(downsampling_factor)
|
||||
x = x * len(encoder_dim)
|
||||
else:
|
||||
assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
|
||||
assert len(x) == len(encoder_dim) and isinstance(x[0], int)
|
||||
return x
|
||||
|
||||
self.output_downsampling_factor = output_downsampling_factor # int
|
||||
self.downsampling_factor = downsampling_factor # tuple
|
||||
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
|
||||
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple
|
||||
self.encoder_dim = encoder_dim
|
||||
num_encoder_layers = _to_tuple(num_encoder_layers)
|
||||
query_head_dim = _to_tuple(query_head_dim)
|
||||
value_head_dim = _to_tuple(value_head_dim)
|
||||
@ -126,13 +117,17 @@ class Subformer(EncoderInterface):
|
||||
feedforward_dim = _to_tuple(feedforward_dim)
|
||||
self.causal = causal
|
||||
|
||||
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
||||
assert u <= d
|
||||
|
||||
# each one will be SubformerEncoder or DownsampledSubformerEncoder
|
||||
encoders = []
|
||||
|
||||
num_encoders = len(downsampling_factor)
|
||||
|
||||
num_encoders = len(encoder_dim)
|
||||
assert num_encoders % 2 == 1
|
||||
downsampling_factor = [ 1 ]
|
||||
while len(downsampling_factor) < num_encoders:
|
||||
downsampling_factor = [ 1 ] + [ d * 2 for d in downsampling_factor ] + [ 1 ]
|
||||
|
||||
for i in range(num_encoders):
|
||||
|
||||
encoder_layer = SubformerEncoderLayer(
|
||||
@ -158,83 +153,36 @@ class Subformer(EncoderInterface):
|
||||
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
||||
)
|
||||
|
||||
if downsampling_factor[i] != 1:
|
||||
encoder = DownsampledSubformerEncoder(
|
||||
encoder,
|
||||
dim=encoder_dim[i],
|
||||
downsample=downsampling_factor[i],
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
encoders.append(encoder)
|
||||
|
||||
mid = len(encoders) // 2
|
||||
encoder = DownsampledSubformerEncoder(
|
||||
[ encoders[mid] ],
|
||||
input_num_channels=encoder_dim[mid],
|
||||
downsample=2
|
||||
)
|
||||
|
||||
encoder = encoders[mid]
|
||||
for i in range(mid-1, -1, -1):
|
||||
this_list = [ encoders[mid-i],
|
||||
encoder,
|
||||
encoders[mid+i] ]
|
||||
encoder = DownsampledSubformerEncoder(
|
||||
this_list,
|
||||
input_num_channels=encoder_dim[max(0, mid-2)],
|
||||
downsample=2,
|
||||
)
|
||||
|
||||
self.encoder = encoder
|
||||
|
||||
self.encoder_pos = CompactRelPositionalEncoding(64, pos_dim,
|
||||
dropout_rate=0.15,
|
||||
length_factor=1.0)
|
||||
|
||||
self.encoders = nn.ModuleList(encoders)
|
||||
|
||||
#self.downsample_output = SimpleDownsample(max(encoder_dim),
|
||||
# downsample=output_downsampling_factor,
|
||||
# dropout=dropout)
|
||||
|
||||
def get_feature_masks(
|
||||
self,
|
||||
x: torch.Tensor) -> List[Union[float, Tensor]]:
|
||||
"""
|
||||
In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
|
||||
randomized feature masks, one per encoder.
|
||||
On e.g. 15% of frames, these masks will zero out all enocder dims larger than
|
||||
some supplied number, e.g. >256, so in effect on those frames we are using
|
||||
a smaller encoer dim.
|
||||
|
||||
We generate the random masks at this level because we want the 2 masks to 'agree'
|
||||
all the way up the encoder stack. This will mean that the 1st mask will have
|
||||
mask values repeated self.zipformer_subsampling_factor times.
|
||||
|
||||
Args:
|
||||
x: the embeddings (needed for the shape and dtype and device), of shape
|
||||
(1, batch_size, encoder_dims0)
|
||||
"""
|
||||
num_encoders = len(self.encoder_dim)
|
||||
if not self.training:
|
||||
return [ 1.0 ] * num_encoders
|
||||
|
||||
(num_frames0, batch_size, _encoder_dims0) = x.shape
|
||||
|
||||
assert self.encoder_dim[0] == _encoder_dims0
|
||||
|
||||
feature_mask_dropout_prob = 0.125
|
||||
|
||||
# mask1 shape: (1, batch_size, 1)
|
||||
mask1 = (torch.rand(1, batch_size, 1,
|
||||
device=x.device) >
|
||||
feature_mask_dropout_prob).to(x.dtype)
|
||||
|
||||
# mask2 has additional sequences masked, about twice the number.
|
||||
mask2 = torch.logical_and(mask1,
|
||||
(torch.rand(1, batch_size, 1,
|
||||
device=x.device) >
|
||||
feature_mask_dropout_prob).to(x.dtype))
|
||||
|
||||
|
||||
# dim: (1, batch_size, 2)
|
||||
mask = torch.cat((mask1, mask2), dim=-1)
|
||||
|
||||
feature_masks = []
|
||||
for i in range(num_encoders):
|
||||
channels = self.encoder_dim[i]
|
||||
feature_mask = torch.ones(1, batch_size, channels,
|
||||
dtype=x.dtype, device=x.device)
|
||||
u1 = self.encoder_unmasked_dim[i]
|
||||
u2 = u1 + (channels - u1) // 2
|
||||
|
||||
feature_mask[:, :, u1:u2] *= mask[..., 0:1]
|
||||
feature_mask[:, :, u2:] *= mask[..., 1:2]
|
||||
|
||||
feature_masks.append(feature_mask)
|
||||
|
||||
return feature_masks
|
||||
|
||||
|
||||
def forward(
|
||||
@ -267,7 +215,6 @@ class Subformer(EncoderInterface):
|
||||
of frames in `embeddings` before padding.
|
||||
"""
|
||||
outputs = []
|
||||
feature_masks = self.get_feature_masks(x)
|
||||
|
||||
attn_offset = self._get_attn_offset(x, src_key_padding_mask)
|
||||
|
||||
@ -281,44 +228,17 @@ class Subformer(EncoderInterface):
|
||||
|
||||
pos_emb = self.encoder_pos(x)
|
||||
|
||||
for i, module in enumerate(self.encoders):
|
||||
x = convert_num_channels(x, self.encoder_dim[i])
|
||||
x = self.encoder(x,
|
||||
pos_emb,
|
||||
attn_offset=attn_offset,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
|
||||
x = module(x,
|
||||
pos_emb,
|
||||
feature_mask=feature_masks[i],
|
||||
attn_offset=attn_offset,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
outputs.append(x)
|
||||
# d = self.output_downsampling_factor
|
||||
# lengths = (x_lens + d - 1) // d
|
||||
|
||||
def get_full_dim_output():
|
||||
num_encoders = len(self.encoder_dim)
|
||||
assert len(outputs) == num_encoders
|
||||
output_dim = max(self.encoder_dim)
|
||||
output_pieces = [ outputs[-1] ]
|
||||
cur_dim = self.encoder_dim[-1]
|
||||
for i in range(num_encoders - 2, -1, -1):
|
||||
d = self.encoder_dim[i]
|
||||
if d > cur_dim:
|
||||
this_output = outputs[i]
|
||||
output_pieces.append(this_output[..., cur_dim:d])
|
||||
cur_dim = d
|
||||
assert cur_dim == output_dim
|
||||
return torch.cat(output_pieces, dim=-1)
|
||||
|
||||
# if the last output has the largest dimension, x will be unchanged,
|
||||
# it will be the same as outputs[-1]. Otherwise it will be concatenated
|
||||
# from different pieces of 'outputs', taking each dimension from the
|
||||
# most recent output that has it present.
|
||||
x = get_full_dim_output()
|
||||
#x = self.downsample_output(x)
|
||||
|
||||
d = self.output_downsampling_factor
|
||||
lengths = (x_lens + d - 1) // d
|
||||
|
||||
return x, lengths
|
||||
return x, x_lens
|
||||
|
||||
def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]:
|
||||
"""
|
||||
@ -507,7 +427,6 @@ class SubformerEncoderLayer(nn.Module):
|
||||
min_abs=0.1, max_abs=4.0,
|
||||
)
|
||||
|
||||
|
||||
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
|
||||
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
|
||||
return None
|
||||
@ -650,11 +569,14 @@ class SubformerEncoder(nn.Module):
|
||||
final_layerdrop_rate: float = 0.05,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.layers = nn.ModuleList(
|
||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||
)
|
||||
self.num_layers = num_layers
|
||||
|
||||
self.bypass = BypassModule(self.embed_dim())
|
||||
|
||||
assert 0 <= warmup_begin <= warmup_end
|
||||
|
||||
delta = (1. / num_layers) * (warmup_end - warmup_begin)
|
||||
@ -666,11 +588,14 @@ class SubformerEncoder(nn.Module):
|
||||
default=0.0)
|
||||
cur_begin = cur_end
|
||||
|
||||
def embed_dim(self):
|
||||
return self.layers[0].embed_dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
src: Tensor,
|
||||
pos_emb: Tensor,
|
||||
feature_mask: Union[Tensor, float] = 1.0,
|
||||
feature_mask: Optional[Tensor] = None,
|
||||
attn_offset: Optional[Tensor] = None,
|
||||
memory: Optional[Tensor] = None,
|
||||
memory_key_padding_mask: Optional[Tensor] = None,
|
||||
@ -693,11 +618,13 @@ class SubformerEncoder(nn.Module):
|
||||
|
||||
Returns: a Tensor with the same shape as src.
|
||||
"""
|
||||
src = convert_num_channels(src, self.embed_dim())
|
||||
output = src
|
||||
|
||||
rnd_seed = src.numel() + random.randint(0, 1000)
|
||||
|
||||
output = output * feature_mask
|
||||
if feature_mask is not None:
|
||||
output = output * feature_mask
|
||||
|
||||
for i, mod in enumerate(self.layers):
|
||||
output = mod(
|
||||
@ -708,9 +635,10 @@ class SubformerEncoder(nn.Module):
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
|
||||
output = output * feature_mask
|
||||
if feature_mask is not None:
|
||||
output = output * feature_mask
|
||||
|
||||
return output
|
||||
return self.bypass(src, output)
|
||||
|
||||
|
||||
class BypassModule(nn.Module):
|
||||
@ -1014,17 +942,17 @@ class DownsampledSubformerEncoder(nn.Module):
|
||||
with the origin input, so that the output has the same shape as the input.
|
||||
"""
|
||||
def __init__(self,
|
||||
encoder: nn.Module,
|
||||
dim: int,
|
||||
downsample: int,
|
||||
dropout: FloatLike):
|
||||
encoders: List[nn.Module],
|
||||
input_num_channels: int,
|
||||
downsample: int):
|
||||
super(DownsampledSubformerEncoder, self).__init__()
|
||||
self.downsample_factor = downsample
|
||||
self.downsampler = LearnedDownsamplingModule(dim,
|
||||
self.downsampler = LearnedDownsamplingModule(input_num_channels,
|
||||
downsample)
|
||||
self.encoder = encoder
|
||||
self.encoders = nn.ModuleList(encoders)
|
||||
|
||||
self.out_combiner = BypassModule(dim, straight_through_rate=0.025)
|
||||
self.out_combiner = BypassModule(max(e.embed_dim() for e in encoders),
|
||||
straight_through_rate=0.0)
|
||||
|
||||
def forward(self,
|
||||
src: Tensor,
|
||||
@ -1059,15 +987,35 @@ class DownsampledSubformerEncoder(nn.Module):
|
||||
attn_offset = self.downsampler.downsample_attn_offset(attn_offset,
|
||||
indexes,
|
||||
weights)
|
||||
outputs = [ src ]
|
||||
|
||||
for encoder in self.encoders:
|
||||
src = encoder(
|
||||
src,
|
||||
pos_emb,
|
||||
attn_offset=attn_offset,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
outputs.append(src)
|
||||
|
||||
def get_full_dim_output():
|
||||
num_encoders = len(outputs)
|
||||
output_dim = max(o.shape[-1] for o in outputs)
|
||||
output_pieces = [ outputs[-1] ]
|
||||
cur_dim = outputs[-1].shape[-1]
|
||||
for i in range(num_encoders - 2, -1, -1):
|
||||
d = outputs[i].shape[-1]
|
||||
if d > cur_dim:
|
||||
this_output = outputs[i]
|
||||
output_pieces.append(this_output[..., cur_dim:d])
|
||||
cur_dim = d
|
||||
assert cur_dim == output_dim
|
||||
return torch.cat(output_pieces, dim=-1)
|
||||
|
||||
src = get_full_dim_output()
|
||||
src_orig = convert_num_channels(src_orig, src.shape[-1])
|
||||
|
||||
src = self.encoder(
|
||||
src,
|
||||
pos_emb,
|
||||
feature_mask=feature_mask,
|
||||
attn_offset=attn_offset,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
)
|
||||
src = self.downsampler.upsample(src_orig, src, indexes)
|
||||
|
||||
return self.out_combiner(src_orig, src)
|
||||
@ -1793,7 +1741,8 @@ def _test_zipformer_main(causal: bool = False):
|
||||
memory_dim = 100
|
||||
|
||||
c = Subformer(
|
||||
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
||||
encoder_dim=(64, 96, 64),
|
||||
num_heads=(4, 4, 8),
|
||||
causal=causal,
|
||||
memory_dim=memory_dim,
|
||||
)
|
||||
|
||||
@ -125,15 +125,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Number of subformer encoder layers per stack, comma separated.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--downsampling-factor",
|
||||
type=str,
|
||||
default="1,2,4,8,4,2,1",
|
||||
help="Downsampling factor for each stack of encoder layers.",
|
||||
)
|
||||
|
||||
|
||||
parser.add_argument(
|
||||
"--feedforward-dim",
|
||||
type=str,
|
||||
@ -176,13 +167,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="Positional-encoding dimension in encoder stacks: a single int or comma-separated list."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-unmasked-dim",
|
||||
type=str,
|
||||
default="192,192,256,256,256,192,192",
|
||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
|
||||
)
|
||||
|
||||
|
||||
|
||||
def get_parser():
|
||||
@ -478,11 +463,8 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
#chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
|
||||
encoder = Subformer(
|
||||
#output_downsampling_factor=chunk_size,
|
||||
downsampling_factor=_to_int_tuple(params.downsampling_factor),
|
||||
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
||||
encoder_dim=_to_int_tuple(params.encoder_dim),
|
||||
encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
|
||||
query_head_dim=_to_int_tuple(params.query_head_dim),
|
||||
pos_dim=int(params.pos_dim),
|
||||
value_head_dim=_to_int_tuple(params.value_head_dim),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user