Merge 0eccb2b62cc7ed8f7066bcc2090984ff87535006 into abd9437e6d5419a497707748eb935e50976c3b7b

This commit is contained in:
Daniel Povey 2025-06-27 11:34:23 +00:00 committed by GitHub
commit 65275d5e48
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 1 deletions

View File

@ -22,12 +22,15 @@ import math
from typing import List, Tuple
import numpy as np
import random
from scaling import penalize_abs_values_gt
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import Fp32GroupNorm, Fp32LayerNorm, TransposeLast
class ConvFeatureExtractionModel(nn.Module):
def __init__(
self,
@ -105,4 +108,8 @@ class ConvFeatureExtractionModel(nn.Module):
for conv in self.conv_layers:
x = conv(x)
if self.training and random.random() < 0.2:
x = penalize_abs_values_gt(x, limit=1000.0, penalty=1.0e-05,
name=(self.name if hasattr(self, 'name') else 'ConvFeatureExtractionModel'))
return x

View File

@ -789,7 +789,7 @@ class Zipformer2EncoderLayer(nn.Module):
selected_attn_weights = attn_weights[0:1]
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass
elif not self.training and random.random() < float(self.const_attention_rate):
elif self.training and random.random() < float(self.const_attention_rate):
# Make attention weights constant. The intention is to
# encourage these modules to do something similar to an
# averaging-over-time operation.