mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
update docs
This commit is contained in:
parent
08729f88b1
commit
3e3c1a6aee
@ -38,7 +38,7 @@ class Decoder(nn.Module):
|
|||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
blank_id: int,
|
blank_id: int,
|
||||||
context_size: int,
|
context_size: int,
|
||||||
backward: bool = False,
|
use_right_context: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -51,6 +51,9 @@ class Decoder(nn.Module):
|
|||||||
context_size:
|
context_size:
|
||||||
Number of previous words to use to predict the next word.
|
Number of previous words to use to predict the next word.
|
||||||
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
1 means bigram; 2 means trigram. n means (n+1)-gram.
|
||||||
|
use_right_context:
|
||||||
|
True to use right context, which is usefull to implement a
|
||||||
|
backward decoder, only used for training.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.embedding = nn.Embedding(
|
self.embedding = nn.Embedding(
|
||||||
@ -62,7 +65,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
assert context_size >= 1, context_size
|
assert context_size >= 1, context_size
|
||||||
self.context_size = context_size
|
self.context_size = context_size
|
||||||
self.backward = backward
|
self.use_right_context = use_right_context
|
||||||
if context_size > 1:
|
if context_size > 1:
|
||||||
self.conv = nn.Conv1d(
|
self.conv = nn.Conv1d(
|
||||||
in_channels=embedding_dim,
|
in_channels=embedding_dim,
|
||||||
@ -88,14 +91,20 @@ class Decoder(nn.Module):
|
|||||||
if self.context_size > 1:
|
if self.context_size > 1:
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
if need_pad is True:
|
if need_pad is True:
|
||||||
# If the input is [sos, a, b, c, d] and output is
|
# Regarding the left or right context we are using,
|
||||||
# [a, b, c, d, eos], padding left and using kernel-size 2,
|
# if we feed sequence [sos, a, b, c, d] to this decoder, and
|
||||||
# it uses left context.
|
# want to predict the sequence [a, b, c, d]. After padding to
|
||||||
# If the input is [a, b, c, d, eos] and output is
|
# the left with context_size==2, the fed in sequence changes to
|
||||||
# [sos, a, b, c, d], padding right and using kernel-size 2,
|
# [pad, sos, a, b, c, d], and we use `pad,sos` to predict `a`,
|
||||||
# it uses right context.
|
# `sos,a` to predict `b` ..., that is left context.
|
||||||
if self.backward:
|
# if we feed sequence [b, c, d, blk, blk] to this decoder,
|
||||||
assert self.context_size == 2
|
# and want to predict the sequence [a, b, c, d]. After padding
|
||||||
|
# to the right with context_size==2, the fed in sequence changes
|
||||||
|
# to [b, c, d, blk, blk, pad], and we use `b, c` to predict `a`
|
||||||
|
# `c,d` to predict `b` ..., that is right context.
|
||||||
|
# This is tricky and not so straightforward, will find better
|
||||||
|
# implementation later.
|
||||||
|
if self.use_right_context:
|
||||||
embedding_out = F.pad(
|
embedding_out = F.pad(
|
||||||
embedding_out, pad=(0, self.context_size - 1)
|
embedding_out, pad=(0, self.context_size - 1)
|
||||||
)
|
)
|
||||||
@ -107,7 +116,7 @@ class Decoder(nn.Module):
|
|||||||
# During inference time, there is no need to do extra padding
|
# During inference time, there is no need to do extra padding
|
||||||
# as we only need one output
|
# as we only need one output
|
||||||
assert embedding_out.size(-1) == self.context_size
|
assert embedding_out.size(-1) == self.context_size
|
||||||
assert self.backward is False
|
assert self.use_right_context is False
|
||||||
embedding_out = self.conv(embedding_out)
|
embedding_out = self.conv(embedding_out)
|
||||||
embedding_out = embedding_out.permute(0, 2, 1)
|
embedding_out = embedding_out.permute(0, 2, 1)
|
||||||
return embedding_out
|
return embedding_out
|
||||||
|
@ -21,7 +21,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
from icefall.utils import add_eos, add_sos
|
from icefall.utils import add_sos
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
@ -124,11 +124,14 @@ class Transducer(nn.Module):
|
|||||||
blank_id = self.decoder.blank_id
|
blank_id = self.decoder.blank_id
|
||||||
sos_y = add_sos(y, sos_id=blank_id)
|
sos_y = add_sos(y, sos_id=blank_id)
|
||||||
|
|
||||||
|
# sos_y_padded: [B, S + 1], start with SOS.
|
||||||
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id)
|
||||||
|
|
||||||
|
# decoder_out: [B, S + 1, C]
|
||||||
decoder_out = self.decoder(sos_y_padded)
|
decoder_out = self.decoder(sos_y_padded)
|
||||||
|
|
||||||
# Note: y does not start with SOS
|
# Note: y does not start with SOS
|
||||||
|
# y_padded : [B, S]
|
||||||
y_padded = y.pad(mode="constant", padding_value=0)
|
y_padded = y.pad(mode="constant", padding_value=0)
|
||||||
|
|
||||||
boundary = torch.zeros(
|
boundary = torch.zeros(
|
||||||
@ -148,33 +151,49 @@ class Transducer(nn.Module):
|
|||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
return_grad=True,
|
return_grad=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ranges : [B, T, prune_range]
|
||||||
ranges = k2.get_rnnt_prune_ranges(
|
ranges = k2.get_rnnt_prune_ranges(
|
||||||
px_grad, py_grad, boundary, self.prune_range
|
px_grad, py_grad, boundary, self.prune_range
|
||||||
)
|
)
|
||||||
|
|
||||||
# forward loss
|
# forward loss
|
||||||
|
# am_pruned : [B, T, prune_range, C]
|
||||||
|
# lm_pruned : [B, T, prune_range, C]
|
||||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||||
encoder_out, decoder_out, ranges
|
encoder_out, decoder_out, ranges
|
||||||
)
|
)
|
||||||
|
# logits : [B, T, prune_range, C]
|
||||||
logits = self.joiner(am_pruned, lm_pruned)
|
logits = self.joiner(am_pruned, lm_pruned)
|
||||||
|
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
logits, y_padded.to(torch.int64), ranges, blank_id, boundary
|
logits, y_padded.to(torch.int64), ranges, blank_id, boundary
|
||||||
)
|
)
|
||||||
|
|
||||||
eos_y = add_eos(y, eos_id=blank_id)
|
# y_padded shape : [B, S]
|
||||||
eos_y_padded = eos_y.pad(mode="constant", padding_value=blank_id)
|
# we skip the first symbol(a shift trick for right context),
|
||||||
eos_y_padded = F.pad(eos_y_padded[:, 1:], pad=(0, 1), value=blank_id)
|
# so we have to pad 2 blank to the right to make the output shape of
|
||||||
|
# deocder to be [B, S + 1, C],
|
||||||
|
# backward_y shape : [B, S + 1]
|
||||||
|
backward_y = F.pad(y_padded[:, 1:], pad=(0, 2), value=blank_id)
|
||||||
|
|
||||||
# backward loss
|
# backward loss
|
||||||
assert self.backward_decoder is not None
|
assert self.backward_decoder is not None
|
||||||
assert self.backward_joiner is not None
|
assert self.backward_joiner is not None
|
||||||
backward_decoder_out = self.backward_decoder(eos_y_padded)
|
# backward_decoder_out : [B, S + 1, C]
|
||||||
|
backward_decoder_out = self.backward_decoder(backward_y)
|
||||||
|
|
||||||
|
# backward_am_pruned : [B, T, prune_range, C]
|
||||||
|
# backward_lm_pruned : [B, T, prune_range, C]
|
||||||
backward_am_pruned, backward_lm_pruned = k2.do_rnnt_pruning(
|
backward_am_pruned, backward_lm_pruned = k2.do_rnnt_pruning(
|
||||||
encoder_out, backward_decoder_out, ranges
|
encoder_out, backward_decoder_out, ranges
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# backward_logits : [B, T, prune_range, C]
|
||||||
backward_logits = self.backward_joiner(
|
backward_logits = self.backward_joiner(
|
||||||
backward_am_pruned, backward_lm_pruned
|
backward_am_pruned, backward_lm_pruned
|
||||||
)
|
)
|
||||||
|
|
||||||
backward_pruned_loss = k2.rnnt_loss_pruned(
|
backward_pruned_loss = k2.rnnt_loss_pruned(
|
||||||
backward_logits,
|
backward_logits,
|
||||||
y_padded.to(torch.int64),
|
y_padded.to(torch.int64),
|
||||||
|
Loading…
x
Reference in New Issue
Block a user