mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make output dim of Zipformer be max dim
This commit is contained in:
parent
fb7a967276
commit
167b58baa0
@ -542,7 +542,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
encoder_dim=int(params.encoder_dim.split(',')[-1]),
|
||||
encoder_dim=int(max(params.encoder_dim.split(','))),
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
@ -559,7 +559,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
joiner=joiner,
|
||||
encoder_dim=int(params.encoder_dim.split(',')[-1]),
|
||||
encoder_dim=int(max(params.encoder_dim.split(','))),
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
|
||||
@ -221,8 +221,8 @@ class Zipformer(EncoderInterface):
|
||||
# initializes self.skip_layers and self.skip_modules
|
||||
self._init_skip_modules()
|
||||
|
||||
self.downsample_output = SimpleDownsample(encoder_dim[-1],
|
||||
encoder_dim[-1],
|
||||
self.downsample_output = SimpleDownsample(max(encoder_dim),
|
||||
max(encoder_dim),
|
||||
downsample=output_downsampling_factor,
|
||||
dropout=dropout)
|
||||
|
||||
@ -324,7 +324,7 @@ class Zipformer(EncoderInterface):
|
||||
`x` before padding.
|
||||
Returns:
|
||||
Return a tuple containing 2 tensors:
|
||||
- embeddings: its shape is (batch_size, output_seq_len, encoder_dim[-1])
|
||||
- embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim))
|
||||
- lengths, a tensor of shape (batch_size,) containing the number
|
||||
of frames in `embeddings` before padding.
|
||||
"""
|
||||
@ -349,8 +349,9 @@ class Zipformer(EncoderInterface):
|
||||
ds = self.downsampling_factor[i]
|
||||
if self.skip_layers[i] is not None:
|
||||
# this how we implement U-net-like skipping of some series of
|
||||
# stacks. The layer_skip_dropout_prob is to discourage it, especially
|
||||
# early in training, from completely ignoring the middle layers.
|
||||
# stacks. The layer_skip_dropout_prob is to discourage it from
|
||||
# completely ignoring the middle layers, especially early in
|
||||
# training,
|
||||
batch_size = x.shape[0]
|
||||
skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
|
||||
if self.training:
|
||||
@ -363,8 +364,28 @@ class Zipformer(EncoderInterface):
|
||||
feature_mask=feature_masks[i],
|
||||
src_key_padding_mask=None if mask is None else mask[...,::ds])
|
||||
outputs.append(x)
|
||||
# logging.info(f"Memory allocated after stack {i}: {torch.cuda.memory_allocated() // 1000000}M")
|
||||
|
||||
def get_full_dim_output():
|
||||
num_encoders = len(self.encoder_dim)
|
||||
assert len(outputs) == num_encoders
|
||||
output_dim = max(self.encoder_dim)
|
||||
output_shape = outputs[-1].shape[:-1] + (output_dim,)
|
||||
output_pieces = [ outputs[-1] ]
|
||||
cur_dim = self.encoder_dim[-1]
|
||||
for i in range(num_encoders - 2, -1, -1):
|
||||
d = self.encoder_dim[i]
|
||||
if d > cur_dim:
|
||||
this_output = outputs[i]
|
||||
output_pieces.append(this_output[..., cur_dim:d])
|
||||
cur_dim = d
|
||||
assert cur_dim == output_dim
|
||||
return torch.cat(output_pieces, dim=-1)
|
||||
|
||||
# if the last output has the largest dimension, x will be unchanged,
|
||||
# it will be the same as outputs[-1]. Otherwise it will be concatenated
|
||||
# from different pieces of 'outputs', taking each dimension from the
|
||||
# most recent output that has it present.
|
||||
x = get_full_dim_output()
|
||||
x = self.downsample_output(x)
|
||||
# class Downsample has this rounding behavior..
|
||||
assert self.output_downsampling_factor == 2
|
||||
@ -375,6 +396,8 @@ class Zipformer(EncoderInterface):
|
||||
return x, lengths
|
||||
|
||||
|
||||
|
||||
|
||||
def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
|
||||
return ScheduledFloat((0.0, x),
|
||||
(20000.0, ratio * x),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user