mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Implement layerdrop per-sequence for convnext; lower, slower-decreasing layerdrop rate.
This commit is contained in:
parent
28cac1c2dc
commit
5fa8de5c05
@ -1660,14 +1660,14 @@ class ConvNeXt(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
channels: int,
|
channels: int,
|
||||||
hidden_ratio: int = 4,
|
hidden_ratio: int = 4,
|
||||||
layerdrop_prob: FloatLike = None):
|
layerdrop_rate: FloatLike = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
kernel_size = 7
|
kernel_size = 7
|
||||||
pad = (kernel_size - 1) // 2
|
pad = (kernel_size - 1) // 2
|
||||||
hidden_channels = channels * hidden_ratio
|
hidden_channels = channels * hidden_ratio
|
||||||
if layerdrop_prob is None:
|
if layerdrop_rate is None:
|
||||||
layerdrop_prob = ScheduledFloat((0.0, 0.2), (16000.0, 0.025))
|
layerdrop_rate = ScheduledFloat((0.0, 0.1), (20000.0, 0.01))
|
||||||
self.layerdrop_prob = layerdrop_prob
|
self.layerdrop_rate = layerdrop_rate
|
||||||
|
|
||||||
self.depthwise_conv = nn.Conv2d(
|
self.depthwise_conv = nn.Conv2d(
|
||||||
in_channels=channels,
|
in_channels=channels,
|
||||||
@ -1702,17 +1702,20 @@ class ConvNeXt(nn.Module):
|
|||||||
|
|
||||||
The returned value has the same shape as x.
|
The returned value has the same shape as x.
|
||||||
"""
|
"""
|
||||||
if torch.jit.is_scripting() or (self.training and random.random() < float(self.layerdrop_prob)):
|
|
||||||
return x
|
|
||||||
|
|
||||||
bypass = x
|
bypass = x
|
||||||
x = self.depthwise_conv(x)
|
x = self.depthwise_conv(x)
|
||||||
x = self.pointwise_conv1(x)
|
x = self.pointwise_conv1(x)
|
||||||
x = self.hidden_balancer(x)
|
x = self.hidden_balancer(x)
|
||||||
x = self.activation(x)
|
x = self.activation(x)
|
||||||
x = self.pointwise_conv2(x)
|
x = self.pointwise_conv2(x)
|
||||||
return bypass + x
|
|
||||||
|
|
||||||
|
layerdrop_rate = float(self.layerdrop_rate)
|
||||||
|
if not torch.jit.is_scripting() and self.training and layerdrop_rate != 0.0:
|
||||||
|
batch_size = x.shape[0]
|
||||||
|
mask = torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) > layerdrop_rate
|
||||||
|
x = x * mask
|
||||||
|
|
||||||
|
return bypass + x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user