mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
minor updates
This commit is contained in:
parent
70d603dc28
commit
3631361b95
@ -907,9 +907,9 @@ def deprecated_greedy_search_batch_for_cross_attn(
|
|||||||
logits = model.joiner(
|
logits = model.joiner(
|
||||||
current_encoder_out,
|
current_encoder_out,
|
||||||
decoder_out.unsqueeze(1),
|
decoder_out.unsqueeze(1),
|
||||||
attn_encoder_out if t > 0 else torch.zeros_like(current_encoder_out),
|
attn_encoder_out if t < 0 else torch.zeros_like(current_encoder_out),
|
||||||
None,
|
encoder_out_lens,
|
||||||
apply_attn=True,
|
apply_attn=False,
|
||||||
project_input=False,
|
project_input=False,
|
||||||
)
|
)
|
||||||
# logits'shape (batch_size, 1, 1, vocab_size)
|
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||||
|
@ -14,6 +14,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from alignment_attention_module import AlignmentAttentionModule
|
from alignment_attention_module import AlignmentAttentionModule
|
||||||
@ -34,6 +36,7 @@ class Joiner(nn.Module):
|
|||||||
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
|
self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim, initial_scale=0.25)
|
||||||
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
|
self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim, initial_scale=0.25)
|
||||||
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
self.output_linear = nn.Linear(joiner_dim, vocab_size)
|
||||||
|
self.enable_attn = False
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@ -64,7 +67,10 @@ class Joiner(nn.Module):
|
|||||||
decoder_out.shape,
|
decoder_out.shape,
|
||||||
)
|
)
|
||||||
|
|
||||||
if apply_attn and lengths is not None:
|
if apply_attn:
|
||||||
|
if not self.enable_attn:
|
||||||
|
self.enable_attn = True
|
||||||
|
logging.info("enabling ATTN!")
|
||||||
attn_encoder_out = self.label_level_am_attention(
|
attn_encoder_out = self.label_level_am_attention(
|
||||||
encoder_out, decoder_out, lengths
|
encoder_out, decoder_out, lengths
|
||||||
)
|
)
|
||||||
@ -72,7 +78,11 @@ class Joiner(nn.Module):
|
|||||||
if project_input:
|
if project_input:
|
||||||
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out)
|
||||||
else:
|
else:
|
||||||
|
if apply_attn:
|
||||||
logit = encoder_out + decoder_out + attn_encoder_out
|
logit = encoder_out + decoder_out + attn_encoder_out
|
||||||
|
else:
|
||||||
|
# logging.info("disabling cross attn mdl")
|
||||||
|
logit = encoder_out + decoder_out
|
||||||
|
|
||||||
logit = self.output_linear(torch.tanh(logit))
|
logit = self.output_linear(torch.tanh(logit))
|
||||||
|
|
||||||
|
@ -24,12 +24,13 @@ import torch.nn as nn
|
|||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
from scaling import ScaledLinear
|
from scaling import ScaledLinear
|
||||||
|
|
||||||
from icefall.utils import add_sos, make_pad_mask
|
from icefall.utils import add_sos, make_pad_mask, AttributeDict
|
||||||
|
|
||||||
|
|
||||||
class AsrModel(nn.Module):
|
class AsrModel(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
params: AttributeDict,
|
||||||
encoder_embed: nn.Module,
|
encoder_embed: nn.Module,
|
||||||
encoder: EncoderInterface,
|
encoder: EncoderInterface,
|
||||||
decoder: Optional[nn.Module] = None,
|
decoder: Optional[nn.Module] = None,
|
||||||
@ -79,6 +80,8 @@ class AsrModel(nn.Module):
|
|||||||
|
|
||||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
|
|
||||||
|
self.params = params
|
||||||
|
|
||||||
self.encoder_embed = encoder_embed
|
self.encoder_embed = encoder_embed
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
|
|
||||||
@ -180,6 +183,7 @@ class AsrModel(nn.Module):
|
|||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
|
batch_idx_train: int = 0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Compute Transducer loss.
|
"""Compute Transducer loss.
|
||||||
Args:
|
Args:
|
||||||
@ -264,12 +268,13 @@ class AsrModel(nn.Module):
|
|||||||
|
|
||||||
# project_input=False since we applied the decoder's input projections
|
# project_input=False since we applied the decoder's input projections
|
||||||
# prior to do_rnnt_pruning (this is an optimization for speed).
|
# prior to do_rnnt_pruning (this is an optimization for speed).
|
||||||
|
# print(batch_idx_train)
|
||||||
logits = self.joiner(
|
logits = self.joiner(
|
||||||
am_pruned,
|
am_pruned,
|
||||||
lm_pruned,
|
lm_pruned,
|
||||||
None,
|
None,
|
||||||
encoder_out_lens,
|
encoder_out_lens,
|
||||||
apply_attn=True,
|
apply_attn=batch_idx_train > self.params.warm_step, # True, # batch_idx_train > self.params.warm_step,
|
||||||
project_input=False,
|
project_input=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -293,6 +298,7 @@ class AsrModel(nn.Module):
|
|||||||
prune_range: int = 5,
|
prune_range: int = 5,
|
||||||
am_scale: float = 0.0,
|
am_scale: float = 0.0,
|
||||||
lm_scale: float = 0.0,
|
lm_scale: float = 0.0,
|
||||||
|
batch_idx_train: int = 0,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -345,6 +351,7 @@ class AsrModel(nn.Module):
|
|||||||
prune_range=prune_range,
|
prune_range=prune_range,
|
||||||
am_scale=am_scale,
|
am_scale=am_scale,
|
||||||
lm_scale=lm_scale,
|
lm_scale=lm_scale,
|
||||||
|
batch_idx_train=batch_idx_train,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
simple_loss = torch.empty(0)
|
simple_loss = torch.empty(0)
|
||||||
|
@ -622,6 +622,7 @@ def get_model(params: AttributeDict) -> nn.Module:
|
|||||||
joiner = None
|
joiner = None
|
||||||
|
|
||||||
model = AsrModel(
|
model = AsrModel(
|
||||||
|
params=params,
|
||||||
encoder_embed=encoder_embed,
|
encoder_embed=encoder_embed,
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
@ -800,6 +801,7 @@ def compute_loss(
|
|||||||
prune_range=params.prune_range,
|
prune_range=params.prune_range,
|
||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
|
batch_idx_train=batch_idx_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = 0.0
|
loss = 0.0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user