diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 8ebe627a5..b98ec1c0e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -542,7 +542,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module: def get_joiner_model(params: AttributeDict) -> nn.Module: joiner = Joiner( - encoder_dim=int(params.encoder_dim.split(',')[-1]), + encoder_dim=int(max(params.encoder_dim.split(','))), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, @@ -559,7 +559,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module: encoder=encoder, decoder=decoder, joiner=joiner, - encoder_dim=int(params.encoder_dim.split(',')[-1]), + encoder_dim=int(max(params.encoder_dim.split(','))), decoder_dim=params.decoder_dim, joiner_dim=params.joiner_dim, vocab_size=params.vocab_size, diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index fc70150c9..29b9d84f1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -221,10 +221,10 @@ class Zipformer(EncoderInterface): # initializes self.skip_layers and self.skip_modules self._init_skip_modules() - self.downsample_output = SimpleDownsample(encoder_dim[-1], - encoder_dim[-1], - downsample=output_downsampling_factor, - dropout=dropout) + self.downsample_output = SimpleDownsample(max(encoder_dim), + max(encoder_dim), + downsample=output_downsampling_factor, + dropout=dropout) def _init_skip_modules(self): @@ -324,7 +324,7 @@ class Zipformer(EncoderInterface): `x` before padding. Returns: Return a tuple containing 2 tensors: - - embeddings: its shape is (batch_size, output_seq_len, encoder_dim[-1]) + - embeddings: its shape is (batch_size, output_seq_len, max(encoder_dim)) - lengths, a tensor of shape (batch_size,) containing the number of frames in `embeddings` before padding. """ @@ -349,8 +349,9 @@ class Zipformer(EncoderInterface): ds = self.downsampling_factor[i] if self.skip_layers[i] is not None: # this how we implement U-net-like skipping of some series of - # stacks. The layer_skip_dropout_prob is to discourage it, especially - # early in training, from completely ignoring the middle layers. + # stacks. The layer_skip_dropout_prob is to discourage it from + # completely ignoring the middle layers, especially early in + # training, batch_size = x.shape[0] skip_x = self.skip_modules[i](outputs[self.skip_layers[i]], x) if self.training: @@ -363,8 +364,28 @@ class Zipformer(EncoderInterface): feature_mask=feature_masks[i], src_key_padding_mask=None if mask is None else mask[...,::ds]) outputs.append(x) - # logging.info(f"Memory allocated after stack {i}: {torch.cuda.memory_allocated() // 1000000}M") + def get_full_dim_output(): + num_encoders = len(self.encoder_dim) + assert len(outputs) == num_encoders + output_dim = max(self.encoder_dim) + output_shape = outputs[-1].shape[:-1] + (output_dim,) + output_pieces = [ outputs[-1] ] + cur_dim = self.encoder_dim[-1] + for i in range(num_encoders - 2, -1, -1): + d = self.encoder_dim[i] + if d > cur_dim: + this_output = outputs[i] + output_pieces.append(this_output[..., cur_dim:d]) + cur_dim = d + assert cur_dim == output_dim + return torch.cat(output_pieces, dim=-1) + + # if the last output has the largest dimension, x will be unchanged, + # it will be the same as outputs[-1]. Otherwise it will be concatenated + # from different pieces of 'outputs', taking each dimension from the + # most recent output that has it present. + x = get_full_dim_output() x = self.downsample_output(x) # class Downsample has this rounding behavior.. assert self.output_downsampling_factor == 2 @@ -375,6 +396,8 @@ class Zipformer(EncoderInterface): return x, lengths + + def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: return ScheduledFloat((0.0, x), (20000.0, ratio * x),