Bug fixes, config changes

This commit is contained in:
Daniel Povey 2023-01-12 22:11:42 +08:00
parent d3b3592986
commit bac72718f0
2 changed files with 15 additions and 6 deletions

View File

@ -123,7 +123,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--num-encoder-layers", "--num-encoder-layers",
type=str, type=str,
default="2,2,4,6,4,2", default="2,4,4,4,4,4",
help="Number of zipformer encoder layers per stack, comma separated.", help="Number of zipformer encoder layers per stack, comma separated.",
) )
@ -139,7 +139,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--feedforward-dim", "--feedforward-dim",
type=str, type=str,
default="384,768,1024,1536,1024,768", default="384,512,1024,1536,1024,512",
help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.", help="Feedforward dimension of the zipformer encoder layers, per stack, comma separated.",
) )
@ -160,7 +160,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--encoder-dim", "--encoder-dim",
type=str, type=str,
default="192,256,320,384,320,256", default="192,192,256,320,256,192",
help="Embedding dimension in encoder stacks: a single int or comma-separated list." help="Embedding dimension in encoder stacks: a single int or comma-separated list."
) )
@ -195,7 +195,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument( parser.add_argument(
"--encoder-unmasked-dim", "--encoder-unmasked-dim",
type=str, type=str,
default="164,192,256,256,256,192", default="192,192,224,224,224,192",
help="Unmasked dimensions in the encoders, relates to augmentation during training. " help="Unmasked dimensions in the encoders, relates to augmentation during training. "
"A single int or comma-separated list. Must be <= each corresponding encoder_dim." "A single int or comma-separated list. Must be <= each corresponding encoder_dim."
) )

View File

@ -215,6 +215,7 @@ class Zipformer(EncoderInterface):
encoder.lr_scale = downsampling_factor[i] ** -0.33 encoder.lr_scale = downsampling_factor[i] ** -0.33
encoders.append(encoder) encoders.append(encoder)
self.encoders = nn.ModuleList(encoders) self.encoders = nn.ModuleList(encoders)
# initializes self.skip_layers and self.skip_modules # initializes self.skip_layers and self.skip_modules
@ -327,8 +328,12 @@ class Zipformer(EncoderInterface):
- lengths, a tensor of shape (batch_size,) containing the number - lengths, a tensor of shape (batch_size,) containing the number
of frames in `embeddings` before padding. of frames in `embeddings` before padding.
""" """
logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x = self.encoder_embed(x) x = self.encoder_embed(x)
logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
with warnings.catch_warnings(): with warnings.catch_warnings():
@ -358,6 +363,7 @@ class Zipformer(EncoderInterface):
feature_mask=feature_masks[i], feature_mask=feature_masks[i],
src_key_padding_mask=None if mask is None else mask[...,::ds]) src_key_padding_mask=None if mask is None else mask[...,::ds])
outputs.append(x) outputs.append(x)
logging.info(f"Memory allocated after stack {i}: {torch.cuda.memory_allocated() // 1000000}M")
x = self.downsample_output(x) x = self.downsample_output(x)
# class Downsample has this rounding behavior.. # class Downsample has this rounding behavior..
@ -834,6 +840,7 @@ class SimpleDownsample(torch.nn.Module):
else: else:
self.extra_proj = None self.extra_proj = None
self.downsample = downsample self.downsample = downsample
self.out_channels = out_channels
def forward(self, def forward(self,
src: Tensor) -> Tensor: src: Tensor) -> Tensor:
@ -867,6 +874,8 @@ class SimpleDownsample(torch.nn.Module):
if self.extra_proj is not None: if self.extra_proj is not None:
ans2 = self.extra_proj(src) ans2 = self.extra_proj(src)
ans = torch.cat((ans, ans2), dim=2) ans = torch.cat((ans, ans2), dim=2)
ans = ans[..., :self.out_channels]
return ans return ans
@ -941,7 +950,7 @@ class SimpleCombiner(torch.nn.Module):
dtype=src1.dtype)), dtype=src1.dtype)),
dim=-1) dim=-1)
else: else:
src1 = src1[:src2_dim] src1 = src1[...,:src2_dim]
src1 = src1 * weight1 src1 = src1 * weight1
src2 = src2 * (1.0 - weight1) src2 = src2 * (1.0 - weight1)
@ -1917,7 +1926,7 @@ class Conv2dSubsampling(nn.Module):
out_channels: int, out_channels: int,
layer1_channels: int = 8, layer1_channels: int = 8,
layer2_channels: int = 32, layer2_channels: int = 32,
layer3_channels: int = 96, layer3_channels: int = 64,
dropout: FloatLike = 0.1, dropout: FloatLike = 0.1,
) -> None: ) -> None:
""" """