mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 10:02:22 +00:00
Merge 0eccb2b62cc7ed8f7066bcc2090984ff87535006 into abd9437e6d5419a497707748eb935e50976c3b7b
This commit is contained in:
commit
65275d5e48
@ -22,12 +22,15 @@ import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
from scaling import penalize_abs_values_gt
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from utils import Fp32GroupNorm, Fp32LayerNorm, TransposeLast
|
||||
|
||||
|
||||
|
||||
class ConvFeatureExtractionModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -105,4 +108,8 @@ class ConvFeatureExtractionModel(nn.Module):
|
||||
for conv in self.conv_layers:
|
||||
x = conv(x)
|
||||
|
||||
if self.training and random.random() < 0.2:
|
||||
x = penalize_abs_values_gt(x, limit=1000.0, penalty=1.0e-05,
|
||||
name=(self.name if hasattr(self, 'name') else 'ConvFeatureExtractionModel'))
|
||||
|
||||
return x
|
||||
|
@ -789,7 +789,7 @@ class Zipformer2EncoderLayer(nn.Module):
|
||||
selected_attn_weights = attn_weights[0:1]
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
pass
|
||||
elif not self.training and random.random() < float(self.const_attention_rate):
|
||||
elif self.training and random.random() < float(self.const_attention_rate):
|
||||
# Make attention weights constant. The intention is to
|
||||
# encourage these modules to do something similar to an
|
||||
# averaging-over-time operation.
|
||||
|
Loading…
x
Reference in New Issue
Block a user