enabled pos_embed

This commit is contained in:
JinZr 2023-08-14 18:15:01 +08:00
parent b5d6a69cb4
commit eb7180a0e2
3 changed files with 54 additions and 59 deletions

View File

@ -108,7 +108,7 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
deprecated_greedy_search_batch, # deprecated_greedy_search_batch,
fast_beam_search_nbest, fast_beam_search_nbest,
fast_beam_search_nbest_LG, fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle, fast_beam_search_nbest_oracle,
@ -426,14 +426,10 @@ def decode_one_batch(
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
# hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
# model=model,
# encoder_out=encoder_out,
# encoder_out_lens=encoder_out_lens,
# )
hyp_tokens = deprecated_greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for hyp in sp.decode(hyp_tokens): for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split()) hyps.append(hyp.split())

View File

@ -333,60 +333,59 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len) attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len)
# print(attn_scores.shape) # print(attn_scores.shape)
# use_pos_scores = False use_pos_scores = False
# if torch.jit.is_scripting() or torch.jit.is_tracing(): if torch.jit.is_scripting() or torch.jit.is_tracing():
# # We can't put random.random() in the same line # We can't put random.random() in the same line
# use_pos_scores = True use_pos_scores = True
# elif not self.training or random.random() >= float(self.pos_emb_skip_rate): elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
# use_pos_scores = True use_pos_scores = True
# if use_pos_scores: if use_pos_scores:
# pos_emb = self.linear_pos(pos_emb) pos_emb = self.linear_pos(pos_emb)
# seq_len2 = 2 * am_seq_len - 1 seq_len2 = 2 * am_seq_len - 1
# pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute( pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim)
# 2, 0, 3, 1 pos_emb = pos_emb.permute(2, 0, 3, 1)
# ) # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
# # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
# if for_reference: if for_reference:
# pos_offset = am_seq_len - 1 - offset pos_offset = am_seq_len - 1 - offset
# pos_emb = pos_emb[ pos_emb = pos_emb[
# :, :,
# :, :,
# :, :,
# pos_offset : pos_offset + am_seq_len, pos_offset : pos_offset + am_seq_len,
# ] ]
# # (head, 1, pos_dim, seq_len2) # (head, 1, pos_dim, seq_len2)
# # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2) # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
# # [where seq_len2 represents relative position.] # [where seq_len2 represents relative position.]
# pos_scores = torch.matmul(p, pos_emb) pos_scores = torch.matmul(p, pos_emb)
# # print(pos_scores.shape, attn_scores.shape) # print(pos_scores.shape, attn_scores.shape)
# # the following .as_strided() expression converts the last axis of pos_scores from relative # the following .as_strided() expression converts the last axis of pos_scores from relative
# # to absolute position. I don't know whether I might have got the time-offsets backwards or # to absolute position. I don't know whether I might have got the time-offsets backwards or
# # not, but let this code define which way round it is supposed to be. # not, but let this code define which way round it is supposed to be.
# if torch.jit.is_tracing(): if torch.jit.is_tracing():
# (num_heads, b_p_dim, time1, n) = pos_scores.shape (num_heads, b_p_dim, time1, n) = pos_scores.shape
# rows = torch.arange(start=time1 - 1, end=-1, step=-1) rows = torch.arange(start=time1 - 1, end=-1, step=-1)
# cols = torch.arange(lm_seq_len) cols = torch.arange(lm_seq_len)
# rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1) rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1)
# indexes = rows + cols indexes = rows + cols
# pos_scores = pos_scores.reshape(-1, n) pos_scores = pos_scores.reshape(-1, n)
# pos_scores = torch.gather(pos_scores, dim=1, index=indexes) pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
# pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len) pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len)
# else: else:
# pos_scores = pos_scores.as_strided( pos_scores = pos_scores.as_strided(
# (num_heads, b_p_dim, lm_seq_len, lm_seq_len), (num_heads, b_p_dim, lm_seq_len, lm_seq_len),
# ( (
# pos_scores.stride(0), pos_scores.stride(0),
# pos_scores.stride(1), pos_scores.stride(1),
# pos_scores.stride(2) - pos_scores.stride(3), pos_scores.stride(2) - pos_scores.stride(3),
# pos_scores.stride(3), pos_scores.stride(3),
# ), ),
# storage_offset=pos_scores.stride(3) * (lm_seq_len - 1), storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
# ) )
# # print(pos_scores.shape) # print(pos_scores.shape)
# attn_scores = attn_scores + pos_scores attn_scores = attn_scores + pos_scores
if torch.jit.is_scripting() or torch.jit.is_tracing(): if torch.jit.is_scripting() or torch.jit.is_tracing():
pass pass

View File

@ -84,7 +84,7 @@ class AsrModel(nn.Module):
self.encoder_embed = encoder_embed self.encoder_embed = encoder_embed
self.encoder = encoder self.encoder = encoder
self.dropout = nn.Dropout(p=0.1) self.dropout = nn.Dropout(p=0.5)
self.use_transducer = use_transducer self.use_transducer = use_transducer
if use_transducer: if use_transducer: