mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Make output dim of Zipformer be max dim
This commit is contained in:
parent
fb7a967276
commit
167b58baa0
@ -542,7 +542,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
encoder_dim=int(params.encoder_dim.split(',')[-1]),
|
encoder_dim=int(max(params.encoder_dim.split(','))),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
@ -559,7 +559,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
joiner=joiner,
|
joiner=joiner,
|
||||||
encoder_dim=int(params.encoder_dim.split(',')[-1]),
|
encoder_dim=int(max(params.encoder_dim.split(','))),
|
||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
|
|||||||
@ -221,10 +221,10 @@ class Zipformer(EncoderInterface):
|
|||||||
# initializes self.skip_layers and self.skip_modules
|
# initializes self.skip_layers and self.skip_modules
|
||||||
self._init_skip_modules()
|
self._init_skip_modules()
|
||||||
|
|
||||||
self.downsample_output = SimpleDownsample(encoder_dim[-1],
|
self.downsample_output = SimpleDownsample(max(encoder_dim),
|
||||||
encoder_dim[-1],
|
max(encoder_dim),
|
||||||
downsample=output_downsampling_factor,
|
downsample=output_downsampling_factor,
|
||||||
dropout=dropout)
|
dropout=dropout)
|
||||||
|
|
||||||
|
|
||||||
def _init_skip_modules(self):
|
def _init_skip_modules(self):
|
||||||
@ -324,7 +324,7 @@ class Zipformer(EncoderInterface):
|
|||||||
`x` before padding.
|
`x` before padding.
|
||||||
Returns:
|
Returns:
|
||||||
Return a tuple containing 2 tensors:
|
Return a tuple containing 2 tensors:
|
||||||
- embeddings: its shape is (batch_size, output_seq_len, encoder_dim[-1])
|
- embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim))
|
||||||
- lengths, a tensor of shape (batch_size,) containing the number
|
- lengths, a tensor of shape (batch_size,) containing the number
|
||||||
of frames in `embeddings` before padding.
|
of frames in `embeddings` before padding.
|
||||||
"""
|
"""
|
||||||
@ -349,8 +349,9 @@ class Zipformer(EncoderInterface):
|
|||||||
ds = self.downsampling_factor[i]
|
ds = self.downsampling_factor[i]
|
||||||
if self.skip_layers[i] is not None:
|
if self.skip_layers[i] is not None:
|
||||||
# this how we implement U-net-like skipping of some series of
|
# this how we implement U-net-like skipping of some series of
|
||||||
# stacks. The layer_skip_dropout_prob is to discourage it, especially
|
# stacks. The layer_skip_dropout_prob is to discourage it from
|
||||||
# early in training, from completely ignoring the middle layers.
|
# completely ignoring the middle layers, especially early in
|
||||||
|
# training,
|
||||||
batch_size = x.shape[0]
|
batch_size = x.shape[0]
|
||||||
skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
|
skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x)
|
||||||
if self.training:
|
if self.training:
|
||||||
@ -363,8 +364,28 @@ class Zipformer(EncoderInterface):
|
|||||||
feature_mask=feature_masks[i],
|
feature_mask=feature_masks[i],
|
||||||
src_key_padding_mask=None if mask is None else mask[...,::ds])
|
src_key_padding_mask=None if mask is None else mask[...,::ds])
|
||||||
outputs.append(x)
|
outputs.append(x)
|
||||||
# logging.info(f"Memory allocated after stack {i}: {torch.cuda.memory_allocated() // 1000000}M")
|
|
||||||
|
|
||||||
|
def get_full_dim_output():
|
||||||
|
num_encoders = len(self.encoder_dim)
|
||||||
|
assert len(outputs) == num_encoders
|
||||||
|
output_dim = max(self.encoder_dim)
|
||||||
|
output_shape = outputs[-1].shape[:-1] + (output_dim,)
|
||||||
|
output_pieces = [ outputs[-1] ]
|
||||||
|
cur_dim = self.encoder_dim[-1]
|
||||||
|
for i in range(num_encoders - 2, -1, -1):
|
||||||
|
d = self.encoder_dim[i]
|
||||||
|
if d > cur_dim:
|
||||||
|
this_output = outputs[i]
|
||||||
|
output_pieces.append(this_output[..., cur_dim:d])
|
||||||
|
cur_dim = d
|
||||||
|
assert cur_dim == output_dim
|
||||||
|
return torch.cat(output_pieces, dim=-1)
|
||||||
|
|
||||||
|
# if the last output has the largest dimension, x will be unchanged,
|
||||||
|
# it will be the same as outputs[-1]. Otherwise it will be concatenated
|
||||||
|
# from different pieces of 'outputs', taking each dimension from the
|
||||||
|
# most recent output that has it present.
|
||||||
|
x = get_full_dim_output()
|
||||||
x = self.downsample_output(x)
|
x = self.downsample_output(x)
|
||||||
# class Downsample has this rounding behavior..
|
# class Downsample has this rounding behavior..
|
||||||
assert self.output_downsampling_factor == 2
|
assert self.output_downsampling_factor == 2
|
||||||
@ -375,6 +396,8 @@ class Zipformer(EncoderInterface):
|
|||||||
return x, lengths
|
return x, lengths
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
|
def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat:
|
||||||
return ScheduledFloat((0.0, x),
|
return ScheduledFloat((0.0, x),
|
||||||
(20000.0, ratio * x),
|
(20000.0, ratio * x),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user