enabled pos_embed

This commit is contained in:
JinZr 2023-08-14 18:15:01 +08:00
parent b5d6a69cb4
commit eb7180a0e2
3 changed files with 54 additions and 59 deletions

View File

@ -108,7 +108,7 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from beam_search import (
beam_search,
deprecated_greedy_search_batch,
# deprecated_greedy_search_batch,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_nbest_oracle,
@ -426,14 +426,10 @@ def decode_one_batch(
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
# hyp_tokens = greedy_search_batch(
# model=model,
# encoder_out=encoder_out,
# encoder_out_lens=encoder_out_lens,
# )
hyp_tokens = deprecated_greedy_search_batch(
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())

View File

@ -333,60 +333,59 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len)
# print(attn_scores.shape)
# use_pos_scores = False
# if torch.jit.is_scripting() or torch.jit.is_tracing():
# # We can't put random.random() in the same line
# use_pos_scores = True
# elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
# use_pos_scores = True
use_pos_scores = False
if torch.jit.is_scripting() or torch.jit.is_tracing():
# We can't put random.random() in the same line
use_pos_scores = True
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
use_pos_scores = True
# if use_pos_scores:
# pos_emb = self.linear_pos(pos_emb)
if use_pos_scores:
pos_emb = self.linear_pos(pos_emb)
# seq_len2 = 2 * am_seq_len - 1
# pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
# 2, 0, 3, 1
# )
# # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
seq_len2 = 2 * am_seq_len - 1
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim)
pos_emb = pos_emb.permute(2, 0, 3, 1)
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
# if for_reference:
# pos_offset = am_seq_len - 1 - offset
# pos_emb = pos_emb[
# :,
# :,
# :,
# pos_offset : pos_offset + am_seq_len,
# ]
# # (head, 1, pos_dim, seq_len2)
# # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
# # [where seq_len2 represents relative position.]
# pos_scores = torch.matmul(p, pos_emb)
# # print(pos_scores.shape, attn_scores.shape)
# # the following .as_strided() expression converts the last axis of pos_scores from relative
# # to absolute position. I don't know whether I might have got the time-offsets backwards or
# # not, but let this code define which way round it is supposed to be.
# if torch.jit.is_tracing():
# (num_heads, b_p_dim, time1, n) = pos_scores.shape
# rows = torch.arange(start=time1 - 1, end=-1, step=-1)
# cols = torch.arange(lm_seq_len)
# rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1)
# indexes = rows + cols
# pos_scores = pos_scores.reshape(-1, n)
# pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
# pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len)
# else:
# pos_scores = pos_scores.as_strided(
# (num_heads, b_p_dim, lm_seq_len, lm_seq_len),
# (
# pos_scores.stride(0),
# pos_scores.stride(1),
# pos_scores.stride(2) - pos_scores.stride(3),
# pos_scores.stride(3),
# ),
# storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
# )
# # print(pos_scores.shape)
# attn_scores = attn_scores + pos_scores
if for_reference:
pos_offset = am_seq_len - 1 - offset
pos_emb = pos_emb[
:,
:,
:,
pos_offset : pos_offset + am_seq_len,
]
# (head, 1, pos_dim, seq_len2)
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
# [where seq_len2 represents relative position.]
pos_scores = torch.matmul(p, pos_emb)
# print(pos_scores.shape, attn_scores.shape)
# the following .as_strided() expression converts the last axis of pos_scores from relative
# to absolute position. I don't know whether I might have got the time-offsets backwards or
# not, but let this code define which way round it is supposed to be.
if torch.jit.is_tracing():
(num_heads, b_p_dim, time1, n) = pos_scores.shape
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
cols = torch.arange(lm_seq_len)
rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1)
indexes = rows + cols
pos_scores = pos_scores.reshape(-1, n)
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len)
else:
pos_scores = pos_scores.as_strided(
(num_heads, b_p_dim, lm_seq_len, lm_seq_len),
(
pos_scores.stride(0),
pos_scores.stride(1),
pos_scores.stride(2) - pos_scores.stride(3),
pos_scores.stride(3),
),
storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
)
# print(pos_scores.shape)
attn_scores = attn_scores + pos_scores
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass

View File

@ -84,7 +84,7 @@ class AsrModel(nn.Module):
self.encoder_embed = encoder_embed
self.encoder = encoder
self.dropout = nn.Dropout(p=0.1)
self.dropout = nn.Dropout(p=0.5)
self.use_transducer = use_transducer
if use_transducer: