Make output dim of Zipformer be max dim

This commit is contained in:
Daniel Povey 2023-01-14 14:29:29 +08:00
parent fb7a967276
commit 167b58baa0
2 changed files with 33 additions and 10 deletions

View File

@ -542,7 +542,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
encoder_dim=int(params.encoder_dim.split(',')[-1]),
encoder_dim=int(max(params.encoder_dim.split(','))),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
@ -559,7 +559,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder=encoder,
decoder=decoder,
joiner=joiner,
encoder_dim=int(params.encoder_dim.split(',')[-1]),
encoder_dim=int(max(params.encoder_dim.split(','))),
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,

View File

@ -221,8 +221,8 @@ class Zipformer(EncoderInterface):
# initializes self.skip_layers and self.skip_modules
self._init_skip_modules()
self.downsample_output = SimpleDownsample(encoder_dim[-1],
encoder_dim[-1],
self.downsample_output = SimpleDownsample(max(encoder_dim),
max(encoder_dim),
downsample=output_downsampling_factor,
dropout=dropout)
@ -324,7 +324,7 @@ class Zipformer(EncoderInterface):
`x` before padding.
Returns:
Return a tuple containing 2 tensors:
- embeddings: its shape is (batch_size, output_seq_len, encoder_dim[-1])
- embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim))
- lengths, a tensor of shape (batch_size,) containing the number
of frames in `embeddings` before padding.
"""
@ -349,8 +349,9 @@ class Zipformer(EncoderInterface):
ds = self.downsampling_factor[i]
if self.skip_layers[i] is not None:
# this how we implement U-net-like skipping of some series of
# stacks. The layer_skip_dropout_prob is to discourage it, especially
# early in training, from completely ignoring the middle layers.
# stacks. The layer_skip_dropout_prob is to discourage it from
# completely ignoring the middle layers, especially early in
# training,
batch_size = x.shape[0]
skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
if self.training:
@ -363,8 +364,28 @@ class Zipformer(EncoderInterface):
feature_mask=feature_masks[i],
src_key_padding_mask=None if mask is None else mask[...,::ds])
outputs.append(x)
# logging.info(f"Memory allocated after stack {i}: {torch.cuda.memory_allocated() // 1000000}M")
def get_full_dim_output():
num_encoders = len(self.encoder_dim)
assert len(outputs) == num_encoders
output_dim = max(self.encoder_dim)
output_shape = outputs[-1].shape[:-1] + (output_dim,)
output_pieces = [ outputs[-1] ]
cur_dim = self.encoder_dim[-1]
for i in range(num_encoders - 2, -1, -1):
d = self.encoder_dim[i]
if d > cur_dim:
this_output = outputs[i]
output_pieces.append(this_output[..., cur_dim:d])
cur_dim = d
assert cur_dim == output_dim
return torch.cat(output_pieces, dim=-1)
# if the last output has the largest dimension, x will be unchanged,
# it will be the same as outputs[-1]. Otherwise it will be concatenated
# from different pieces of 'outputs', taking each dimension from the
# most recent output that has it present.
x = get_full_dim_output()
x = self.downsample_output(x)
# class Downsample has this rounding behavior..
assert self.output_downsampling_factor == 2
@ -375,6 +396,8 @@ class Zipformer(EncoderInterface):
return x, lengths
def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
return ScheduledFloat((0.0, x),
(20000.0, ratio * x),