Ditch caching_eval; reduce params more.
This commit is contained in:
parent
083e5474c4
commit
53ab18a862
@ -41,7 +41,6 @@ from scaling import (
|
|||||||
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
|
||||||
penalize_abs_values_gt,
|
penalize_abs_values_gt,
|
||||||
softmax,
|
softmax,
|
||||||
caching_eval,
|
|
||||||
ScheduledFloat,
|
ScheduledFloat,
|
||||||
FloatLike,
|
FloatLike,
|
||||||
limit_param_value,
|
limit_param_value,
|
||||||
@ -318,12 +317,7 @@ class Zipformer(EncoderInterface):
|
|||||||
- lengths, a tensor of shape (batch_size,) containing the number
|
- lengths, a tensor of shape (batch_size,) containing the number
|
||||||
of frames in `embeddings` before padding.
|
of frames in `embeddings` before padding.
|
||||||
"""
|
"""
|
||||||
if not torch.jit.is_scripting():
|
x = self.encoder_embed(x)
|
||||||
# This saves memory during training, at the expense of re-doing the encoder_embed
|
|
||||||
# computation in the backward pass.
|
|
||||||
x = caching_eval(x, self.encoder_embed)
|
|
||||||
else:
|
|
||||||
x = self.encoder_embed(x)
|
|
||||||
|
|
||||||
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
|
||||||
|
|
||||||
@ -1693,7 +1687,7 @@ class ConvNeXt(nn.Module):
|
|||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
channels: int,
|
channels: int,
|
||||||
hidden_ratio: int = 4,
|
hidden_ratio: int = 3,
|
||||||
layerdrop_prob: FloatLike = None):
|
layerdrop_prob: FloatLike = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
kernel_size = 7
|
kernel_size = 7
|
||||||
|
|||||||
Reference in New Issue
Block a user