From eb7180a0e24b0129bf0997514c1a405452b70a17 Mon Sep 17 00:00:00 2001 From: JinZr <60612200+JinZr@users.noreply.github.com> Date: Mon, 14 Aug 2023 18:15:01 +0800 Subject: [PATCH] enabled pos_embed --- egs/librispeech/ASR/zipformer/decode.py | 10 +- .../alignment_attention_module.py | 101 +++++++++--------- .../ASR/zipformer_label_level_algn/model.py | 2 +- 3 files changed, 54 insertions(+), 59 deletions(-) diff --git a/egs/librispeech/ASR/zipformer/decode.py b/egs/librispeech/ASR/zipformer/decode.py index 41d54ea50..508d67aa8 100755 --- a/egs/librispeech/ASR/zipformer/decode.py +++ b/egs/librispeech/ASR/zipformer/decode.py @@ -108,7 +108,7 @@ import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from beam_search import ( beam_search, - deprecated_greedy_search_batch, + # deprecated_greedy_search_batch, fast_beam_search_nbest, fast_beam_search_nbest_LG, fast_beam_search_nbest_oracle, @@ -426,14 +426,10 @@ def decode_one_batch( for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: - # hyp_tokens = greedy_search_batch( - # model=model, - # encoder_out=encoder_out, - # encoder_out_lens=encoder_out_lens, - # ) - hyp_tokens = deprecated_greedy_search_batch( + hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py index df45119b7..22a68aacf 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/alignment_attention_module.py @@ -333,60 +333,59 @@ class RelPositionMultiheadAttentionWeights(nn.Module): attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len) # print(attn_scores.shape) - # use_pos_scores = False - # if torch.jit.is_scripting() or torch.jit.is_tracing(): - # # We can't put random.random() in the same line - # use_pos_scores = True - # elif not self.training or random.random() >= float(self.pos_emb_skip_rate): - # use_pos_scores = True + use_pos_scores = False + if torch.jit.is_scripting() or torch.jit.is_tracing(): + # We can't put random.random() in the same line + use_pos_scores = True + elif not self.training or random.random() >= float(self.pos_emb_skip_rate): + use_pos_scores = True - # if use_pos_scores: - # pos_emb = self.linear_pos(pos_emb) + if use_pos_scores: + pos_emb = self.linear_pos(pos_emb) - # seq_len2 = 2 * am_seq_len - 1 - # pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( - # 2, 0, 3, 1 - # ) - # # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) + seq_len2 = 2 * am_seq_len - 1 + pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim) + pos_emb = pos_emb.permute(2, 0, 3, 1) + # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2) - # if for_reference: - # pos_offset = am_seq_len - 1 - offset - # pos_emb = pos_emb[ - # :, - # :, - # :, - # pos_offset : pos_offset + am_seq_len, - # ] - # # (head, 1, pos_dim, seq_len2) - # # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) - # # [where seq_len2 represents relative position.] - # pos_scores = torch.matmul(p, pos_emb) - # # print(pos_scores.shape, attn_scores.shape) - # # the following .as_strided() expression converts the last axis of pos_scores from relative - # # to absolute position. I don't know whether I might have got the time-offsets backwards or - # # not, but let this code define which way round it is supposed to be. - # if torch.jit.is_tracing(): - # (num_heads, b_p_dim, time1, n) = pos_scores.shape - # rows = torch.arange(start=time1 - 1, end=-1, step=-1) - # cols = torch.arange(lm_seq_len) - # rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1) - # indexes = rows + cols - # pos_scores = pos_scores.reshape(-1, n) - # pos_scores = torch.gather(pos_scores, dim=1, index=indexes) - # pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len) - # else: - # pos_scores = pos_scores.as_strided( - # (num_heads, b_p_dim, lm_seq_len, lm_seq_len), - # ( - # pos_scores.stride(0), - # pos_scores.stride(1), - # pos_scores.stride(2) - pos_scores.stride(3), - # pos_scores.stride(3), - # ), - # storage_offset=pos_scores.stride(3) * (lm_seq_len - 1), - # ) - # # print(pos_scores.shape) - # attn_scores = attn_scores + pos_scores + if for_reference: + pos_offset = am_seq_len - 1 - offset + pos_emb = pos_emb[ + :, + :, + :, + pos_offset : pos_offset + am_seq_len, + ] + # (head, 1, pos_dim, seq_len2) + # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) + # [where seq_len2 represents relative position.] + pos_scores = torch.matmul(p, pos_emb) + # print(pos_scores.shape, attn_scores.shape) + # the following .as_strided() expression converts the last axis of pos_scores from relative + # to absolute position. I don't know whether I might have got the time-offsets backwards or + # not, but let this code define which way round it is supposed to be. + if torch.jit.is_tracing(): + (num_heads, b_p_dim, time1, n) = pos_scores.shape + rows = torch.arange(start=time1 - 1, end=-1, step=-1) + cols = torch.arange(lm_seq_len) + rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1) + indexes = rows + cols + pos_scores = pos_scores.reshape(-1, n) + pos_scores = torch.gather(pos_scores, dim=1, index=indexes) + pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len) + else: + pos_scores = pos_scores.as_strided( + (num_heads, b_p_dim, lm_seq_len, lm_seq_len), + ( + pos_scores.stride(0), + pos_scores.stride(1), + pos_scores.stride(2) - pos_scores.stride(3), + pos_scores.stride(3), + ), + storage_offset=pos_scores.stride(3) * (lm_seq_len - 1), + ) + # print(pos_scores.shape) + attn_scores = attn_scores + pos_scores if torch.jit.is_scripting() or torch.jit.is_tracing(): pass diff --git a/egs/librispeech/ASR/zipformer_label_level_algn/model.py b/egs/librispeech/ASR/zipformer_label_level_algn/model.py index 7d00bf98b..cb863d30d 100644 --- a/egs/librispeech/ASR/zipformer_label_level_algn/model.py +++ b/egs/librispeech/ASR/zipformer_label_level_algn/model.py @@ -84,7 +84,7 @@ class AsrModel(nn.Module): self.encoder_embed = encoder_embed self.encoder = encoder - self.dropout = nn.Dropout(p=0.1) + self.dropout = nn.Dropout(p=0.5) self.use_transducer = use_transducer if use_transducer: