mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
More partial work
This commit is contained in:
parent
e51a2c9170
commit
8483ca2e8f
@ -57,6 +57,10 @@ class Subformer(EncoderInterface):
|
|||||||
the whole stack to downsample.)
|
the whole stack to downsample.)
|
||||||
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
|
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
|
||||||
encoder stack (i.e. one per "S" in structure).
|
encoder stack (i.e. one per "S" in structure).
|
||||||
|
encoder_chunk_sizes (Tuple[Tuple[int]]): A tuple containing either one tuple or
|
||||||
|
one tuple per encoder stack. Each element tuple is a list of the chunk sizes
|
||||||
|
that we use during training, e.g. (128, 1024); we go through these round-robin
|
||||||
|
in successive layers.
|
||||||
downsampling_factor (Tuple[int]): downsampling factor for each downsampling
|
downsampling_factor (Tuple[int]): downsampling factor for each downsampling
|
||||||
operation (each open-parenthesis).
|
operation (each open-parenthesis).
|
||||||
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
|
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
|
||||||
@ -85,7 +89,7 @@ class Subformer(EncoderInterface):
|
|||||||
structure: str = "S(S)S",
|
structure: str = "S(S)S",
|
||||||
encoder_dim: Tuple[int, ...] = (384, 512, 384),
|
encoder_dim: Tuple[int, ...] = (384, 512, 384),
|
||||||
downsampling_factor: Tuple[int, ...] = (2,),
|
downsampling_factor: Tuple[int, ...] = (2,),
|
||||||
encoder_chunk_sizes: Tuple[Tuple[int, ...]] = (128,),
|
encoder_chunk_sizes: Tuple[Tuple[int, ...]] = ((128,1024),),
|
||||||
num_encoder_layers: Union[int, Tuple[int, ...]] = (4,),
|
num_encoder_layers: Union[int, Tuple[int, ...]] = (4,),
|
||||||
query_head_dim: Tuple[int, ...] = (24,),
|
query_head_dim: Tuple[int, ...] = (24,),
|
||||||
value_head_dim: Tuple[int, ...] = (12,),
|
value_head_dim: Tuple[int, ...] = (12,),
|
||||||
@ -120,7 +124,7 @@ class Subformer(EncoderInterface):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
self.encoder_dim = encoder_dim
|
self.encoder_dim = encoder_dim
|
||||||
encoder_chunk_size = _to_tuple(encoder_chunk_size)
|
encoder_chunk_sizes = _to_tuple(encoder_chunk_sizes)
|
||||||
num_encoder_layers = _to_tuple(num_encoder_layers)
|
num_encoder_layers = _to_tuple(num_encoder_layers)
|
||||||
query_head_dim = _to_tuple(query_head_dim)
|
query_head_dim = _to_tuple(query_head_dim)
|
||||||
value_head_dim = _to_tuple(value_head_dim)
|
value_head_dim = _to_tuple(value_head_dim)
|
||||||
@ -136,7 +140,17 @@ class Subformer(EncoderInterface):
|
|||||||
# each one will be SubformerEncoder or DownsampledSubformerEncoder
|
# each one will be SubformerEncoder or DownsampledSubformerEncoder
|
||||||
encoders = []
|
encoders = []
|
||||||
downsamplers = []
|
downsamplers = []
|
||||||
bypass = []
|
bypasses = []
|
||||||
|
|
||||||
|
layer_indexes = []
|
||||||
|
|
||||||
|
cur_max_dim = encoder_dim[0]
|
||||||
|
|
||||||
|
downsampling_factors_list = []
|
||||||
|
def cur_downsampling_factor():
|
||||||
|
c = 1
|
||||||
|
for d in downsampling_factors_list: c *= d
|
||||||
|
return c
|
||||||
|
|
||||||
for s in structure:
|
for s in structure:
|
||||||
if s == 'S':
|
if s == 'S':
|
||||||
@ -152,61 +166,45 @@ class Subformer(EncoderInterface):
|
|||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
)
|
)
|
||||||
|
cur_max_dim = max(cur_max_dim, encoder_dim[i])
|
||||||
encoder = SubformerEncoder(
|
encoder = SubformerEncoder(
|
||||||
encoder_layer,
|
encoder_layer,
|
||||||
num_encoder_layers[i],
|
num_encoder_layers[i],
|
||||||
|
embed_dim=cur_max_dim,
|
||||||
dropout=dropout,
|
dropout=dropout,
|
||||||
chunk_size=encoder_chunk_size[i],
|
chunk_sizes=encoder_chunk_sizes[i],
|
||||||
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
|
||||||
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
|
||||||
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
final_layerdrop_rate=0.035 * (cur_downsampling_factor() ** 0.5),
|
||||||
)
|
)
|
||||||
|
layer_indexes.append(len(encoders))
|
||||||
encoders.append(encoder)
|
encoders.append(encoder)
|
||||||
|
|
||||||
pass
|
|
||||||
elif s =='(':
|
elif s =='(':
|
||||||
pass
|
i = len(downsamplers)
|
||||||
|
downsampler = LearnedDownsamplingModule(cur_max_dim,
|
||||||
|
downsampling_factor[i])
|
||||||
|
downsampling_factors_list.append(downsampling_factor[i])
|
||||||
|
layer_indexes.append(len(downsamplers))
|
||||||
|
downsamplers.append(downsampler)
|
||||||
else:
|
else:
|
||||||
assert s == ')'
|
assert s == ')'
|
||||||
|
bypass = BypassModule(cur_max_dim, straight_through_rate=0.0)
|
||||||
|
layer_indexes.append(len(bypasses))
|
||||||
|
bypasses.append(bypass)
|
||||||
|
downsampling_factors_list.pop()
|
||||||
|
|
||||||
|
logging.info(f"cur_downsampling_factor={cur_downsampling_factor()}")
|
||||||
|
|
||||||
num_encoders = len(encoder_dim)
|
self.layer_indexes = layer_indexes
|
||||||
assert num_encoders % 2 == 1
|
self.structure = structure
|
||||||
downsampling_factor = [ 1 ]
|
self.encoders = nn.ModuleList(encoders)
|
||||||
while len(downsampling_factor) < num_encoders:
|
self.downsamplers = nn.ModuleList(downsamplers)
|
||||||
downsampling_factor = [ 1 ] + [ d * 2 for d in downsampling_factor ] + [ 1 ]
|
self.bypasses = nn.ModuleList(bypasses)
|
||||||
|
|
||||||
for i in range(num_encoders):
|
|
||||||
|
|
||||||
|
|
||||||
mid = len(encoders) // 2
|
|
||||||
encoder = DownsampledSubformerEncoder(
|
|
||||||
[ encoders[mid] ],
|
|
||||||
input_num_channels=encoder_dim[mid-1],
|
|
||||||
downsample=2
|
|
||||||
)
|
|
||||||
for i in range(1, mid+1):
|
|
||||||
this_list = [ encoders[mid-i],
|
|
||||||
encoder,
|
|
||||||
encoders[mid+i] ]
|
|
||||||
encoder = DownsampledSubformerEncoder(
|
|
||||||
this_list,
|
|
||||||
input_num_channels=encoder_dim[max(0, mid-i-1)],
|
|
||||||
downsample=2 if i != mid else 1
|
|
||||||
)
|
|
||||||
|
|
||||||
self.encoder = encoder
|
|
||||||
|
|
||||||
self.encoder_pos = CompactRelPositionalEncoding(64, pos_dim,
|
self.encoder_pos = CompactRelPositionalEncoding(64, pos_dim,
|
||||||
dropout_rate=0.15,
|
dropout_rate=0.15,
|
||||||
length_factor=1.0)
|
length_factor=1.0)
|
||||||
|
|
||||||
#self.downsample_output = SimpleDownsample(max(encoder_dim),
|
|
||||||
# downsample=output_downsampling_factor,
|
|
||||||
# dropout=dropout)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -239,7 +237,6 @@ class Subformer(EncoderInterface):
|
|||||||
"""
|
"""
|
||||||
outputs = []
|
outputs = []
|
||||||
|
|
||||||
attn_offset = self._get_attn_offset(x, src_key_padding_mask)
|
|
||||||
|
|
||||||
if self.training and memory is not None:
|
if self.training and memory is not None:
|
||||||
batch_size = x.shape[1]
|
batch_size = x.shape[1]
|
||||||
@ -249,14 +246,42 @@ class Subformer(EncoderInterface):
|
|||||||
memory = memory * (torch.rand(batch_size, 1, device=memory.device) >
|
memory = memory * (torch.rand(batch_size, 1, device=memory.device) >
|
||||||
memory_dropout_rate)
|
memory_dropout_rate)
|
||||||
|
|
||||||
pos_emb = self.encoder_pos(x)
|
attn_offsets = [ self._get_attn_offset(x, src_key_padding_mask) ]
|
||||||
|
pos_embs = [ self.encoder_pos(x) ]
|
||||||
|
downsample_info = []
|
||||||
|
|
||||||
|
for s, i in zip(self.structure, self.layer_indexes):
|
||||||
|
if s == 'S':
|
||||||
|
encoder = self.encoders[i] # one encoder stack
|
||||||
|
x = encoder(x,
|
||||||
|
pos_embs[-1],
|
||||||
|
attn_offset=attn_offsets[-1],
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask)
|
||||||
|
# x will have the maximum dimension up till now, even if
|
||||||
|
# `encoder` uses lower dim in its layers.
|
||||||
|
elif s == '(':
|
||||||
|
downsampler = self.downsamplers[i]
|
||||||
|
|
||||||
|
indexes, weights, x_new = downsampler(x)
|
||||||
|
downsample_info.append((indexes, weights, x))
|
||||||
|
x = x_new
|
||||||
|
|
||||||
|
pos_embs.append(downsampler.downsample_pos_emb(pos_embs[-1], indexes))
|
||||||
|
|
||||||
|
attn_offsets.append(downsampler.downsample_attn_offset(attn_offsets[-1],
|
||||||
|
indexes,
|
||||||
|
weights))
|
||||||
|
|
||||||
|
else:
|
||||||
|
assert s == ')' # upsample
|
||||||
|
indexes, weights, x_orig = downsample_info.pop()
|
||||||
|
_attn_offset = attn_offsets.pop()
|
||||||
|
_pos_emb = pos_embs.pop()
|
||||||
|
x_orig = convert_num_channels(x_orig, x.shape[-1])
|
||||||
|
|
||||||
|
x = LearnedDownsamplingModule.upsample(x_orig, x, indexes, weights)
|
||||||
|
|
||||||
x = self.encoder(x,
|
|
||||||
pos_emb,
|
|
||||||
attn_offset=attn_offset,
|
|
||||||
memory=memory,
|
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
|
||||||
)
|
|
||||||
|
|
||||||
# d = self.output_downsampling_factor
|
# d = self.output_downsampling_factor
|
||||||
# lengths = (x_lens + d - 1) // d
|
# lengths = (x_lens + d - 1) // d
|
||||||
@ -575,6 +600,9 @@ class SubformerEncoder(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
encoder_layer: an instance of the SubformerEncoderLayer() class (required).
|
encoder_layer: an instance of the SubformerEncoderLayer() class (required).
|
||||||
num_layers: the number of sub-encoder-layers in the encoder (required).
|
num_layers: the number of sub-encoder-layers in the encoder (required).
|
||||||
|
embed_dim: the embedding dimension to use for the bypass (may exceed the
|
||||||
|
dimension of encoder_layer, as it may not operate on the full
|
||||||
|
dimension).
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
>>> encoder_layer = SubformerEncoderLayer(embed_dim=512, nhead=8)
|
>>> encoder_layer = SubformerEncoderLayer(embed_dim=512, nhead=8)
|
||||||
@ -586,6 +614,7 @@ class SubformerEncoder(nn.Module):
|
|||||||
self,
|
self,
|
||||||
encoder_layer: nn.Module,
|
encoder_layer: nn.Module,
|
||||||
num_layers: int,
|
num_layers: int,
|
||||||
|
embed_dim: int,
|
||||||
dropout: float,
|
dropout: float,
|
||||||
warmup_begin: float,
|
warmup_begin: float,
|
||||||
warmup_end: float,
|
warmup_end: float,
|
||||||
@ -602,7 +631,7 @@ class SubformerEncoder(nn.Module):
|
|||||||
)
|
)
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
self.bypass = BypassModule(self.embed_dim())
|
self.bypass = BypassModule(embed_dim)
|
||||||
|
|
||||||
assert 0 <= warmup_begin <= warmup_end
|
assert 0 <= warmup_begin <= warmup_end
|
||||||
|
|
||||||
@ -616,7 +645,7 @@ class SubformerEncoder(nn.Module):
|
|||||||
cur_begin = cur_end
|
cur_begin = cur_end
|
||||||
|
|
||||||
def embed_dim(self):
|
def embed_dim(self):
|
||||||
return self.layers[0].embed_dim
|
return self.bypass.embed_dim()
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -644,14 +673,7 @@ class SubformerEncoder(nn.Module):
|
|||||||
|
|
||||||
Returns: a Tensor with the same shape as src.
|
Returns: a Tensor with the same shape as src.
|
||||||
"""
|
"""
|
||||||
src = convert_num_channels(src, self.embed_dim())
|
output = convert_num_channels(src, self.layers[0].embed_dim)
|
||||||
output = src
|
|
||||||
|
|
||||||
rnd_seed = src.numel() + random.randint(0, 1000)
|
|
||||||
|
|
||||||
#if feature_mask is not None:
|
|
||||||
# output = output * feature_mask
|
|
||||||
|
|
||||||
|
|
||||||
chunk_sizes, chunk_indexes = self._get_chunk_sizes(src)
|
chunk_sizes, chunk_indexes = self._get_chunk_sizes(src)
|
||||||
b = src.shape[1] # batch_size
|
b = src.shape[1] # batch_size
|
||||||
@ -678,6 +700,9 @@ class SubformerEncoder(nn.Module):
|
|||||||
|
|
||||||
output = self._to_chunk_size(output, src.shape[0])
|
output = self._to_chunk_size(output, src.shape[0])
|
||||||
|
|
||||||
|
output = convert_num_channels(output, self.bypass.embed_dim())
|
||||||
|
src = convert_num_channels(src, self.bypass.embed_dim())
|
||||||
|
|
||||||
return self.bypass(src, output)
|
return self.bypass(src, output)
|
||||||
|
|
||||||
def _get_chunk_sizes(self, src: Tensor) -> Tuple[List[int], List[int]]:
|
def _get_chunk_sizes(self, src: Tensor) -> Tuple[List[int], List[int]]:
|
||||||
@ -784,6 +809,8 @@ class BypassModule(nn.Module):
|
|||||||
self.scale_min = copy.deepcopy(scale_min)
|
self.scale_min = copy.deepcopy(scale_min)
|
||||||
self.scale_max = copy.deepcopy(scale_max)
|
self.scale_max = copy.deepcopy(scale_max)
|
||||||
|
|
||||||
|
def embed_dim(self):
|
||||||
|
return self.bypass_scale.numel()
|
||||||
|
|
||||||
def _get_bypass_scale(self, batch_size: int):
|
def _get_bypass_scale(self, batch_size: int):
|
||||||
# returns bypass-scale of shape (num_channels,),
|
# returns bypass-scale of shape (num_channels,),
|
||||||
@ -840,7 +867,7 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
|
self.to_scores = nn.Linear(embed_dim, 1, bias=False)
|
||||||
self.to_scores.lr_factor = 0.5
|
self.to_scores.lr_scale = 0.5
|
||||||
# score_balancer is just to keep the magnitudes of the scores in
|
# score_balancer is just to keep the magnitudes of the scores in
|
||||||
# a fixed range and keep them balanced around zero, to stop
|
# a fixed range and keep them balanced around zero, to stop
|
||||||
# these drifting around.
|
# these drifting around.
|
||||||
@ -1028,7 +1055,9 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
return attn_offset
|
return attn_offset
|
||||||
|
|
||||||
|
|
||||||
def upsample(self, x_orig: Tensor, x: Tensor, indexes: Tensor) -> Tensor:
|
@staticmethod
|
||||||
|
def upsample(x_orig: Tensor, x: Tensor, indexes: Tensor,
|
||||||
|
weights: Optional[Tensor] = None) -> Tensor:
|
||||||
"""
|
"""
|
||||||
Upsamples, reversing the downsample() operation and filling in
|
Upsamples, reversing the downsample() operation and filling in
|
||||||
any not-chosen frames with their original value before downsampling
|
any not-chosen frames with their original value before downsampling
|
||||||
@ -1038,30 +1067,40 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
x_orig: (seq_len, batch_size, num_channels)
|
x_orig: (seq_len, batch_size, num_channels)
|
||||||
x: (seq_len_reduced, batch_size, num_channels)
|
x: (seq_len_reduced, batch_size, num_channels)
|
||||||
indexes: (batch_size, seq_len_reduced), contains original frame indexes
|
indexes: (batch_size, seq_len_reduced), contains original frame indexes
|
||||||
|
weights: optional tensor
|
||||||
|
|
||||||
Downsamples x via indexing with the indexes obtained from the
|
Downsamples x via indexing with the indexes obtained from the
|
||||||
forward() function.
|
forward() function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x: tensor of shape (seq_len, batch_size, indexes)
|
x: tensor of shape (seq_len, batch_size, indexes)
|
||||||
|
weights: a tensor of shape (batch_size, seq_len_reduced) containing weights between
|
||||||
|
0 and 1, where 1 means fully use this x value and 0 means use x_orig
|
||||||
indexes: integer indexes of shape (batch_size, seq_len_reduced), with elements
|
indexes: integer indexes of shape (batch_size, seq_len_reduced), with elements
|
||||||
0 <= indexes < seq_len.
|
0 <= indexes < seq_len.
|
||||||
"""
|
"""
|
||||||
(seq_len, batch_size, num_channels) = x_orig.shape
|
(seq_len, batch_size, num_channels) = x_orig.shape
|
||||||
|
|
||||||
not_kept = torch.ones(batch_size, seq_len, dtype=torch.bool,
|
x_weight = 1.0 if weights is None else weights.t().unsqueeze(-1)
|
||||||
device=x.device)
|
# x_weight: (seq_len_reduced, batch_size, 1) if a tensor
|
||||||
not_kept.scatter_(dim=1, index=indexes, value=False)
|
|
||||||
|
orig_x_weight = torch.ones(batch_size, seq_len,
|
||||||
|
device=x.device, dtype=x.dtype)
|
||||||
|
if weights is None:
|
||||||
|
orig_x_weight.scatter_(dim=1, index=indexes, value=0.)
|
||||||
|
else:
|
||||||
|
orig_x_weight.scatter_(dim=1, index=indexes,
|
||||||
|
src=(1. - weights).to(x.dtype))
|
||||||
|
|
||||||
indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels)
|
indexes = indexes.t().unsqueeze(-1).expand(-1, batch_size, num_channels)
|
||||||
# indexes now: (seq_len_reduced, batch_size, num_channels)
|
# indexes now: (seq_len_reduced, batch_size, num_channels)
|
||||||
|
|
||||||
ans = torch.zeros_like(x_orig)
|
ans = torch.zeros_like(x_orig)
|
||||||
|
|
||||||
ans.scatter_(dim=0, index=indexes, src=x)
|
ans.scatter_(dim=0, index=indexes, src=(x * x_weight))
|
||||||
|
|
||||||
# add in x_orig in the frames that were not originally kept.
|
# add in x_orig in the frames that were not originally kept.
|
||||||
return ans + x_orig * not_kept.t().unsqueeze(-1)
|
return ans + x_orig * orig_x_weight.t().unsqueeze(-1)
|
||||||
|
|
||||||
|
|
||||||
class DownsampledSubformerEncoder(nn.Module):
|
class DownsampledSubformerEncoder(nn.Module):
|
||||||
@ -1151,7 +1190,7 @@ class DownsampledSubformerEncoder(nn.Module):
|
|||||||
src_orig = convert_num_channels(src_orig, src.shape[-1])
|
src_orig = convert_num_channels(src_orig, src.shape[-1])
|
||||||
|
|
||||||
if hasattr(self, 'downsampler'):
|
if hasattr(self, 'downsampler'):
|
||||||
src = self.downsampler.upsample(src_orig, src, indexes)
|
src = self.downsampler.upsample(src_orig, src, indexes, weights)
|
||||||
|
|
||||||
return self.out_combiner(src_orig, src)
|
return self.out_combiner(src_orig, src)
|
||||||
|
|
||||||
|
|||||||
@ -118,6 +118,7 @@ def set_batch_count(
|
|||||||
|
|
||||||
|
|
||||||
def add_model_arguments(parser: argparse.ArgumentParser):
|
def add_model_arguments(parser: argparse.ArgumentParser):
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-encoder-layers",
|
"--num-encoder-layers",
|
||||||
type=str,
|
type=str,
|
||||||
@ -147,13 +148,21 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-chunk-size",
|
"--encoder-chunk-sizes",
|
||||||
type=str,
|
type=str,
|
||||||
default="128",
|
default="128,1024",
|
||||||
help="Base chunk size for attention in encoder stacks; alternate layers will use this value or "
|
help="Base chunk size for attention in encoder stacks; alternate layers will use this value or "
|
||||||
"double this value."
|
"double this value."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--encoder-structure",
|
||||||
|
type=str,
|
||||||
|
default="S(S(S(S)S)S)S",
|
||||||
|
help="Structure of encoder, determines order of encoder stacks and (downsampling/upsampling) "
|
||||||
|
"operations."
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--query-head-dim",
|
"--query-head-dim",
|
||||||
type=str,
|
type=str,
|
||||||
@ -421,9 +430,10 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
|||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
#chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
|
#chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
|
||||||
encoder = Subformer(
|
encoder = Subformer(
|
||||||
|
structure=params.encoder_structure,
|
||||||
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
||||||
encoder_dim=_to_int_tuple(params.encoder_dim),
|
encoder_dim=_to_int_tuple(params.encoder_dim),
|
||||||
encoder_chunk_size=_to_int_tuple(params.encoder_chunk_size),
|
encoder_chunk_sizes=(_to_int_tuple(params.encoder_chunk_sizes),),
|
||||||
query_head_dim=_to_int_tuple(params.query_head_dim),
|
query_head_dim=_to_int_tuple(params.query_head_dim),
|
||||||
pos_dim=int(params.pos_dim),
|
pos_dim=int(params.pos_dim),
|
||||||
value_head_dim=_to_int_tuple(params.value_head_dim),
|
value_head_dim=_to_int_tuple(params.value_head_dim),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user