mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Merge b49510e2bf7064f4f60650e6787288db1bad2941 into 6caff5fd38a09c231fbd728a2c4d3f3ac14e4455
This commit is contained in:
commit
0210d58965
@ -22,33 +22,51 @@ class Joiner(nn.Module):
|
|||||||
def __init__(self, input_dim: int, output_dim: int):
|
def __init__(self, input_dim: int, output_dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
self.output_linear = nn.Linear(input_dim, output_dim)
|
self.output_linear = nn.Linear(input_dim, output_dim)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self, encoder_out: torch.Tensor, decoder_out: torch.Tensor
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
decoder_out: torch.Tensor,
|
||||||
|
encoder_out_len: torch.Tensor,
|
||||||
|
decoder_out_len: torch.Tensor,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
encoder_out:
|
encoder_out:
|
||||||
Output from the encoder. Its shape is (N, T, C).
|
Output from the encoder. Its shape is (N, T, self.input_dim).
|
||||||
decoder_out:
|
decoder_out:
|
||||||
Output from the decoder. Its shape is (N, U, C).
|
Output from the decoder. Its shape is (N, U, self.input_dim).
|
||||||
Returns:
|
Returns:
|
||||||
Return a tensor of shape (N, T, U, C).
|
Return a tensor of shape (sum_all_TU, self.output_dim).
|
||||||
"""
|
"""
|
||||||
assert encoder_out.ndim == decoder_out.ndim == 3
|
assert encoder_out.ndim == decoder_out.ndim == 3
|
||||||
assert encoder_out.size(0) == decoder_out.size(0)
|
assert encoder_out.size(0) == decoder_out.size(0)
|
||||||
assert encoder_out.size(2) == decoder_out.size(2)
|
assert encoder_out.size(2) == self.input_dim
|
||||||
|
assert decoder_out.size(2) == self.input_dim
|
||||||
|
|
||||||
encoder_out = encoder_out.unsqueeze(2)
|
N = encoder_out.size(0)
|
||||||
# Now encoder_out is (N, T, 1, C)
|
|
||||||
|
|
||||||
decoder_out = decoder_out.unsqueeze(1)
|
encoder_out_list = [
|
||||||
# Now decoder_out is (N, 1, U, C)
|
encoder_out[i, : encoder_out_len[i], :] for i in range(N)
|
||||||
|
]
|
||||||
|
|
||||||
logit = encoder_out + decoder_out
|
decoder_out_list = [
|
||||||
logit = torch.tanh(logit)
|
decoder_out[i, : decoder_out_len[i], :] for i in range(N)
|
||||||
|
]
|
||||||
|
|
||||||
output = self.output_linear(logit)
|
x = [
|
||||||
|
e.unsqueeze(1) + d.unsqueeze(0)
|
||||||
|
for e, d in zip(encoder_out_list, decoder_out_list)
|
||||||
|
]
|
||||||
|
|
||||||
return output
|
x = [p.reshape(-1, self.input_dim) for p in x]
|
||||||
|
x = torch.cat(x)
|
||||||
|
|
||||||
|
activations = torch.tanh(x)
|
||||||
|
|
||||||
|
logits = self.output_linear(activations)
|
||||||
|
|
||||||
|
return logits
|
||||||
|
@ -14,20 +14,71 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
import math
|
||||||
Note we use `rnnt_loss` from torchaudio, which exists only in
|
|
||||||
torchaudio >= v0.10.0. It also means you have to use torch >= v1.10.0
|
|
||||||
"""
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchaudio
|
|
||||||
import torchaudio.functional
|
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
from icefall.utils import add_sos
|
from icefall.utils import add_sos
|
||||||
|
|
||||||
|
|
||||||
|
def reverse_label_smoothing(
|
||||||
|
logprobs: torch.Tensor, alpha: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
This function is written by Dan.
|
||||||
|
|
||||||
|
Modifies `logprobs` in such a way that if you compute a data probability
|
||||||
|
using `logprobs`, it will be equivalent to a label-smoothed data probability
|
||||||
|
with the supplied label-smoothing constant alpha (e.g. alpha=0.1).
|
||||||
|
This allows us to use `logprobs` in things like RNN-T and CTC and
|
||||||
|
get a kind of label-smoothed version of those sequence objectives.
|
||||||
|
|
||||||
|
Label smoothing means that if the reference label is i, we convert it
|
||||||
|
into a distribution with weight (1-alpha) on i, and alpha distributed
|
||||||
|
equally to all labels (including i itself).
|
||||||
|
|
||||||
|
Note: the output logprobs can be interpreted as cross-entropies, meaning
|
||||||
|
we correct for the entropy of the smoothed distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logprobs:
|
||||||
|
A Tensor of shape (*, num_classes), containing logprobs that sum
|
||||||
|
to one: e.g. the output of log_softmax.
|
||||||
|
alpha:
|
||||||
|
A constant that defines the extent of label smoothing, e.g. 0.1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
modified_logprobs, a Tensor of shape (*, num_classes), containing
|
||||||
|
"fake" logprobs that will give you label-smoothed probabilities.
|
||||||
|
"""
|
||||||
|
assert alpha >= 0.0 and alpha < 1
|
||||||
|
if alpha == 0.0:
|
||||||
|
return logprobs
|
||||||
|
num_classes = logprobs.shape[-1]
|
||||||
|
|
||||||
|
# We correct for the entropy of the label-smoothed target distribution, so
|
||||||
|
# the resulting logprobs can be thought of as cross-entropies, which are
|
||||||
|
# more interpretable.
|
||||||
|
#
|
||||||
|
# The expression for entropy below is not quite correct -- it treats
|
||||||
|
# the target label and the smoothed version of the target label as being
|
||||||
|
# separate classes -- but this can be thought of as an adjustment
|
||||||
|
# for the way we compute the likelihood below, which also treats the
|
||||||
|
# target label and its smoothed version as being separate.
|
||||||
|
target_entropy = -(
|
||||||
|
(1 - alpha) * math.log(1 - alpha)
|
||||||
|
+ alpha * math.log(alpha / num_classes)
|
||||||
|
)
|
||||||
|
sum_logprob = logprobs.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
return (
|
||||||
|
logprobs * (1 - alpha) + sum_logprob * (alpha / num_classes)
|
||||||
|
) + target_entropy
|
||||||
|
|
||||||
|
|
||||||
class Transducer(nn.Module):
|
class Transducer(nn.Module):
|
||||||
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
"""It implements https://arxiv.org/pdf/1211.3711.pdf
|
||||||
"Sequence Transduction with Recurrent Neural Networks"
|
"Sequence Transduction with Recurrent Neural Networks"
|
||||||
@ -68,6 +119,7 @@ class Transducer(nn.Module):
|
|||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
x_lens: torch.Tensor,
|
x_lens: torch.Tensor,
|
||||||
y: k2.RaggedTensor,
|
y: k2.RaggedTensor,
|
||||||
|
label_smoothing_factor: float,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -79,6 +131,8 @@ class Transducer(nn.Module):
|
|||||||
y:
|
y:
|
||||||
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each
|
||||||
utterance.
|
utterance.
|
||||||
|
label_smoothing_factor:
|
||||||
|
The factor for label smoothing. Should be in the range [0, 1).
|
||||||
Returns:
|
Returns:
|
||||||
Return the transducer loss.
|
Return the transducer loss.
|
||||||
"""
|
"""
|
||||||
@ -102,24 +156,35 @@ class Transducer(nn.Module):
|
|||||||
|
|
||||||
decoder_out = self.decoder(sos_y_padded)
|
decoder_out = self.decoder(sos_y_padded)
|
||||||
|
|
||||||
logits = self.joiner(encoder_out, decoder_out)
|
# +1 here since a blank is prepended to each utterance.
|
||||||
|
logits = self.joiner(
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
decoder_out=decoder_out,
|
||||||
|
encoder_out_len=x_lens,
|
||||||
|
decoder_out_len=y_lens + 1,
|
||||||
|
)
|
||||||
|
# logits is of shape (sum_all_TU, vocab_size)
|
||||||
|
|
||||||
|
log_probs = logits.log_softmax(dim=-1)
|
||||||
|
log_probs = reverse_label_smoothing(log_probs, label_smoothing_factor)
|
||||||
|
|
||||||
# rnnt_loss requires 0 padded targets
|
# rnnt_loss requires 0 padded targets
|
||||||
# Note: y does not start with SOS
|
# Note: y does not start with SOS
|
||||||
y_padded = y.pad(mode="constant", padding_value=0)
|
y_padded = y.pad(mode="constant", padding_value=0)
|
||||||
|
|
||||||
assert hasattr(torchaudio.functional, "rnnt_loss"), (
|
# We don't put this `import` at the beginning of the file
|
||||||
f"Current torchaudio version: {torchaudio.__version__}\n"
|
# as it is required only in the training, not during the
|
||||||
"Please install a version >= 0.10.0"
|
# reference stage
|
||||||
)
|
import optimized_transducer
|
||||||
|
|
||||||
loss = torchaudio.functional.rnnt_loss(
|
loss = optimized_transducer.transducer_loss(
|
||||||
logits=logits,
|
logits=log_probs,
|
||||||
targets=y_padded,
|
targets=y_padded,
|
||||||
logit_lengths=x_lens,
|
logit_lengths=x_lens,
|
||||||
target_lengths=y_lens,
|
target_lengths=y_lens,
|
||||||
blank=blank_id,
|
blank=blank_id,
|
||||||
reduction="sum",
|
reduction="sum",
|
||||||
|
from_log_softmax=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
@ -138,6 +138,13 @@ def get_parser():
|
|||||||
"2 means tri-gram",
|
"2 means tri-gram",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--label-smoothing-factor",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="The factor for label smoothing",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -383,7 +390,12 @@ def compute_loss(
|
|||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
loss = model(x=feature, x_lens=feature_lens, y=y)
|
loss = model(
|
||||||
|
x=feature,
|
||||||
|
x_lens=feature_lens,
|
||||||
|
y=y,
|
||||||
|
label_smoothing_factor=params.label_smoothing_factor,
|
||||||
|
)
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user