mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-18 21:44:18 +00:00
enabled pos_embed
This commit is contained in:
parent
b5d6a69cb4
commit
eb7180a0e2
@ -108,7 +108,7 @@ import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from beam_search import (
|
||||
beam_search,
|
||||
deprecated_greedy_search_batch,
|
||||
# deprecated_greedy_search_batch,
|
||||
fast_beam_search_nbest,
|
||||
fast_beam_search_nbest_LG,
|
||||
fast_beam_search_nbest_oracle,
|
||||
@ -426,14 +426,10 @@ def decode_one_batch(
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||
# hyp_tokens = greedy_search_batch(
|
||||
# model=model,
|
||||
# encoder_out=encoder_out,
|
||||
# encoder_out_lens=encoder_out_lens,
|
||||
# )
|
||||
hyp_tokens = deprecated_greedy_search_batch(
|
||||
hyp_tokens = greedy_search_batch(
|
||||
model=model,
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
)
|
||||
for hyp in sp.decode(hyp_tokens):
|
||||
hyps.append(hyp.split())
|
||||
|
@ -333,60 +333,59 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
||||
attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len)
|
||||
# print(attn_scores.shape)
|
||||
|
||||
# use_pos_scores = False
|
||||
# if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
# # We can't put random.random() in the same line
|
||||
# use_pos_scores = True
|
||||
# elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||
# use_pos_scores = True
|
||||
use_pos_scores = False
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
# We can't put random.random() in the same line
|
||||
use_pos_scores = True
|
||||
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||
use_pos_scores = True
|
||||
|
||||
# if use_pos_scores:
|
||||
# pos_emb = self.linear_pos(pos_emb)
|
||||
if use_pos_scores:
|
||||
pos_emb = self.linear_pos(pos_emb)
|
||||
|
||||
# seq_len2 = 2 * am_seq_len - 1
|
||||
# pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
|
||||
# 2, 0, 3, 1
|
||||
# )
|
||||
# # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
||||
seq_len2 = 2 * am_seq_len - 1
|
||||
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim)
|
||||
pos_emb = pos_emb.permute(2, 0, 3, 1)
|
||||
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
||||
|
||||
# if for_reference:
|
||||
# pos_offset = am_seq_len - 1 - offset
|
||||
# pos_emb = pos_emb[
|
||||
# :,
|
||||
# :,
|
||||
# :,
|
||||
# pos_offset : pos_offset + am_seq_len,
|
||||
# ]
|
||||
# # (head, 1, pos_dim, seq_len2)
|
||||
# # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||
# # [where seq_len2 represents relative position.]
|
||||
# pos_scores = torch.matmul(p, pos_emb)
|
||||
# # print(pos_scores.shape, attn_scores.shape)
|
||||
# # the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||
# # to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
# # not, but let this code define which way round it is supposed to be.
|
||||
# if torch.jit.is_tracing():
|
||||
# (num_heads, b_p_dim, time1, n) = pos_scores.shape
|
||||
# rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
# cols = torch.arange(lm_seq_len)
|
||||
# rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1)
|
||||
# indexes = rows + cols
|
||||
# pos_scores = pos_scores.reshape(-1, n)
|
||||
# pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||
# pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len)
|
||||
# else:
|
||||
# pos_scores = pos_scores.as_strided(
|
||||
# (num_heads, b_p_dim, lm_seq_len, lm_seq_len),
|
||||
# (
|
||||
# pos_scores.stride(0),
|
||||
# pos_scores.stride(1),
|
||||
# pos_scores.stride(2) - pos_scores.stride(3),
|
||||
# pos_scores.stride(3),
|
||||
# ),
|
||||
# storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
|
||||
# )
|
||||
# # print(pos_scores.shape)
|
||||
# attn_scores = attn_scores + pos_scores
|
||||
if for_reference:
|
||||
pos_offset = am_seq_len - 1 - offset
|
||||
pos_emb = pos_emb[
|
||||
:,
|
||||
:,
|
||||
:,
|
||||
pos_offset : pos_offset + am_seq_len,
|
||||
]
|
||||
# (head, 1, pos_dim, seq_len2)
|
||||
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||
# [where seq_len2 represents relative position.]
|
||||
pos_scores = torch.matmul(p, pos_emb)
|
||||
# print(pos_scores.shape, attn_scores.shape)
|
||||
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||
# not, but let this code define which way round it is supposed to be.
|
||||
if torch.jit.is_tracing():
|
||||
(num_heads, b_p_dim, time1, n) = pos_scores.shape
|
||||
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||
cols = torch.arange(lm_seq_len)
|
||||
rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1)
|
||||
indexes = rows + cols
|
||||
pos_scores = pos_scores.reshape(-1, n)
|
||||
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||
pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len)
|
||||
else:
|
||||
pos_scores = pos_scores.as_strided(
|
||||
(num_heads, b_p_dim, lm_seq_len, lm_seq_len),
|
||||
(
|
||||
pos_scores.stride(0),
|
||||
pos_scores.stride(1),
|
||||
pos_scores.stride(2) - pos_scores.stride(3),
|
||||
pos_scores.stride(3),
|
||||
),
|
||||
storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
|
||||
)
|
||||
# print(pos_scores.shape)
|
||||
attn_scores = attn_scores + pos_scores
|
||||
|
||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||
pass
|
||||
|
@ -84,7 +84,7 @@ class AsrModel(nn.Module):
|
||||
|
||||
self.encoder_embed = encoder_embed
|
||||
self.encoder = encoder
|
||||
self.dropout = nn.Dropout(p=0.1)
|
||||
self.dropout = nn.Dropout(p=0.5)
|
||||
|
||||
self.use_transducer = use_transducer
|
||||
if use_transducer:
|
||||
|
Loading…
x
Reference in New Issue
Block a user