Ditch caching_eval; reduce params more.
This commit is contained in:
parent
083e5474c4
commit
53ab18a862
@ -41,7 +41,6 @@ from scaling import (
|
||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||
penalize_abs_values_gt,
|
||||
softmax,
|
||||
caching_eval,
|
||||
ScheduledFloat,
|
||||
FloatLike,
|
||||
limit_param_value,
|
||||
@ -318,12 +317,7 @@ class Zipformer(EncoderInterface):
|
||||
- lengths, a tensor of shape (batch_size,) containing the number
|
||||
of frames in `embeddings` before padding.
|
||||
"""
|
||||
if not torch.jit.is_scripting():
|
||||
# This saves memory during training, at the expense of re-doing the encoder_embed
|
||||
# computation in the backward pass.
|
||||
x = caching_eval(x, self.encoder_embed)
|
||||
else:
|
||||
x = self.encoder_embed(x)
|
||||
x = self.encoder_embed(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||
|
||||
@ -1693,7 +1687,7 @@ class ConvNeXt(nn.Module):
|
||||
"""
|
||||
def __init__(self,
|
||||
channels: int,
|
||||
hidden_ratio: int = 4,
|
||||
hidden_ratio: int = 3,
|
||||
layerdrop_prob: FloatLike = None):
|
||||
super().__init__()
|
||||
kernel_size = 7
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user