mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Add print statement
This commit is contained in:
parent
d398f0ed70
commit
d8f7310118
@ -21,6 +21,7 @@ import warnings
|
|||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
import logging
|
import logging
|
||||||
import torch
|
import torch
|
||||||
|
import random
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import (
|
from scaling import (
|
||||||
ActivationBalancer,
|
ActivationBalancer,
|
||||||
@ -1090,6 +1091,9 @@ class AttentionCombine(nn.Module):
|
|||||||
|
|
||||||
weights = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias
|
weights = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias
|
||||||
|
|
||||||
|
if random.random() < 0.002:
|
||||||
|
logging.info(f"Average weights are {weights.softmax(dim=1).mean(dim=0)}")
|
||||||
|
|
||||||
if self.training:
|
if self.training:
|
||||||
# random masking..
|
# random masking..
|
||||||
mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob),
|
mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user