Add print statement

This commit is contained in:
Daniel Povey 2022-09-29 14:15:29 +08:00
parent d398f0ed70
commit d8f7310118

View File

@ -21,6 +21,7 @@ import warnings
from typing import List, Optional, Tuple
import logging
import torch
import random
from encoder_interface import EncoderInterface
from scaling import (
ActivationBalancer,
@ -1090,6 +1091,9 @@ class AttentionCombine(nn.Module):
weights = (stacked_inputs * self.weight).sum(dim=(1,)) + self.bias
if random.random() < 0.002:
logging.info(f"Average weights are {weights.softmax(dim=1).mean(dim=0)}")
if self.training:
# random masking..
mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob),