mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-08 08:34:19 +00:00
fixed formatting issues
This commit is contained in:
parent
2cb0092b09
commit
b36f3b5c52
@ -240,4 +240,3 @@ def main():
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
@ -61,10 +61,15 @@ class Decoder(nn.Module):
|
||||
)
|
||||
# the balancers are to avoid any drift in the magnitude of the
|
||||
# embeddings, which would interact badly with parameter averaging.
|
||||
self.balancer = Balancer(decoder_dim, channel_dim=-1,
|
||||
min_positive=0.0, max_positive=1.0,
|
||||
min_abs=0.5, max_abs=1.0,
|
||||
prob=0.05)
|
||||
self.balancer = Balancer(
|
||||
decoder_dim,
|
||||
channel_dim=-1,
|
||||
min_positive=0.0,
|
||||
max_positive=1.0,
|
||||
min_abs=0.5,
|
||||
max_abs=1.0,
|
||||
prob=0.05,
|
||||
)
|
||||
|
||||
self.blank_id = blank_id
|
||||
|
||||
@ -81,10 +86,15 @@ class Decoder(nn.Module):
|
||||
groups=decoder_dim // 4, # group size == 4
|
||||
bias=False,
|
||||
)
|
||||
self.balancer2 = Balancer(decoder_dim, channel_dim=-1,
|
||||
min_positive=0.0, max_positive=1.0,
|
||||
min_abs=0.5, max_abs=1.0,
|
||||
prob=0.05)
|
||||
self.balancer2 = Balancer(
|
||||
decoder_dim,
|
||||
channel_dim=-1,
|
||||
min_positive=0.0,
|
||||
max_positive=1.0,
|
||||
min_abs=0.5,
|
||||
max_abs=1.0,
|
||||
prob=0.05,
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, need_pad: bool = True) -> torch.Tensor:
|
||||
"""
|
||||
@ -107,9 +117,7 @@ class Decoder(nn.Module):
|
||||
if self.context_size > 1:
|
||||
embedding_out = embedding_out.permute(0, 2, 1)
|
||||
if need_pad is True:
|
||||
embedding_out = F.pad(
|
||||
embedding_out, pad=(self.context_size - 1, 0)
|
||||
)
|
||||
embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0))
|
||||
else:
|
||||
# During inference time, there is no need to do extra padding
|
||||
# as we only need one output
|
||||
|
@ -52,12 +52,13 @@ class Joiner(nn.Module):
|
||||
Returns:
|
||||
Return a tensor of shape (N, T, s_range, C).
|
||||
"""
|
||||
assert encoder_out.ndim == decoder_out.ndim, (encoder_out.shape, decoder_out.shape)
|
||||
assert encoder_out.ndim == decoder_out.ndim, (
|
||||
encoder_out.shape,
|
||||
decoder_out.shape,
|
||||
)
|
||||
|
||||
if project_input:
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(
|
||||
decoder_out
|
||||
)
|
||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||
else:
|
||||
logit = encoder_out + decoder_out
|
||||
|
||||
|
@ -299,8 +299,8 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# the input is groups of parameter or named parameter.
|
||||
for cur_group in iterable_or_groups:
|
||||
assert "named_params" in cur_group
|
||||
name_list = [ x[0] for x in cur_group["named_params"] ]
|
||||
p_list = [ x[1] for x in cur_group["named_params"] ]
|
||||
name_list = [x[0] for x in cur_group["named_params"]]
|
||||
p_list = [x[1] for x in cur_group["named_params"]]
|
||||
del cur_group["named_params"]
|
||||
cur_group["params"] = p_list
|
||||
param_groups.append(cur_group)
|
||||
@ -667,8 +667,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# We have to look at the trained model for parameters at or around the
|
||||
# param_max_rms, because sometimes they can indicate a problem with the
|
||||
# topology or settings.
|
||||
scale_step = torch.minimum(scale_step,
|
||||
(param_max_rms - param_rms) / param_rms)
|
||||
scale_step = torch.minimum(scale_step, (param_max_rms - param_rms) / param_rms)
|
||||
|
||||
delta = state["delta"]
|
||||
# the factor of (1-beta1) relates to momentum.
|
||||
@ -879,7 +878,8 @@ class Eden(LRScheduler):
|
||||
warmup_factor = (
|
||||
1.0
|
||||
if self.batch >= self.warmup_batches
|
||||
else self.warmup_start + (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
|
||||
else self.warmup_start
|
||||
+ (1.0 - self.warmup_start) * (self.batch / self.warmup_batches)
|
||||
# else 0.5 + 0.5 * (self.batch / self.warmup_batches)
|
||||
)
|
||||
|
||||
|
@ -100,17 +100,13 @@ class Model(nn.Module):
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder_proj = encoder_proj
|
||||
|
||||
def forward(
|
||||
self, feature: Tensor, feature_lens: Tensor
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
def forward(self, feature: Tensor, feature_lens: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
x, x_lens = self.encoder_embed(feature, feature_lens)
|
||||
|
||||
src_key_padding_mask = make_pad_mask(x_lens)
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(
|
||||
x, x_lens, src_key_padding_mask
|
||||
)
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
|
||||
|
||||
encoder_out = encoder_out.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
logits = self.encoder_proj(encoder_out)
|
||||
@ -168,9 +164,7 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
|
||||
main()
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -374,11 +374,7 @@ def streaming_forward(
|
||||
Returns encoder outputs, output lengths, and updated states.
|
||||
"""
|
||||
cached_embed_left_pad = states[-2]
|
||||
(
|
||||
x,
|
||||
x_lens,
|
||||
new_cached_embed_left_pad,
|
||||
) = model.encoder_embed.streaming_forward(
|
||||
(x, x_lens, new_cached_embed_left_pad,) = model.encoder_embed.streaming_forward(
|
||||
x=features,
|
||||
x_lens=feature_lens,
|
||||
cached_left_pad=cached_embed_left_pad,
|
||||
|
@ -107,9 +107,7 @@ class ConvNeXt(nn.Module):
|
||||
if layerdrop_rate != 0.0:
|
||||
batch_size = x.shape[0]
|
||||
mask = (
|
||||
torch.rand(
|
||||
(batch_size, 1, 1, 1), dtype=x.dtype, device=x.device
|
||||
)
|
||||
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device)
|
||||
> layerdrop_rate
|
||||
)
|
||||
else:
|
||||
@ -275,9 +273,7 @@ class Conv2dSubsampling(nn.Module):
|
||||
# many copies of this extra gradient term.
|
||||
self.out_whiten = Whiten(
|
||||
num_groups=1,
|
||||
whitening_limit=ScheduledFloat(
|
||||
(0.0, 4.0), (20000.0, 8.0), default=4.0
|
||||
),
|
||||
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0),
|
||||
prob=(0.025, 0.25),
|
||||
grad_scale=0.02,
|
||||
)
|
||||
@ -400,8 +396,8 @@ class Conv2dSubsampling(nn.Module):
|
||||
left_pad = self.convnext.padding[0]
|
||||
freq = self.out_width
|
||||
channels = self.layer3_channels
|
||||
cached_embed_left_pad = torch.zeros(
|
||||
batch_size, channels, left_pad, freq
|
||||
).to(device)
|
||||
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to(
|
||||
device
|
||||
)
|
||||
|
||||
return cached_embed_left_pad
|
||||
|
File diff suppressed because it is too large
Load Diff
Loading…
x
Reference in New Issue
Block a user