Ditch caching_eval; reduce params more.

This commit is contained in:
Daniel Povey 2022-12-16 00:22:44 +08:00
parent 083e5474c4
commit 53ab18a862

View File

@ -41,7 +41,6 @@ from scaling import (
Identity, # more friendly to backward hooks than nn.Identity(), for diagnostic reasons.
penalize_abs_values_gt,
softmax,
caching_eval,
ScheduledFloat,
FloatLike,
limit_param_value,
@ -318,12 +317,7 @@ class Zipformer(EncoderInterface):
- lengths, a tensor of shape (batch_size,) containing the number
of frames in `embeddings` before padding.
"""
if not torch.jit.is_scripting():
# This saves memory during training, at the expense of re-doing the encoder_embed
# computation in the backward pass.
x = caching_eval(x, self.encoder_embed)
else:
x = self.encoder_embed(x)
x = self.encoder_embed(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
@ -1693,7 +1687,7 @@ class ConvNeXt(nn.Module):
"""
def __init__(self,
channels: int,
hidden_ratio: int = 4,
hidden_ratio: int = 3,
layerdrop_prob: FloatLike = None):
super().__init__()
kernel_size = 7