mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add print statement
This commit is contained in:
parent
d398f0ed70
commit
d8f7310118
@ -21,6 +21,7 @@ import warnings
|
||||
from typing import List, Optional, Tuple
|
||||
import logging
|
||||
import torch
|
||||
import random
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import (
|
||||
ActivationBalancer,
|
||||
@ -1090,6 +1091,9 @@ class AttentionCombine(nn.Module):
|
||||
|
||||
weights = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias
|
||||
|
||||
if random.random() < 0.002:
|
||||
logging.info(f"Average weights are {weights.softmax(dim=1).mean(dim=0)}")
|
||||
|
||||
if self.training:
|
||||
# random masking..
|
||||
mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user