Try to reproduce baseline but with current code with 2 encoder stacks, as a baseline

This commit is contained in:
Daniel Povey 2022-09-28 20:56:40 +08:00
parent 1005ff35ba
commit df795912ed
2 changed files with 12 additions and 9 deletions

View File

@ -478,9 +478,12 @@ class AttentionDownsample(torch.nn.Module):
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
# fill in the extra dimensions with a projection of the input
self.extra_proj = nn.Linear(in_channels * downsample,
out_channels - in_channels,
bias=False)
if out_channels > in_channels:
self.extra_proj = nn.Linear(in_channels * downsample,
out_channels - in_channels,
bias=False)
else:
self.extra_proj = None
self.downsample = downsample
def forward(self,
@ -509,10 +512,10 @@ class AttentionDownsample(torch.nn.Module):
# ans1 is the first `in_channels` channels of the output
ans1 = (src * weights).sum(dim=1)
src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels)
ans2 = self.extra_proj(src)
ans = torch.cat((ans1, ans2), dim=2)
if self.extra_proj is not None:
ans2 = self.extra_proj(src)
ans = torch.cat((ans1, ans2), dim=2)
return ans
@ -551,7 +554,7 @@ class SimpleCombiner(torch.nn.Module):
dim1: int,
dim2: int):
super(SimpleCombiner, self).__init__()
assert dim2 > dim1
assert dim2 >= dim1
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01)
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)

View File

@ -113,7 +113,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--encoder-dims",
type=str,
default="320,512,512",
default="384,384",
help="Attention dimension in 2, blocks of conformer encoder layers, comma separated, "
"and the output dim of the encoder",
)
@ -121,7 +121,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--conformer-subsampling-factor",
type=int,
default=4,
default=1,
help="Subsampling factor for 2nd stack of encoder layers.",
)