Fix bug where attn_mask was not passed in.

This commit is contained in:
Daniel Povey 2023-02-11 17:31:21 +08:00
parent e9157535a4
commit 8ccd061051

View File

@ -479,9 +479,9 @@ class Zipformer(EncoderInterface):
attn_mask = torch.logical_or(src_c > tgt_c, attn_mask = torch.logical_or(src_c > tgt_c,
src_c < tgt_c - left_context_chunks) src_c < tgt_c - left_context_chunks)
if __name__ == "__main__": if __name__ == "__main__" or random.random() < 0.1:
logging.info(f"attn_mask = {attn_mask}") logging.info(f"attn_mask = {attn_mask}")
return attn_mask
def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: def _whitening_schedule(x: float, ratio: float = 2.0) -> ScheduledFloat: