mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
add random combiner for training deeper model
This commit is contained in:
parent
8bd700cff2
commit
9bb0c7988f
@ -116,7 +116,7 @@ from beam_search import (
|
|||||||
greedy_search_batch,
|
greedy_search_batch,
|
||||||
modified_beam_search,
|
modified_beam_search,
|
||||||
)
|
)
|
||||||
from train import get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -302,6 +302,8 @@ def get_parser():
|
|||||||
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -354,13 +356,6 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
# feature_lens += params.left_context
|
|
||||||
# feature = torch.nn.functional.pad(
|
|
||||||
# feature,
|
|
||||||
# pad=(0, 0, 0, params.left_context),
|
|
||||||
# value=LOG_EPS,
|
|
||||||
# )
|
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
x=feature, x_lens=feature_lens
|
x=feature, x_lens=feature_lens
|
||||||
)
|
)
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from typing import Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
@ -47,6 +47,9 @@ class RNN(EncoderInterface):
|
|||||||
Dropout rate (default=0.1).
|
Dropout rate (default=0.1).
|
||||||
layer_dropout (float):
|
layer_dropout (float):
|
||||||
Dropout value for model-level warmup (default=0.075).
|
Dropout value for model-level warmup (default=0.075).
|
||||||
|
aux_layer_period (int):
|
||||||
|
Peroid of auxiliary layers used for randomly combined during training.
|
||||||
|
If not larger than 0, will not use the random combiner.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -58,6 +61,7 @@ class RNN(EncoderInterface):
|
|||||||
num_encoder_layers: int = 12,
|
num_encoder_layers: int = 12,
|
||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
layer_dropout: float = 0.075,
|
layer_dropout: float = 0.075,
|
||||||
|
aux_layer_period: int = 3,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(RNN, self).__init__()
|
super(RNN, self).__init__()
|
||||||
|
|
||||||
@ -79,7 +83,19 @@ class RNN(EncoderInterface):
|
|||||||
encoder_layer = RNNEncoderLayer(
|
encoder_layer = RNNEncoderLayer(
|
||||||
d_model, dim_feedforward, dropout, layer_dropout
|
d_model, dim_feedforward, dropout, layer_dropout
|
||||||
)
|
)
|
||||||
self.encoder = RNNEncoder(encoder_layer, num_encoder_layers)
|
self.encoder = RNNEncoder(
|
||||||
|
encoder_layer,
|
||||||
|
num_encoder_layers,
|
||||||
|
aux_layers=list(
|
||||||
|
range(
|
||||||
|
num_encoder_layers // 3,
|
||||||
|
num_encoder_layers - 1,
|
||||||
|
aux_layer_period,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if aux_layer_period > 0
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
|
||||||
@ -306,13 +322,31 @@ class RNNEncoder(nn.Module):
|
|||||||
The number of sub-encoder-layers in the encoder (required).
|
The number of sub-encoder-layers in the encoder (required).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_layer: nn.Module,
|
||||||
|
num_layers: int,
|
||||||
|
aux_layers: Optional[List[int]] = None,
|
||||||
|
) -> None:
|
||||||
super(RNNEncoder, self).__init__()
|
super(RNNEncoder, self).__init__()
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||||
)
|
)
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
self.use_random_combiner = False
|
||||||
|
if aux_layers is not None:
|
||||||
|
assert len(set(aux_layers)) == len(aux_layers)
|
||||||
|
assert num_layers - 1 not in aux_layers
|
||||||
|
self.use_random_combiner = True
|
||||||
|
self.aux_layers = aux_layers + [num_layers - 1]
|
||||||
|
self.combiner = RandomCombine(
|
||||||
|
num_inputs=len(self.aux_layers),
|
||||||
|
final_weight=0.5,
|
||||||
|
pure_prob=0.333,
|
||||||
|
stddev=2.0,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, src: torch.Tensor, warmup: float = 1.0) -> torch.Tensor:
|
def forward(self, src: torch.Tensor, warmup: float = 1.0) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Pass the input through the encoder layer in turn.
|
Pass the input through the encoder layer in turn.
|
||||||
@ -328,8 +362,16 @@ class RNNEncoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
for layer_index, mod in enumerate(self.layers):
|
outputs = []
|
||||||
|
|
||||||
|
for i, mod in enumerate(self.layers):
|
||||||
output = mod(output, warmup=warmup)
|
output = mod(output, warmup=warmup)
|
||||||
|
if self.use_random_combiner:
|
||||||
|
if i in self.aux_layers:
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
if self.use_random_combiner:
|
||||||
|
output = self.combiner(outputs)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@ -459,6 +501,244 @@ class Conv2dSubsampling(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class RandomCombine(nn.Module):
|
||||||
|
"""
|
||||||
|
This module combines a list of Tensors, all with the same shape, to
|
||||||
|
produce a single output of that same shape which, in training time,
|
||||||
|
is a random combination of all the inputs; but which in test time
|
||||||
|
will be just the last input.
|
||||||
|
|
||||||
|
The idea is that the list of Tensors will be a list of outputs of multiple
|
||||||
|
conformer layers. This has a similar effect as iterated loss. (See:
|
||||||
|
DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER
|
||||||
|
NETWORKS).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_inputs: int,
|
||||||
|
final_weight: float = 0.5,
|
||||||
|
pure_prob: float = 0.5,
|
||||||
|
stddev: float = 2.0,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
num_inputs:
|
||||||
|
The number of tensor inputs, which equals the number of layers'
|
||||||
|
outputs that are fed into this module. E.g. in an 18-layer neural
|
||||||
|
net if we output layers 16, 12, 18, num_inputs would be 3.
|
||||||
|
final_weight:
|
||||||
|
The amount of weight or probability we assign to the
|
||||||
|
final layer when randomly choosing layers or when choosing
|
||||||
|
continuous layer weights.
|
||||||
|
pure_prob:
|
||||||
|
The probability, on each frame, with which we choose
|
||||||
|
only a single layer to output (rather than an interpolation)
|
||||||
|
stddev:
|
||||||
|
A standard deviation that we add to log-probs for computing
|
||||||
|
randomized weights.
|
||||||
|
|
||||||
|
The method of choosing which layers, or combinations of layers, to use,
|
||||||
|
is conceptually as follows::
|
||||||
|
|
||||||
|
With probability `pure_prob`::
|
||||||
|
With probability `final_weight`: choose final layer,
|
||||||
|
Else: choose random non-final layer.
|
||||||
|
Else::
|
||||||
|
Choose initial log-weights that correspond to assigning
|
||||||
|
weight `final_weight` to the final layer and equal
|
||||||
|
weights to other layers; then add Gaussian noise
|
||||||
|
with variance `stddev` to these log-weights, and normalize
|
||||||
|
to weights (note: the average weight assigned to the
|
||||||
|
final layer here will not be `final_weight` if stddev>0).
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert 0 <= pure_prob <= 1, pure_prob
|
||||||
|
assert 0 < final_weight < 1, final_weight
|
||||||
|
assert num_inputs >= 1
|
||||||
|
|
||||||
|
self.num_inputs = num_inputs
|
||||||
|
self.final_weight = final_weight
|
||||||
|
self.pure_prob = pure_prob
|
||||||
|
self.stddev = stddev
|
||||||
|
|
||||||
|
self.final_log_weight = (
|
||||||
|
torch.tensor(
|
||||||
|
(final_weight / (1 - final_weight)) * (self.num_inputs - 1)
|
||||||
|
)
|
||||||
|
.log()
|
||||||
|
.item()
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
|
||||||
|
"""Forward function.
|
||||||
|
Args:
|
||||||
|
inputs:
|
||||||
|
A list of Tensor, e.g. from various layers of a transformer.
|
||||||
|
All must be the same shape, of (*, num_channels)
|
||||||
|
Returns:
|
||||||
|
A Tensor of shape (*, num_channels). In test mode
|
||||||
|
this is just the final input.
|
||||||
|
"""
|
||||||
|
num_inputs = self.num_inputs
|
||||||
|
assert len(inputs) == num_inputs
|
||||||
|
if not self.training or torch.jit.is_scripting():
|
||||||
|
return inputs[-1]
|
||||||
|
|
||||||
|
# Shape of weights: (*, num_inputs)
|
||||||
|
num_channels = inputs[0].shape[-1]
|
||||||
|
num_frames = inputs[0].numel() // num_channels
|
||||||
|
|
||||||
|
ndim = inputs[0].ndim
|
||||||
|
# stacked_inputs: (num_frames, num_channels, num_inputs)
|
||||||
|
stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
|
||||||
|
(num_frames, num_channels, num_inputs)
|
||||||
|
)
|
||||||
|
|
||||||
|
# weights: (num_frames, num_inputs)
|
||||||
|
weights = self._get_random_weights(
|
||||||
|
inputs[0].dtype, inputs[0].device, num_frames
|
||||||
|
)
|
||||||
|
|
||||||
|
weights = weights.reshape(num_frames, num_inputs, 1)
|
||||||
|
# ans: (num_frames, num_channels, 1)
|
||||||
|
ans = torch.matmul(stacked_inputs, weights)
|
||||||
|
# ans: (*, num_channels)
|
||||||
|
|
||||||
|
ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,))
|
||||||
|
|
||||||
|
# The following if causes errors for torch script in torch 1.6.0
|
||||||
|
# if __name__ == "__main__":
|
||||||
|
# # for testing only...
|
||||||
|
# print("Weights = ", weights.reshape(num_frames, num_inputs))
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def _get_random_weights(
|
||||||
|
self, dtype: torch.dtype, device: torch.device, num_frames: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Return a tensor of random weights, of shape
|
||||||
|
`(num_frames, self.num_inputs)`,
|
||||||
|
Args:
|
||||||
|
dtype:
|
||||||
|
The data-type desired for the answer, e.g. float, double.
|
||||||
|
device:
|
||||||
|
The device needed for the answer.
|
||||||
|
num_frames:
|
||||||
|
The number of sets of weights desired
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (num_frames, self.num_inputs), such that
|
||||||
|
`ans.sum(dim=1)` is all ones.
|
||||||
|
"""
|
||||||
|
pure_prob = self.pure_prob
|
||||||
|
if pure_prob == 0.0:
|
||||||
|
return self._get_random_mixed_weights(dtype, device, num_frames)
|
||||||
|
elif pure_prob == 1.0:
|
||||||
|
return self._get_random_pure_weights(dtype, device, num_frames)
|
||||||
|
else:
|
||||||
|
p = self._get_random_pure_weights(dtype, device, num_frames)
|
||||||
|
m = self._get_random_mixed_weights(dtype, device, num_frames)
|
||||||
|
return torch.where(
|
||||||
|
torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_random_pure_weights(
|
||||||
|
self, dtype: torch.dtype, device: torch.device, num_frames: int
|
||||||
|
):
|
||||||
|
"""Return a tensor of random one-hot weights, of shape
|
||||||
|
`(num_frames, self.num_inputs)`,
|
||||||
|
Args:
|
||||||
|
dtype:
|
||||||
|
The data-type desired for the answer, e.g. float, double.
|
||||||
|
device:
|
||||||
|
The device needed for the answer.
|
||||||
|
num_frames:
|
||||||
|
The number of sets of weights desired.
|
||||||
|
Returns:
|
||||||
|
A one-hot tensor of shape `(num_frames, self.num_inputs)`, with
|
||||||
|
exactly one weight equal to 1.0 on each frame.
|
||||||
|
"""
|
||||||
|
final_prob = self.final_weight
|
||||||
|
|
||||||
|
# final contains self.num_inputs - 1 in all elements
|
||||||
|
final = torch.full((num_frames,), self.num_inputs - 1, device=device)
|
||||||
|
# nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa
|
||||||
|
nonfinal = torch.randint(
|
||||||
|
self.num_inputs - 1, (num_frames,), device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
indexes = torch.where(
|
||||||
|
torch.rand(num_frames, device=device) < final_prob, final, nonfinal
|
||||||
|
)
|
||||||
|
ans = torch.nn.functional.one_hot(
|
||||||
|
indexes, num_classes=self.num_inputs
|
||||||
|
).to(dtype=dtype)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def _get_random_mixed_weights(
|
||||||
|
self, dtype: torch.dtype, device: torch.device, num_frames: int
|
||||||
|
):
|
||||||
|
"""Return a tensor of random one-hot weights, of shape
|
||||||
|
`(num_frames, self.num_inputs)`,
|
||||||
|
Args:
|
||||||
|
dtype:
|
||||||
|
The data-type desired for the answer, e.g. float, double.
|
||||||
|
device:
|
||||||
|
The device needed for the answer.
|
||||||
|
num_frames:
|
||||||
|
The number of sets of weights desired.
|
||||||
|
Returns:
|
||||||
|
A tensor of shape (num_frames, self.num_inputs), which elements
|
||||||
|
in [0..1] that sum to one over the second axis, i.e.
|
||||||
|
`ans.sum(dim=1)` is all ones.
|
||||||
|
"""
|
||||||
|
logprobs = (
|
||||||
|
torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device)
|
||||||
|
* self.stddev
|
||||||
|
)
|
||||||
|
logprobs[:, -1] += self.final_log_weight
|
||||||
|
return logprobs.softmax(dim=1)
|
||||||
|
|
||||||
|
|
||||||
|
def _test_random_combine(final_weight: float, pure_prob: float, stddev: float):
|
||||||
|
print(
|
||||||
|
f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" # noqa
|
||||||
|
)
|
||||||
|
num_inputs = 3
|
||||||
|
num_channels = 50
|
||||||
|
m = RandomCombine(
|
||||||
|
num_inputs=num_inputs,
|
||||||
|
final_weight=final_weight,
|
||||||
|
pure_prob=pure_prob,
|
||||||
|
stddev=stddev,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)]
|
||||||
|
|
||||||
|
y = m(x)
|
||||||
|
assert y.shape == x[0].shape
|
||||||
|
assert torch.allclose(y, x[0]) # .. since actually all ones.
|
||||||
|
|
||||||
|
|
||||||
|
def _test_random_combine_main():
|
||||||
|
_test_random_combine(0.999, 0, 0.0)
|
||||||
|
_test_random_combine(0.5, 0, 0.0)
|
||||||
|
_test_random_combine(0.999, 0, 0.0)
|
||||||
|
_test_random_combine(0.5, 0, 0.3)
|
||||||
|
_test_random_combine(0.5, 1, 0.3)
|
||||||
|
_test_random_combine(0.5, 0.5, 0.3)
|
||||||
|
|
||||||
|
feature_dim = 50
|
||||||
|
c = RNN(num_features=feature_dim, d_model=128)
|
||||||
|
batch_size = 5
|
||||||
|
seq_len = 20
|
||||||
|
# Just make sure the forward pass runs.
|
||||||
|
f = c(
|
||||||
|
torch.randn(batch_size, seq_len, feature_dim),
|
||||||
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
|
)
|
||||||
|
f # to remove flake8 warnings
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
feature_dim = 50
|
feature_dim = 50
|
||||||
m = RNN(num_features=feature_dim, d_model=128)
|
m = RNN(num_features=feature_dim, d_model=128)
|
||||||
@ -470,3 +750,5 @@ if __name__ == "__main__":
|
|||||||
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
torch.full((batch_size,), seq_len, dtype=torch.int64),
|
||||||
warmup=0.5,
|
warmup=0.5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
_test_random_combine_main()
|
||||||
|
@ -42,7 +42,7 @@ from decode_stream import DecodeStream
|
|||||||
from kaldifeat import Fbank, FbankOptions
|
from kaldifeat import Fbank, FbankOptions
|
||||||
from lhotse import CutSet
|
from lhotse import CutSet
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
from train import get_params, get_transducer_model
|
from train import add_model_arguments, get_params, get_transducer_model
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
average_checkpoints,
|
average_checkpoints,
|
||||||
@ -177,6 +177,8 @@ def get_parser():
|
|||||||
help="The number of streams that can be decoded parallel.",
|
help="The number of streams that can be decoded parallel.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -434,9 +436,7 @@ def decode_dataset(
|
|||||||
decode_results = []
|
decode_results = []
|
||||||
# Contain decode streams currently running.
|
# Contain decode streams currently running.
|
||||||
decode_streams = []
|
decode_streams = []
|
||||||
initial_states = model.encoder.get_init_state(
|
initial_states = model.encoder.get_init_states(device=device)
|
||||||
params.left_context, device=device
|
|
||||||
)
|
|
||||||
for num, cut in enumerate(cuts):
|
for num, cut in enumerate(cuts):
|
||||||
# each utterance has a DecodeStream.
|
# each utterance has a DecodeStream.
|
||||||
decode_stream = DecodeStream(
|
decode_stream = DecodeStream(
|
||||||
|
@ -57,12 +57,12 @@ import torch
|
|||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from lstm import RNN
|
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
|
from lstm import RNN
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -86,6 +86,24 @@ LRSchedulerType = Union[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-encoder-layers",
|
||||||
|
type=int,
|
||||||
|
default=20,
|
||||||
|
help="Number of RNN encoder layers..",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--aux-layer-period",
|
||||||
|
type=int,
|
||||||
|
default=3,
|
||||||
|
help="""Peroid of auxiliary layers used for randomly combined during training.
|
||||||
|
If not larger than 0, will not use the random combiner.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
@ -279,6 +297,8 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
add_model_arguments(parser)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -341,7 +361,6 @@ def get_params() -> AttributeDict:
|
|||||||
"subsampling_factor": 4,
|
"subsampling_factor": 4,
|
||||||
"encoder_dim": 512,
|
"encoder_dim": 512,
|
||||||
"dim_feedforward": 2048,
|
"dim_feedforward": 2048,
|
||||||
"num_encoder_layers": 12,
|
|
||||||
# parameters for decoder
|
# parameters for decoder
|
||||||
"decoder_dim": 512,
|
"decoder_dim": 512,
|
||||||
# parameters for joiner
|
# parameters for joiner
|
||||||
@ -363,6 +382,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
d_model=params.encoder_dim,
|
d_model=params.encoder_dim,
|
||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
|
aux_layer_period=params.aux_layer_period,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user