mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
enabled pos_embed
This commit is contained in:
parent
b5d6a69cb4
commit
eb7180a0e2
@ -108,7 +108,7 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from beam_search import (
|
from beam_search import (
|
||||||
beam_search,
|
beam_search,
|
||||||
deprecated_greedy_search_batch,
|
# deprecated_greedy_search_batch,
|
||||||
fast_beam_search_nbest,
|
fast_beam_search_nbest,
|
||||||
fast_beam_search_nbest_LG,
|
fast_beam_search_nbest_LG,
|
||||||
fast_beam_search_nbest_oracle,
|
fast_beam_search_nbest_oracle,
|
||||||
@ -426,14 +426,10 @@ def decode_one_batch(
|
|||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1:
|
||||||
# hyp_tokens = greedy_search_batch(
|
hyp_tokens = greedy_search_batch(
|
||||||
# model=model,
|
|
||||||
# encoder_out=encoder_out,
|
|
||||||
# encoder_out_lens=encoder_out_lens,
|
|
||||||
# )
|
|
||||||
hyp_tokens = deprecated_greedy_search_batch(
|
|
||||||
model=model,
|
model=model,
|
||||||
encoder_out=encoder_out,
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
)
|
)
|
||||||
for hyp in sp.decode(hyp_tokens):
|
for hyp in sp.decode(hyp_tokens):
|
||||||
hyps.append(hyp.split())
|
hyps.append(hyp.split())
|
||||||
|
@ -333,60 +333,59 @@ class RelPositionMultiheadAttentionWeights(nn.Module):
|
|||||||
attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len)
|
attn_scores = torch.matmul(q, k) # (head, batch, lm_seq_len, am_seq_len)
|
||||||
# print(attn_scores.shape)
|
# print(attn_scores.shape)
|
||||||
|
|
||||||
# use_pos_scores = False
|
use_pos_scores = False
|
||||||
# if torch.jit.is_scripting() or torch.jit.is_tracing():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
# # We can't put random.random() in the same line
|
# We can't put random.random() in the same line
|
||||||
# use_pos_scores = True
|
use_pos_scores = True
|
||||||
# elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
elif not self.training or random.random() >= float(self.pos_emb_skip_rate):
|
||||||
# use_pos_scores = True
|
use_pos_scores = True
|
||||||
|
|
||||||
# if use_pos_scores:
|
if use_pos_scores:
|
||||||
# pos_emb = self.linear_pos(pos_emb)
|
pos_emb = self.linear_pos(pos_emb)
|
||||||
|
|
||||||
# seq_len2 = 2 * am_seq_len - 1
|
seq_len2 = 2 * am_seq_len - 1
|
||||||
# pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim).permute(
|
pos_emb = pos_emb.reshape(-1, seq_len2, num_heads, pos_head_dim)
|
||||||
# 2, 0, 3, 1
|
pos_emb = pos_emb.permute(2, 0, 3, 1)
|
||||||
# )
|
# pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
||||||
# # pos shape now: (head, {1 or batch_size}, pos_dim, seq_len2)
|
|
||||||
|
|
||||||
# if for_reference:
|
if for_reference:
|
||||||
# pos_offset = am_seq_len - 1 - offset
|
pos_offset = am_seq_len - 1 - offset
|
||||||
# pos_emb = pos_emb[
|
pos_emb = pos_emb[
|
||||||
# :,
|
:,
|
||||||
# :,
|
:,
|
||||||
# :,
|
:,
|
||||||
# pos_offset : pos_offset + am_seq_len,
|
pos_offset : pos_offset + am_seq_len,
|
||||||
# ]
|
]
|
||||||
# # (head, 1, pos_dim, seq_len2)
|
# (head, 1, pos_dim, seq_len2)
|
||||||
# # (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
# (head, batch, time1, pos_dim) x (head, 1, pos_dim, seq_len2) -> (head, batch, time1, seq_len2)
|
||||||
# # [where seq_len2 represents relative position.]
|
# [where seq_len2 represents relative position.]
|
||||||
# pos_scores = torch.matmul(p, pos_emb)
|
pos_scores = torch.matmul(p, pos_emb)
|
||||||
# # print(pos_scores.shape, attn_scores.shape)
|
# print(pos_scores.shape, attn_scores.shape)
|
||||||
# # the following .as_strided() expression converts the last axis of pos_scores from relative
|
# the following .as_strided() expression converts the last axis of pos_scores from relative
|
||||||
# # to absolute position. I don't know whether I might have got the time-offsets backwards or
|
# to absolute position. I don't know whether I might have got the time-offsets backwards or
|
||||||
# # not, but let this code define which way round it is supposed to be.
|
# not, but let this code define which way round it is supposed to be.
|
||||||
# if torch.jit.is_tracing():
|
if torch.jit.is_tracing():
|
||||||
# (num_heads, b_p_dim, time1, n) = pos_scores.shape
|
(num_heads, b_p_dim, time1, n) = pos_scores.shape
|
||||||
# rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
rows = torch.arange(start=time1 - 1, end=-1, step=-1)
|
||||||
# cols = torch.arange(lm_seq_len)
|
cols = torch.arange(lm_seq_len)
|
||||||
# rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1)
|
rows = rows.repeat(b_p_dim * num_heads).unsqueeze(-1)
|
||||||
# indexes = rows + cols
|
indexes = rows + cols
|
||||||
# pos_scores = pos_scores.reshape(-1, n)
|
pos_scores = pos_scores.reshape(-1, n)
|
||||||
# pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
pos_scores = torch.gather(pos_scores, dim=1, index=indexes)
|
||||||
# pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len)
|
pos_scores = pos_scores.reshape(num_heads, b_p_dim, time1, lm_seq_len)
|
||||||
# else:
|
else:
|
||||||
# pos_scores = pos_scores.as_strided(
|
pos_scores = pos_scores.as_strided(
|
||||||
# (num_heads, b_p_dim, lm_seq_len, lm_seq_len),
|
(num_heads, b_p_dim, lm_seq_len, lm_seq_len),
|
||||||
# (
|
(
|
||||||
# pos_scores.stride(0),
|
pos_scores.stride(0),
|
||||||
# pos_scores.stride(1),
|
pos_scores.stride(1),
|
||||||
# pos_scores.stride(2) - pos_scores.stride(3),
|
pos_scores.stride(2) - pos_scores.stride(3),
|
||||||
# pos_scores.stride(3),
|
pos_scores.stride(3),
|
||||||
# ),
|
),
|
||||||
# storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
|
storage_offset=pos_scores.stride(3) * (lm_seq_len - 1),
|
||||||
# )
|
)
|
||||||
# # print(pos_scores.shape)
|
# print(pos_scores.shape)
|
||||||
# attn_scores = attn_scores + pos_scores
|
attn_scores = attn_scores + pos_scores
|
||||||
|
|
||||||
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
if torch.jit.is_scripting() or torch.jit.is_tracing():
|
||||||
pass
|
pass
|
||||||
|
@ -84,7 +84,7 @@ class AsrModel(nn.Module):
|
|||||||
|
|
||||||
self.encoder_embed = encoder_embed
|
self.encoder_embed = encoder_embed
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.dropout = nn.Dropout(p=0.1)
|
self.dropout = nn.Dropout(p=0.5)
|
||||||
|
|
||||||
self.use_transducer = use_transducer
|
self.use_transducer = use_transducer
|
||||||
if use_transducer:
|
if use_transducer:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user