mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Code cleanup
This commit is contained in:
parent
671e9ee5bd
commit
0a76215fd7
@ -58,9 +58,6 @@ class Subformer(EncoderInterface):
|
|||||||
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
|
encoder_dim (Tuple[int]): embedding dimension of each of the encoder stacks, one per
|
||||||
encoder stack.
|
encoder stack.
|
||||||
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
|
num_encoder_layers (int or Tuple[int])): number of encoder layers for each stack
|
||||||
encoder_unmasked_dim (int or Tuple[int]): unmasked dimension in each of
|
|
||||||
the encoder stacks for purposes of per-frame dropout (recommend 256 for
|
|
||||||
now).
|
|
||||||
query_head_dim (int or Tuple[int]): dimension of query and key per attention
|
query_head_dim (int or Tuple[int]): dimension of query and key per attention
|
||||||
head: per stack, if a tuple..
|
head: per stack, if a tuple..
|
||||||
value_head_dim (int or Tuple[int]): dimension of value in each attention head
|
value_head_dim (int or Tuple[int]): dimension of value in each attention head
|
||||||
@ -83,11 +80,8 @@ class Subformer(EncoderInterface):
|
|||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
output_downsampling_factor: int = 2,
|
encoder_dim: Union[int, Tuple[int]] = (384, 512, 384),
|
||||||
downsampling_factor: Tuple[int] = (2, 4),
|
|
||||||
encoder_dim: Union[int, Tuple[int]] = 384,
|
|
||||||
num_encoder_layers: Union[int, Tuple[int]] = 4,
|
num_encoder_layers: Union[int, Tuple[int]] = 4,
|
||||||
encoder_unmasked_dim: Union[int, Tuple[int]] = 256,
|
|
||||||
query_head_dim: Union[int, Tuple[int]] = 24,
|
query_head_dim: Union[int, Tuple[int]] = 24,
|
||||||
value_head_dim: Union[int, Tuple[int]] = 12,
|
value_head_dim: Union[int, Tuple[int]] = 12,
|
||||||
num_heads: Union[int, Tuple[int]] = 8,
|
num_heads: Union[int, Tuple[int]] = 8,
|
||||||
@ -106,19 +100,16 @@ class Subformer(EncoderInterface):
|
|||||||
|
|
||||||
def _to_tuple(x):
|
def _to_tuple(x):
|
||||||
""" Converts a single int or a 1-tuple of an int to a tuple with the same length
|
""" Converts a single int or a 1-tuple of an int to a tuple with the same length
|
||||||
as downsampling_factor"""
|
as encoder_dim"""
|
||||||
if isinstance(x, int):
|
if isinstance(x, int):
|
||||||
x = (x,)
|
x = (x,)
|
||||||
if len(x) == 1:
|
if len(x) == 1:
|
||||||
x = x * len(downsampling_factor)
|
x = x * len(encoder_dim)
|
||||||
else:
|
else:
|
||||||
assert len(x) == len(downsampling_factor) and isinstance(x[0], int)
|
assert len(x) == len(encoder_dim) and isinstance(x[0], int)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
self.output_downsampling_factor = output_downsampling_factor # int
|
self.encoder_dim = encoder_dim
|
||||||
self.downsampling_factor = downsampling_factor # tuple
|
|
||||||
self.encoder_dim = encoder_dim = _to_tuple(encoder_dim) # tuple
|
|
||||||
self.encoder_unmasked_dim = encoder_unmasked_dim = _to_tuple(encoder_unmasked_dim) # tuple
|
|
||||||
num_encoder_layers = _to_tuple(num_encoder_layers)
|
num_encoder_layers = _to_tuple(num_encoder_layers)
|
||||||
query_head_dim = _to_tuple(query_head_dim)
|
query_head_dim = _to_tuple(query_head_dim)
|
||||||
value_head_dim = _to_tuple(value_head_dim)
|
value_head_dim = _to_tuple(value_head_dim)
|
||||||
@ -126,13 +117,17 @@ class Subformer(EncoderInterface):
|
|||||||
feedforward_dim = _to_tuple(feedforward_dim)
|
feedforward_dim = _to_tuple(feedforward_dim)
|
||||||
self.causal = causal
|
self.causal = causal
|
||||||
|
|
||||||
for u,d in zip(encoder_unmasked_dim, encoder_dim):
|
|
||||||
assert u <= d
|
|
||||||
|
|
||||||
# each one will be SubformerEncoder or DownsampledSubformerEncoder
|
# each one will be SubformerEncoder or DownsampledSubformerEncoder
|
||||||
encoders = []
|
encoders = []
|
||||||
|
|
||||||
num_encoders = len(downsampling_factor)
|
|
||||||
|
num_encoders = len(encoder_dim)
|
||||||
|
assert num_encoders % 2 == 1
|
||||||
|
downsampling_factor = [ 1 ]
|
||||||
|
while len(downsampling_factor) < num_encoders:
|
||||||
|
downsampling_factor = [ 1 ] + [ d * 2 for d in downsampling_factor ] + [ 1 ]
|
||||||
|
|
||||||
for i in range(num_encoders):
|
for i in range(num_encoders):
|
||||||
|
|
||||||
encoder_layer = SubformerEncoderLayer(
|
encoder_layer = SubformerEncoderLayer(
|
||||||
@ -158,83 +153,36 @@ class Subformer(EncoderInterface):
|
|||||||
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
final_layerdrop_rate=0.035 * (downsampling_factor[i] ** 0.5),
|
||||||
)
|
)
|
||||||
|
|
||||||
if downsampling_factor[i] != 1:
|
|
||||||
encoder = DownsampledSubformerEncoder(
|
|
||||||
encoder,
|
|
||||||
dim=encoder_dim[i],
|
|
||||||
downsample=downsampling_factor[i],
|
|
||||||
dropout=dropout,
|
|
||||||
)
|
|
||||||
|
|
||||||
encoders.append(encoder)
|
encoders.append(encoder)
|
||||||
|
|
||||||
|
mid = len(encoders) // 2
|
||||||
|
encoder = DownsampledSubformerEncoder(
|
||||||
|
[ encoders[mid] ],
|
||||||
|
input_num_channels=encoder_dim[mid],
|
||||||
|
downsample=2
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder = encoders[mid]
|
||||||
|
for i in range(mid-1, -1, -1):
|
||||||
|
this_list = [ encoders[mid-i],
|
||||||
|
encoder,
|
||||||
|
encoders[mid+i] ]
|
||||||
|
encoder = DownsampledSubformerEncoder(
|
||||||
|
this_list,
|
||||||
|
input_num_channels=encoder_dim[max(0, mid-2)],
|
||||||
|
downsample=2,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.encoder = encoder
|
||||||
|
|
||||||
self.encoder_pos = CompactRelPositionalEncoding(64, pos_dim,
|
self.encoder_pos = CompactRelPositionalEncoding(64, pos_dim,
|
||||||
dropout_rate=0.15,
|
dropout_rate=0.15,
|
||||||
length_factor=1.0)
|
length_factor=1.0)
|
||||||
|
|
||||||
self.encoders = nn.ModuleList(encoders)
|
|
||||||
|
|
||||||
#self.downsample_output = SimpleDownsample(max(encoder_dim),
|
#self.downsample_output = SimpleDownsample(max(encoder_dim),
|
||||||
# downsample=output_downsampling_factor,
|
# downsample=output_downsampling_factor,
|
||||||
# dropout=dropout)
|
# dropout=dropout)
|
||||||
|
|
||||||
def get_feature_masks(
|
|
||||||
self,
|
|
||||||
x: torch.Tensor) -> List[Union[float, Tensor]]:
|
|
||||||
"""
|
|
||||||
In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
|
|
||||||
randomized feature masks, one per encoder.
|
|
||||||
On e.g. 15% of frames, these masks will zero out all enocder dims larger than
|
|
||||||
some supplied number, e.g. >256, so in effect on those frames we are using
|
|
||||||
a smaller encoer dim.
|
|
||||||
|
|
||||||
We generate the random masks at this level because we want the 2 masks to 'agree'
|
|
||||||
all the way up the encoder stack. This will mean that the 1st mask will have
|
|
||||||
mask values repeated self.zipformer_subsampling_factor times.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x: the embeddings (needed for the shape and dtype and device), of shape
|
|
||||||
(1, batch_size, encoder_dims0)
|
|
||||||
"""
|
|
||||||
num_encoders = len(self.encoder_dim)
|
|
||||||
if not self.training:
|
|
||||||
return [ 1.0 ] * num_encoders
|
|
||||||
|
|
||||||
(num_frames0, batch_size, _encoder_dims0) = x.shape
|
|
||||||
|
|
||||||
assert self.encoder_dim[0] == _encoder_dims0
|
|
||||||
|
|
||||||
feature_mask_dropout_prob = 0.125
|
|
||||||
|
|
||||||
# mask1 shape: (1, batch_size, 1)
|
|
||||||
mask1 = (torch.rand(1, batch_size, 1,
|
|
||||||
device=x.device) >
|
|
||||||
feature_mask_dropout_prob).to(x.dtype)
|
|
||||||
|
|
||||||
# mask2 has additional sequences masked, about twice the number.
|
|
||||||
mask2 = torch.logical_and(mask1,
|
|
||||||
(torch.rand(1, batch_size, 1,
|
|
||||||
device=x.device) >
|
|
||||||
feature_mask_dropout_prob).to(x.dtype))
|
|
||||||
|
|
||||||
|
|
||||||
# dim: (1, batch_size, 2)
|
|
||||||
mask = torch.cat((mask1, mask2), dim=-1)
|
|
||||||
|
|
||||||
feature_masks = []
|
|
||||||
for i in range(num_encoders):
|
|
||||||
channels = self.encoder_dim[i]
|
|
||||||
feature_mask = torch.ones(1, batch_size, channels,
|
|
||||||
dtype=x.dtype, device=x.device)
|
|
||||||
u1 = self.encoder_unmasked_dim[i]
|
|
||||||
u2 = u1 + (channels - u1) // 2
|
|
||||||
|
|
||||||
feature_mask[:, :, u1:u2] *= mask[..., 0:1]
|
|
||||||
feature_mask[:, :, u2:] *= mask[..., 1:2]
|
|
||||||
|
|
||||||
feature_masks.append(feature_mask)
|
|
||||||
|
|
||||||
return feature_masks
|
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -267,7 +215,6 @@ class Subformer(EncoderInterface):
|
|||||||
of frames in `embeddings` before padding.
|
of frames in `embeddings` before padding.
|
||||||
"""
|
"""
|
||||||
outputs = []
|
outputs = []
|
||||||
feature_masks = self.get_feature_masks(x)
|
|
||||||
|
|
||||||
attn_offset = self._get_attn_offset(x, src_key_padding_mask)
|
attn_offset = self._get_attn_offset(x, src_key_padding_mask)
|
||||||
|
|
||||||
@ -281,44 +228,17 @@ class Subformer(EncoderInterface):
|
|||||||
|
|
||||||
pos_emb = self.encoder_pos(x)
|
pos_emb = self.encoder_pos(x)
|
||||||
|
|
||||||
for i, module in enumerate(self.encoders):
|
x = self.encoder(x,
|
||||||
x = convert_num_channels(x, self.encoder_dim[i])
|
pos_emb,
|
||||||
|
attn_offset=attn_offset,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
)
|
||||||
|
|
||||||
x = module(x,
|
# d = self.output_downsampling_factor
|
||||||
pos_emb,
|
# lengths = (x_lens + d - 1) // d
|
||||||
feature_mask=feature_masks[i],
|
|
||||||
attn_offset=attn_offset,
|
|
||||||
memory=memory,
|
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
|
||||||
)
|
|
||||||
outputs.append(x)
|
|
||||||
|
|
||||||
def get_full_dim_output():
|
return x, x_lens
|
||||||
num_encoders = len(self.encoder_dim)
|
|
||||||
assert len(outputs) == num_encoders
|
|
||||||
output_dim = max(self.encoder_dim)
|
|
||||||
output_pieces = [ outputs[-1] ]
|
|
||||||
cur_dim = self.encoder_dim[-1]
|
|
||||||
for i in range(num_encoders - 2, -1, -1):
|
|
||||||
d = self.encoder_dim[i]
|
|
||||||
if d > cur_dim:
|
|
||||||
this_output = outputs[i]
|
|
||||||
output_pieces.append(this_output[..., cur_dim:d])
|
|
||||||
cur_dim = d
|
|
||||||
assert cur_dim == output_dim
|
|
||||||
return torch.cat(output_pieces, dim=-1)
|
|
||||||
|
|
||||||
# if the last output has the largest dimension, x will be unchanged,
|
|
||||||
# it will be the same as outputs[-1]. Otherwise it will be concatenated
|
|
||||||
# from different pieces of 'outputs', taking each dimension from the
|
|
||||||
# most recent output that has it present.
|
|
||||||
x = get_full_dim_output()
|
|
||||||
#x = self.downsample_output(x)
|
|
||||||
|
|
||||||
d = self.output_downsampling_factor
|
|
||||||
lengths = (x_lens + d - 1) // d
|
|
||||||
|
|
||||||
return x, lengths
|
|
||||||
|
|
||||||
def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]:
|
def _get_attn_offset(self, x: Tensor, src_key_padding_mask: Optional[Tensor]) -> Optional[Tensor]:
|
||||||
"""
|
"""
|
||||||
@ -507,7 +427,6 @@ class SubformerEncoderLayer(nn.Module):
|
|||||||
min_abs=0.1, max_abs=4.0,
|
min_abs=0.1, max_abs=4.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
|
def get_sequence_dropout_mask(self, x: Tensor, dropout_rate: float) -> Optional[Tensor]:
|
||||||
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
|
if dropout_rate == 0.0 or not self.training or torch.jit.is_scripting():
|
||||||
return None
|
return None
|
||||||
@ -650,11 +569,14 @@ class SubformerEncoder(nn.Module):
|
|||||||
final_layerdrop_rate: float = 0.05,
|
final_layerdrop_rate: float = 0.05,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.layers = nn.ModuleList(
|
self.layers = nn.ModuleList(
|
||||||
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
|
||||||
)
|
)
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
|
|
||||||
|
self.bypass = BypassModule(self.embed_dim())
|
||||||
|
|
||||||
assert 0 <= warmup_begin <= warmup_end
|
assert 0 <= warmup_begin <= warmup_end
|
||||||
|
|
||||||
delta = (1. / num_layers) * (warmup_end - warmup_begin)
|
delta = (1. / num_layers) * (warmup_end - warmup_begin)
|
||||||
@ -666,11 +588,14 @@ class SubformerEncoder(nn.Module):
|
|||||||
default=0.0)
|
default=0.0)
|
||||||
cur_begin = cur_end
|
cur_begin = cur_end
|
||||||
|
|
||||||
|
def embed_dim(self):
|
||||||
|
return self.layers[0].embed_dim
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
pos_emb: Tensor,
|
pos_emb: Tensor,
|
||||||
feature_mask: Union[Tensor, float] = 1.0,
|
feature_mask: Optional[Tensor] = None,
|
||||||
attn_offset: Optional[Tensor] = None,
|
attn_offset: Optional[Tensor] = None,
|
||||||
memory: Optional[Tensor] = None,
|
memory: Optional[Tensor] = None,
|
||||||
memory_key_padding_mask: Optional[Tensor] = None,
|
memory_key_padding_mask: Optional[Tensor] = None,
|
||||||
@ -693,11 +618,13 @@ class SubformerEncoder(nn.Module):
|
|||||||
|
|
||||||
Returns: a Tensor with the same shape as src.
|
Returns: a Tensor with the same shape as src.
|
||||||
"""
|
"""
|
||||||
|
src = convert_num_channels(src, self.embed_dim())
|
||||||
output = src
|
output = src
|
||||||
|
|
||||||
rnd_seed = src.numel() + random.randint(0, 1000)
|
rnd_seed = src.numel() + random.randint(0, 1000)
|
||||||
|
|
||||||
output = output * feature_mask
|
if feature_mask is not None:
|
||||||
|
output = output * feature_mask
|
||||||
|
|
||||||
for i, mod in enumerate(self.layers):
|
for i, mod in enumerate(self.layers):
|
||||||
output = mod(
|
output = mod(
|
||||||
@ -708,9 +635,10 @@ class SubformerEncoder(nn.Module):
|
|||||||
memory_key_padding_mask=memory_key_padding_mask,
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
output = output * feature_mask
|
if feature_mask is not None:
|
||||||
|
output = output * feature_mask
|
||||||
|
|
||||||
return output
|
return self.bypass(src, output)
|
||||||
|
|
||||||
|
|
||||||
class BypassModule(nn.Module):
|
class BypassModule(nn.Module):
|
||||||
@ -1014,17 +942,17 @@ class DownsampledSubformerEncoder(nn.Module):
|
|||||||
with the origin input, so that the output has the same shape as the input.
|
with the origin input, so that the output has the same shape as the input.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
encoder: nn.Module,
|
encoders: List[nn.Module],
|
||||||
dim: int,
|
input_num_channels: int,
|
||||||
downsample: int,
|
downsample: int):
|
||||||
dropout: FloatLike):
|
|
||||||
super(DownsampledSubformerEncoder, self).__init__()
|
super(DownsampledSubformerEncoder, self).__init__()
|
||||||
self.downsample_factor = downsample
|
self.downsample_factor = downsample
|
||||||
self.downsampler = LearnedDownsamplingModule(dim,
|
self.downsampler = LearnedDownsamplingModule(input_num_channels,
|
||||||
downsample)
|
downsample)
|
||||||
self.encoder = encoder
|
self.encoders = nn.ModuleList(encoders)
|
||||||
|
|
||||||
self.out_combiner = BypassModule(dim, straight_through_rate=0.025)
|
self.out_combiner = BypassModule(max(e.embed_dim() for e in encoders),
|
||||||
|
straight_through_rate=0.0)
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
src: Tensor,
|
src: Tensor,
|
||||||
@ -1059,15 +987,35 @@ class DownsampledSubformerEncoder(nn.Module):
|
|||||||
attn_offset = self.downsampler.downsample_attn_offset(attn_offset,
|
attn_offset = self.downsampler.downsample_attn_offset(attn_offset,
|
||||||
indexes,
|
indexes,
|
||||||
weights)
|
weights)
|
||||||
|
outputs = [ src ]
|
||||||
|
|
||||||
|
for encoder in self.encoders:
|
||||||
|
src = encoder(
|
||||||
|
src,
|
||||||
|
pos_emb,
|
||||||
|
attn_offset=attn_offset,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
)
|
||||||
|
outputs.append(src)
|
||||||
|
|
||||||
|
def get_full_dim_output():
|
||||||
|
num_encoders = len(outputs)
|
||||||
|
output_dim = max(o.shape[-1] for o in outputs)
|
||||||
|
output_pieces = [ outputs[-1] ]
|
||||||
|
cur_dim = outputs[-1].shape[-1]
|
||||||
|
for i in range(num_encoders - 2, -1, -1):
|
||||||
|
d = outputs[i].shape[-1]
|
||||||
|
if d > cur_dim:
|
||||||
|
this_output = outputs[i]
|
||||||
|
output_pieces.append(this_output[..., cur_dim:d])
|
||||||
|
cur_dim = d
|
||||||
|
assert cur_dim == output_dim
|
||||||
|
return torch.cat(output_pieces, dim=-1)
|
||||||
|
|
||||||
|
src = get_full_dim_output()
|
||||||
|
src_orig = convert_num_channels(src_orig, src.shape[-1])
|
||||||
|
|
||||||
src = self.encoder(
|
|
||||||
src,
|
|
||||||
pos_emb,
|
|
||||||
feature_mask=feature_mask,
|
|
||||||
attn_offset=attn_offset,
|
|
||||||
memory=memory,
|
|
||||||
memory_key_padding_mask=memory_key_padding_mask,
|
|
||||||
)
|
|
||||||
src = self.downsampler.upsample(src_orig, src, indexes)
|
src = self.downsampler.upsample(src_orig, src, indexes)
|
||||||
|
|
||||||
return self.out_combiner(src_orig, src)
|
return self.out_combiner(src_orig, src)
|
||||||
@ -1793,7 +1741,8 @@ def _test_zipformer_main(causal: bool = False):
|
|||||||
memory_dim = 100
|
memory_dim = 100
|
||||||
|
|
||||||
c = Subformer(
|
c = Subformer(
|
||||||
encoder_dim=(64, 96), encoder_unmasked_dim=(48, 64), num_heads=(4, 4),
|
encoder_dim=(64, 96, 64),
|
||||||
|
num_heads=(4, 4, 8),
|
||||||
causal=causal,
|
causal=causal,
|
||||||
memory_dim=memory_dim,
|
memory_dim=memory_dim,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -125,15 +125,6 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Number of subformer encoder layers per stack, comma separated.",
|
help="Number of subformer encoder layers per stack, comma separated.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--downsampling-factor",
|
|
||||||
type=str,
|
|
||||||
default="1,2,4,8,4,2,1",
|
|
||||||
help="Downsampling factor for each stack of encoder layers.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--feedforward-dim",
|
"--feedforward-dim",
|
||||||
type=str,
|
type=str,
|
||||||
@ -176,13 +167,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
help="Positional-encoding dimension in encoder stacks: a single int or comma-separated list."
|
help="Positional-encoding dimension in encoder stacks: a single int or comma-separated list."
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--encoder-unmasked-dim",
|
|
||||||
type=str,
|
|
||||||
default="192,192,256,256,256,192,192",
|
|
||||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
|
||||||
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
@ -478,11 +463,8 @@ def get_encoder_embed(params: AttributeDict) -> nn.Module:
|
|||||||
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||||
#chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
|
#chunk_size = _to_int_tuple(params.downsampling_factor)[-1]
|
||||||
encoder = Subformer(
|
encoder = Subformer(
|
||||||
#output_downsampling_factor=chunk_size,
|
|
||||||
downsampling_factor=_to_int_tuple(params.downsampling_factor),
|
|
||||||
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
num_encoder_layers=_to_int_tuple(params.num_encoder_layers),
|
||||||
encoder_dim=_to_int_tuple(params.encoder_dim),
|
encoder_dim=_to_int_tuple(params.encoder_dim),
|
||||||
encoder_unmasked_dim=_to_int_tuple(params.encoder_unmasked_dim),
|
|
||||||
query_head_dim=_to_int_tuple(params.query_head_dim),
|
query_head_dim=_to_int_tuple(params.query_head_dim),
|
||||||
pos_dim=int(params.pos_dim),
|
pos_dim=int(params.pos_dim),
|
||||||
value_head_dim=_to_int_tuple(params.value_head_dim),
|
value_head_dim=_to_int_tuple(params.value_head_dim),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user