Try to reproduce baseline but with current code with 2 encoder stacks, as a baseline
This commit is contained in:
parent
1005ff35ba
commit
df795912ed
@ -478,9 +478,12 @@ class AttentionDownsample(torch.nn.Module):
|
|||||||
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
|
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
|
||||||
|
|
||||||
# fill in the extra dimensions with a projection of the input
|
# fill in the extra dimensions with a projection of the input
|
||||||
self.extra_proj = nn.Linear(in_channels * downsample,
|
if out_channels > in_channels:
|
||||||
out_channels - in_channels,
|
self.extra_proj = nn.Linear(in_channels * downsample,
|
||||||
bias=False)
|
out_channels - in_channels,
|
||||||
|
bias=False)
|
||||||
|
else:
|
||||||
|
self.extra_proj = None
|
||||||
self.downsample = downsample
|
self.downsample = downsample
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
@ -509,10 +512,10 @@ class AttentionDownsample(torch.nn.Module):
|
|||||||
# ans1 is the first `in_channels` channels of the output
|
# ans1 is the first `in_channels` channels of the output
|
||||||
ans1 = (src * weights).sum(dim=1)
|
ans1 = (src * weights).sum(dim=1)
|
||||||
src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels)
|
src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels)
|
||||||
ans2 = self.extra_proj(src)
|
|
||||||
|
|
||||||
|
if self.extra_proj is not None:
|
||||||
ans = torch.cat((ans1, ans2), dim=2)
|
ans2 = self.extra_proj(src)
|
||||||
|
ans = torch.cat((ans1, ans2), dim=2)
|
||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
@ -551,7 +554,7 @@ class SimpleCombiner(torch.nn.Module):
|
|||||||
dim1: int,
|
dim1: int,
|
||||||
dim2: int):
|
dim2: int):
|
||||||
super(SimpleCombiner, self).__init__()
|
super(SimpleCombiner, self).__init__()
|
||||||
assert dim2 > dim1
|
assert dim2 >= dim1
|
||||||
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01)
|
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01)
|
||||||
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)
|
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)
|
||||||
|
|
||||||
|
|||||||
@ -113,7 +113,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-dims",
|
"--encoder-dims",
|
||||||
type=str,
|
type=str,
|
||||||
default="320,512,512",
|
default="384,384",
|
||||||
help="Attention dimension in 2, blocks of conformer encoder layers, comma separated, "
|
help="Attention dimension in 2, blocks of conformer encoder layers, comma separated, "
|
||||||
"and the output dim of the encoder",
|
"and the output dim of the encoder",
|
||||||
)
|
)
|
||||||
@ -121,7 +121,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--conformer-subsampling-factor",
|
"--conformer-subsampling-factor",
|
||||||
type=int,
|
type=int,
|
||||||
default=4,
|
default=1,
|
||||||
help="Subsampling factor for 2nd stack of encoder layers.",
|
help="Subsampling factor for 2nd stack of encoder layers.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user