small fixes

This commit is contained in:
Fangjun Kuang 2023-07-24 18:09:55 +08:00
parent d5dcca674c
commit c9055e03e3
2 changed files with 3 additions and 3 deletions

View File

@ -320,7 +320,7 @@ class AsrModel(nn.Module):
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)

View File

@ -219,7 +219,7 @@ class Zipformer2(EncoderInterface):
(num_frames0, batch_size, _encoder_dims0) = x.shape
assert self.encoder_dim[0] == _encoder_dims0
assert self.encoder_dim[0] == _encoder_dims0, (self.encoder_dim[0], _encoder_dims0)
feature_mask_dropout_prob = 0.125
@ -334,7 +334,7 @@ class Zipformer2(EncoderInterface):
x = self._get_full_dim_output(outputs)
x = self.downsample_output(x)
# class Downsample has this rounding behavior..
assert self.output_downsampling_factor == 2
assert self.output_downsampling_factor == 2, self.output_downsampling_factor
if torch.jit.is_scripting() or torch.jit.is_tracing():
lengths = (x_lens + 1) // 2
else: