mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Increase num parameters
This commit is contained in:
parent
047c6ffc58
commit
a397a5973b
@ -339,7 +339,7 @@ class Subformer(EncoderInterface):
|
|||||||
src_t = t
|
src_t = t
|
||||||
tgt_t = t.unsqueeze(-1)
|
tgt_t = t.unsqueeze(-1)
|
||||||
attn_mask = (src_t > tgt_t)
|
attn_mask = (src_t > tgt_t)
|
||||||
ans.masked_fill(attn_mask, float('-inf'))
|
ans.masked_fill_(attn_mask, float('-inf'))
|
||||||
|
|
||||||
if src_key_padding_mask is not None:
|
if src_key_padding_mask is not None:
|
||||||
ans = ans * src_key_padding_mask.unsqueeze(1).logical_not()
|
ans = ans * src_key_padding_mask.unsqueeze(1).logical_not()
|
||||||
@ -795,10 +795,13 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
# score_balancer is just to keep the magnitudes of the scores in
|
# score_balancer is just to keep the magnitudes of the scores in
|
||||||
# a fixed range and keep them balanced around zero, to stop
|
# a fixed range and keep them balanced around zero, to stop
|
||||||
# these drifting around.
|
# these drifting around.
|
||||||
|
# largish range used to keep grads relatively small and avoid overflow in grads.
|
||||||
self.score_balancer = Balancer(1, channel_dim=-1,
|
self.score_balancer = Balancer(1, channel_dim=-1,
|
||||||
min_positive=0.4, max_positive=0.6,
|
min_positive=0.4, max_positive=0.6,
|
||||||
min_abs=1.0, max_abs=1.2,
|
min_abs=10.0, max_abs=12.0)
|
||||||
prob=0.025)
|
|
||||||
|
self.copy_weights1 = nn.Identity()
|
||||||
|
self.copy_weights2 = nn.Identity()
|
||||||
|
|
||||||
self.downsampling_factor = downsampling_factor
|
self.downsampling_factor = downsampling_factor
|
||||||
self.intermediate_rate = copy.deepcopy(intermediate_rate)
|
self.intermediate_rate = copy.deepcopy(intermediate_rate)
|
||||||
@ -855,12 +858,22 @@ class LearnedDownsamplingModule(nn.Module):
|
|||||||
left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1, keepdim=True)
|
left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1, keepdim=True)
|
||||||
|
|
||||||
# the + 0.001 is to avoid possible division by zero in case of ties.
|
# the + 0.001 is to avoid possible division by zero in case of ties.
|
||||||
weights = (sscores - right_avg) / (left_avg - right_avg + 0.001)
|
sscores = self.copy_weights1(sscores)
|
||||||
|
|
||||||
|
den = (left_avg - right_avg)
|
||||||
|
# the following is to avoid division by near-zero.
|
||||||
|
den = 0.8 * den + 0.2 * den.mean()
|
||||||
|
|
||||||
|
|
||||||
|
#logging.info(f"den = {den}")
|
||||||
|
weights = (sscores - right_avg) / den
|
||||||
weights = weights.clamp(min=0.0, max=1.0)
|
weights = weights.clamp(min=0.0, max=1.0)
|
||||||
|
|
||||||
indexes = indexes[:, :seq_len_reduced]
|
indexes = indexes[:, :seq_len_reduced]
|
||||||
weights = weights[:, :seq_len_reduced]
|
weights = weights[:, :seq_len_reduced]
|
||||||
|
|
||||||
|
weights = self.copy_weights2(weights)
|
||||||
|
|
||||||
# re-sort the indexes we kept, on index value, so that
|
# re-sort the indexes we kept, on index value, so that
|
||||||
# masking for causal models will be in the correct order.
|
# masking for causal models will be in the correct order.
|
||||||
|
|
||||||
|
|||||||
@ -121,7 +121,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-encoder-layers",
|
"--num-encoder-layers",
|
||||||
type=str,
|
type=str,
|
||||||
default="2,4,8,4,2",
|
default="2,4,4,8,4,4,2",
|
||||||
help="Number of subformer encoder layers per stack, comma separated.",
|
help="Number of subformer encoder layers per stack, comma separated.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -129,7 +129,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--downsampling-factor",
|
"--downsampling-factor",
|
||||||
type=str,
|
type=str,
|
||||||
default="1,2,4,2,1",
|
default="1,2,4,8,4,2,1",
|
||||||
help="Downsampling factor for each stack of encoder layers.",
|
help="Downsampling factor for each stack of encoder layers.",
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -137,21 +137,21 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--feedforward-dim",
|
"--feedforward-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="512,768,1024,768,512",
|
default="512,768,1024,1536,1024,768,512",
|
||||||
help="Feedforward dimension of the subformer encoder layers, per stack, comma separated.",
|
help="Feedforward dimension of the subformer encoder layers, per stack, comma separated.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-heads",
|
"--num-heads",
|
||||||
type=str,
|
type=str,
|
||||||
default="4,4,8,4,4",
|
default="4,4,8,16,8,4,4",
|
||||||
help="Number of attention heads in the subformer encoder layers: a single int or comma-separated list.",
|
help="Number of attention heads in the subformer encoder layers: a single int or comma-separated list.",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-dim",
|
"--encoder-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="256,256,384,256,256",
|
default="256,256,384,512,384,256,256",
|
||||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -179,7 +179,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--encoder-unmasked-dim",
|
"--encoder-unmasked-dim",
|
||||||
type=str,
|
type=str,
|
||||||
default="192,192,256,192,192",
|
default="192,192,256,256,256,192,192",
|
||||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||||
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
|
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
|
||||||
)
|
)
|
||||||
@ -452,7 +452,7 @@ def get_params() -> AttributeDict:
|
|||||||
"warm_step": 2000,
|
"warm_step": 2000,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
"bytes_per_segment": 2048,
|
"bytes_per_segment": 2048,
|
||||||
"batch_size": 16,
|
"batch_size": 20,
|
||||||
"train_file_list": "train.txt",
|
"train_file_list": "train.txt",
|
||||||
"valid_file_list": "valid.txt",
|
"valid_file_list": "valid.txt",
|
||||||
"num_workers": 4,
|
"num_workers": 4,
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user