mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Try to reproduce baseline but with current code with 2 encoder stacks, as a baseline
This commit is contained in:
parent
1005ff35ba
commit
df795912ed
@ -478,9 +478,12 @@ class AttentionDownsample(torch.nn.Module):
|
||||
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5))
|
||||
|
||||
# fill in the extra dimensions with a projection of the input
|
||||
self.extra_proj = nn.Linear(in_channels * downsample,
|
||||
out_channels - in_channels,
|
||||
bias=False)
|
||||
if out_channels > in_channels:
|
||||
self.extra_proj = nn.Linear(in_channels * downsample,
|
||||
out_channels - in_channels,
|
||||
bias=False)
|
||||
else:
|
||||
self.extra_proj = None
|
||||
self.downsample = downsample
|
||||
|
||||
def forward(self,
|
||||
@ -509,10 +512,10 @@ class AttentionDownsample(torch.nn.Module):
|
||||
# ans1 is the first `in_channels` channels of the output
|
||||
ans1 = (src * weights).sum(dim=1)
|
||||
src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels)
|
||||
ans2 = self.extra_proj(src)
|
||||
|
||||
|
||||
ans = torch.cat((ans1, ans2), dim=2)
|
||||
if self.extra_proj is not None:
|
||||
ans2 = self.extra_proj(src)
|
||||
ans = torch.cat((ans1, ans2), dim=2)
|
||||
return ans
|
||||
|
||||
|
||||
@ -551,7 +554,7 @@ class SimpleCombiner(torch.nn.Module):
|
||||
dim1: int,
|
||||
dim2: int):
|
||||
super(SimpleCombiner, self).__init__()
|
||||
assert dim2 > dim1
|
||||
assert dim2 >= dim1
|
||||
self.to_weight1 = nn.Parameter(torch.randn(dim1) * 0.01)
|
||||
self.to_weight2 = nn.Parameter(torch.randn(dim2) * 0.01)
|
||||
|
||||
|
||||
@ -113,7 +113,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--encoder-dims",
|
||||
type=str,
|
||||
default="320,512,512",
|
||||
default="384,384",
|
||||
help="Attention dimension in 2, blocks of conformer encoder layers, comma separated, "
|
||||
"and the output dim of the encoder",
|
||||
)
|
||||
@ -121,7 +121,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--conformer-subsampling-factor",
|
||||
type=int,
|
||||
default=4,
|
||||
default=1,
|
||||
help="Subsampling factor for 2nd stack of encoder layers.",
|
||||
)
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user