mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Increase num parameters
This commit is contained in:
parent
047c6ffc58
commit
a397a5973b
@ -339,7 +339,7 @@ class Subformer(EncoderInterface):
|
||||
src_t = t
|
||||
tgt_t = t.unsqueeze(-1)
|
||||
attn_mask = (src_t > tgt_t)
|
||||
ans.masked_fill(attn_mask, float('-inf'))
|
||||
ans.masked_fill_(attn_mask, float('-inf'))
|
||||
|
||||
if src_key_padding_mask is not None:
|
||||
ans = ans * src_key_padding_mask.unsqueeze(1).logical_not()
|
||||
@ -795,10 +795,13 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
# score_balancer is just to keep the magnitudes of the scores in
|
||||
# a fixed range and keep them balanced around zero, to stop
|
||||
# these drifting around.
|
||||
# largish range used to keep grads relatively small and avoid overflow in grads.
|
||||
self.score_balancer = Balancer(1, channel_dim=-1,
|
||||
min_positive=0.4, max_positive=0.6,
|
||||
min_abs=1.0, max_abs=1.2,
|
||||
prob=0.025)
|
||||
min_abs=10.0, max_abs=12.0)
|
||||
|
||||
self.copy_weights1 = nn.Identity()
|
||||
self.copy_weights2 = nn.Identity()
|
||||
|
||||
self.downsampling_factor = downsampling_factor
|
||||
self.intermediate_rate = copy.deepcopy(intermediate_rate)
|
||||
@ -855,12 +858,22 @@ class LearnedDownsamplingModule(nn.Module):
|
||||
left_avg = sscores[:, left-collar:left+collar+1].mean(dim=-1, keepdim=True)
|
||||
|
||||
# the + 0.001 is to avoid possible division by zero in case of ties.
|
||||
weights = (sscores - right_avg) / (left_avg - right_avg + 0.001)
|
||||
sscores = self.copy_weights1(sscores)
|
||||
|
||||
den = (left_avg - right_avg)
|
||||
# the following is to avoid division by near-zero.
|
||||
den = 0.8 * den + 0.2 * den.mean()
|
||||
|
||||
|
||||
#logging.info(f"den = {den}")
|
||||
weights = (sscores - right_avg) / den
|
||||
weights = weights.clamp(min=0.0, max=1.0)
|
||||
|
||||
indexes = indexes[:, :seq_len_reduced]
|
||||
weights = weights[:, :seq_len_reduced]
|
||||
|
||||
weights = self.copy_weights2(weights)
|
||||
|
||||
# re-sort the indexes we kept, on index value, so that
|
||||
# masking for causal models will be in the correct order.
|
||||
|
||||
|
@ -121,7 +121,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--num-encoder-layers",
|
||||
type=str,
|
||||
default="2,4,8,4,2",
|
||||
default="2,4,4,8,4,4,2",
|
||||
help="Number of subformer encoder layers per stack, comma separated.",
|
||||
)
|
||||
|
||||
@ -129,7 +129,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--downsampling-factor",
|
||||
type=str,
|
||||
default="1,2,4,2,1",
|
||||
default="1,2,4,8,4,2,1",
|
||||
help="Downsampling factor for each stack of encoder layers.",
|
||||
)
|
||||
|
||||
@ -137,21 +137,21 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--feedforward-dim",
|
||||
type=str,
|
||||
default="512,768,1024,768,512",
|
||||
default="512,768,1024,1536,1024,768,512",
|
||||
help="Feedforward dimension of the subformer encoder layers, per stack, comma separated.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-heads",
|
||||
type=str,
|
||||
default="4,4,8,4,4",
|
||||
default="4,4,8,16,8,4,4",
|
||||
help="Number of attention heads in the subformer encoder layers: a single int or comma-separated list.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder-dim",
|
||||
type=str,
|
||||
default="256,256,384,256,256",
|
||||
default="256,256,384,512,384,256,256",
|
||||
help="Embedding dimension in encoder stacks: a single int or comma-separated list."
|
||||
)
|
||||
|
||||
@ -179,7 +179,7 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
"--encoder-unmasked-dim",
|
||||
type=str,
|
||||
default="192,192,256,192,192",
|
||||
default="192,192,256,256,256,192,192",
|
||||
help="Unmasked dimensions in the encoders, relates to augmentation during training. "
|
||||
"A single int or comma-separated list. Must be <= each corresponding encoder_dim."
|
||||
)
|
||||
@ -452,7 +452,7 @@ def get_params() -> AttributeDict:
|
||||
"warm_step": 2000,
|
||||
"env_info": get_env_info(),
|
||||
"bytes_per_segment": 2048,
|
||||
"batch_size": 16,
|
||||
"batch_size": 20,
|
||||
"train_file_list": "train.txt",
|
||||
"valid_file_list": "valid.txt",
|
||||
"num_workers": 4,
|
||||
|
Loading…
x
Reference in New Issue
Block a user