From 78b8792d1d3b15008378b0e38d533a77b456bbbd Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 8 Apr 2022 13:41:33 +0800 Subject: [PATCH] Fix potential bugs in PyTorch that exist in label_smoothing. (#300) --- .../ASR/conformer_ctc/label_smoothing.py | 99 +------------------ .../ASR/conformer_mmi/label_smoothing.py | 99 +------------------ .../ASR/conformer_ctc/label_smoothing.py | 17 +++- 3 files changed, 17 insertions(+), 198 deletions(-) mode change 100644 => 120000 egs/aishell/ASR/conformer_ctc/label_smoothing.py mode change 100644 => 120000 egs/aishell/ASR/conformer_mmi/label_smoothing.py diff --git a/egs/aishell/ASR/conformer_ctc/label_smoothing.py b/egs/aishell/ASR/conformer_ctc/label_smoothing.py deleted file mode 100644 index cdc85ce9a..000000000 --- a/egs/aishell/ASR/conformer_ctc/label_smoothing.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - - -class LabelSmoothingLoss(torch.nn.Module): - """ - Implement the LabelSmoothingLoss proposed in the following paper - https://arxiv.org/pdf/1512.00567.pdf - (Rethinking the Inception Architecture for Computer Vision) - - """ - - def __init__( - self, - ignore_index: int = -1, - label_smoothing: float = 0.1, - reduction: str = "sum", - ) -> None: - """ - Args: - ignore_index: - ignored class id - label_smoothing: - smoothing rate (0.0 means the conventional cross entropy loss) - reduction: - It has the same meaning as the reduction in - `torch.nn.CrossEntropyLoss`. It can be one of the following three - values: (1) "none": No reduction will be applied. (2) "mean": the - mean of the output is taken. (3) "sum": the output will be summed. - """ - super().__init__() - assert 0.0 <= label_smoothing < 1.0 - self.ignore_index = ignore_index - self.label_smoothing = label_smoothing - self.reduction = reduction - - def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Compute loss between x and target. - - Args: - x: - prediction of dimension - (batch_size, input_length, number_of_classes). - target: - target masked with self.ignore_index of - dimension (batch_size, input_length). - - Returns: - A scalar tensor containing the loss without normalization. - """ - assert x.ndim == 3 - assert target.ndim == 2 - assert x.shape[:2] == target.shape - num_classes = x.size(-1) - x = x.reshape(-1, num_classes) - # Now x is of shape (N*T, C) - - # We don't want to change target in-place below, - # so we make a copy of it here - target = target.clone().reshape(-1) - - ignored = target == self.ignore_index - target[ignored] = 0 - - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) - - true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes - ) - # Set the value of ignored indexes to 0 - true_dist[ignored] = 0 - - loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) - if self.reduction == "sum": - return loss.sum() - elif self.reduction == "mean": - return loss.sum() / (~ignored).sum() - else: - return loss.sum(dim=-1) diff --git a/egs/aishell/ASR/conformer_ctc/label_smoothing.py b/egs/aishell/ASR/conformer_ctc/label_smoothing.py new file mode 120000 index 000000000..e9d239fff --- /dev/null +++ b/egs/aishell/ASR/conformer_ctc/label_smoothing.py @@ -0,0 +1 @@ +../../../librispeech/ASR/conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/aishell/ASR/conformer_mmi/label_smoothing.py b/egs/aishell/ASR/conformer_mmi/label_smoothing.py deleted file mode 100644 index cdc85ce9a..000000000 --- a/egs/aishell/ASR/conformer_mmi/label_smoothing.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) -# -# See ../../../../LICENSE for clarification regarding multiple authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch - - -class LabelSmoothingLoss(torch.nn.Module): - """ - Implement the LabelSmoothingLoss proposed in the following paper - https://arxiv.org/pdf/1512.00567.pdf - (Rethinking the Inception Architecture for Computer Vision) - - """ - - def __init__( - self, - ignore_index: int = -1, - label_smoothing: float = 0.1, - reduction: str = "sum", - ) -> None: - """ - Args: - ignore_index: - ignored class id - label_smoothing: - smoothing rate (0.0 means the conventional cross entropy loss) - reduction: - It has the same meaning as the reduction in - `torch.nn.CrossEntropyLoss`. It can be one of the following three - values: (1) "none": No reduction will be applied. (2) "mean": the - mean of the output is taken. (3) "sum": the output will be summed. - """ - super().__init__() - assert 0.0 <= label_smoothing < 1.0 - self.ignore_index = ignore_index - self.label_smoothing = label_smoothing - self.reduction = reduction - - def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Compute loss between x and target. - - Args: - x: - prediction of dimension - (batch_size, input_length, number_of_classes). - target: - target masked with self.ignore_index of - dimension (batch_size, input_length). - - Returns: - A scalar tensor containing the loss without normalization. - """ - assert x.ndim == 3 - assert target.ndim == 2 - assert x.shape[:2] == target.shape - num_classes = x.size(-1) - x = x.reshape(-1, num_classes) - # Now x is of shape (N*T, C) - - # We don't want to change target in-place below, - # so we make a copy of it here - target = target.clone().reshape(-1) - - ignored = target == self.ignore_index - target[ignored] = 0 - - true_dist = torch.nn.functional.one_hot( - target, num_classes=num_classes - ).to(x) - - true_dist = ( - true_dist * (1 - self.label_smoothing) - + self.label_smoothing / num_classes - ) - # Set the value of ignored indexes to 0 - true_dist[ignored] = 0 - - loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) - if self.reduction == "sum": - return loss.sum() - elif self.reduction == "mean": - return loss.sum() / (~ignored).sum() - else: - return loss.sum(dim=-1) diff --git a/egs/aishell/ASR/conformer_mmi/label_smoothing.py b/egs/aishell/ASR/conformer_mmi/label_smoothing.py new file mode 120000 index 000000000..08734abd7 --- /dev/null +++ b/egs/aishell/ASR/conformer_mmi/label_smoothing.py @@ -0,0 +1 @@ +../conformer_ctc/label_smoothing.py \ No newline at end of file diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py index cdc85ce9a..1f2f3b137 100644 --- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py @@ -76,7 +76,11 @@ class LabelSmoothingLoss(torch.nn.Module): target = target.clone().reshape(-1) ignored = target == self.ignore_index - target[ignored] = 0 + + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use target[ignored] = 0 here + target = torch.where(ignored, torch.zeros_like(target), target) true_dist = torch.nn.functional.one_hot( target, num_classes=num_classes @@ -86,8 +90,17 @@ class LabelSmoothingLoss(torch.nn.Module): true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes ) + # Set the value of ignored indexes to 0 - true_dist[ignored] = 0 + # + # See https://github.com/k2-fsa/icefall/issues/240 + # and https://github.com/k2-fsa/icefall/issues/297 + # for why we don't use true_dist[ignored] = 0 here + true_dist = torch.where( + ignored.unsqueeze(1).repeat(1, true_dist.shape[1]), + torch.zeros_like(true_dist), + true_dist, + ) loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) if self.reduction == "sum":