From 336283f8721c4f38a14b139b0fbbe2cf5893f4f7 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 17 Nov 2021 19:24:07 +0800 Subject: [PATCH] New label smoothing (#109) * Modify label smoothing to match the one implemented in PyTorch. * Enable CI for torch 1.10 * Fix CI errors. * Fix CI installation errors. * Fix CI installation errors. * Minor fixes. * Minor fixes. * Minor fixes. * Minor fixes. * Minor fixes. * Fix CI errors. --- .github/workflows/run-pretrained.yml | 9 +- .github/workflows/run-yesno-recipe.yml | 6 +- .github/workflows/test.yml | 28 +++++- .../ASR/conformer_ctc/label_smoothing.py | 98 +++++++++++++++++++ .../ASR/conformer_ctc/test_label_smoothing.py | 52 ++++++++++ .../ASR/conformer_ctc/transformer.py | 70 +------------ 6 files changed, 185 insertions(+), 78 deletions(-) create mode 100644 egs/librispeech/ASR/conformer_ctc/label_smoothing.py create mode 100755 egs/librispeech/ASR/conformer_ctc/test_label_smoothing.py diff --git a/.github/workflows/run-pretrained.yml b/.github/workflows/run-pretrained.yml index 97d3c32d2..710ca2603 100644 --- a/.github/workflows/run-pretrained.yml +++ b/.github/workflows/run-pretrained.yml @@ -31,8 +31,9 @@ jobs: matrix: os: [ubuntu-18.04] python-version: [3.6, 3.7, 3.8, 3.9] - torch: ["1.8.1"] - k2-version: ["1.9.dev20210919"] + torch: ["1.10.0"] + torchaudio: ["0.10.0"] + k2-version: ["1.9.dev20211101"] fail-fast: false @@ -49,7 +50,9 @@ jobs: - name: Install Python dependencies run: | python3 -m pip install --upgrade pip pytest - pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html + # numpy 1.20.x does not support python 3.6 + pip install numpy==1.19 + pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ python3 -m pip install git+https://github.com/lhotse-speech/lhotse diff --git a/.github/workflows/run-yesno-recipe.yml b/.github/workflows/run-yesno-recipe.yml index 876b95e71..98b2e4ebd 100644 --- a/.github/workflows/run-yesno-recipe.yml +++ b/.github/workflows/run-yesno-recipe.yml @@ -33,8 +33,9 @@ jobs: # TODO: enable macOS for CPU testing os: [ubuntu-18.04] python-version: [3.8] - torch: ["1.8.1"] - k2-version: ["1.9.dev20210919"] + torch: ["1.10.0"] + torchaudio: ["0.10.0"] + k2-version: ["1.9.dev20211101"] fail-fast: false steps: @@ -57,6 +58,7 @@ jobs: - name: Install Python dependencies run: | python3 -m pip install -U pip + pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ python3 -m pip install git+https://github.com/lhotse-speech/lhotse diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b5c8cfcfa..e897c3fb5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -33,8 +33,14 @@ jobs: # disable macOS test for now. os: [ubuntu-18.04] python-version: [3.6, 3.7, 3.8, 3.9] - torch: ["1.8.1"] - k2-version: ["1.9.dev20210919"] + torch: ["1.8.0", "1.10.0"] + torchaudio: ["0.8.0", "0.10.0"] + k2-version: ["1.9.dev20211101"] + exclude: + - torch: "1.8.0" + torchaudio: "0.10.0" + - torch: "1.10.0" + torchaudio: "0.8.0" fail-fast: false @@ -58,6 +64,15 @@ jobs: - name: Install Python dependencies run: | python3 -m pip install --upgrade pip pytest + # numpy 1.20.x does not support python 3.6 + pip install numpy==1.19 + pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html + if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then + pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html + else + pip install torchaudio==${{ matrix.torchaudio }} + fi + pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ pip install git+https://github.com/lhotse-speech/lhotse # icefall requirements @@ -83,7 +98,10 @@ jobs: ls -lh export PYTHONPATH=$PWD:$PWD/lhotse:$PYTHONPATH echo $PYTHONPATH - pytest ./test + pytest -v -s ./test + # runt tests for conformer ctc + cd egs/librispeech/ASR/conformer_ctc + pytest -v -s - name: Run tests if: startsWith(matrix.os, 'macos') @@ -93,8 +111,8 @@ jobs: lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())") echo "lib_path: $lib_path" export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH - pytest ./test + pytest -v -s ./test # runt tests for conformer ctc cd egs/librispeech/ASR/conformer_ctc - pytest + pytest -v -s diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py new file mode 100644 index 000000000..cdc85ce9a --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py @@ -0,0 +1,98 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +class LabelSmoothingLoss(torch.nn.Module): + """ + Implement the LabelSmoothingLoss proposed in the following paper + https://arxiv.org/pdf/1512.00567.pdf + (Rethinking the Inception Architecture for Computer Vision) + + """ + + def __init__( + self, + ignore_index: int = -1, + label_smoothing: float = 0.1, + reduction: str = "sum", + ) -> None: + """ + Args: + ignore_index: + ignored class id + label_smoothing: + smoothing rate (0.0 means the conventional cross entropy loss) + reduction: + It has the same meaning as the reduction in + `torch.nn.CrossEntropyLoss`. It can be one of the following three + values: (1) "none": No reduction will be applied. (2) "mean": the + mean of the output is taken. (3) "sum": the output will be summed. + """ + super().__init__() + assert 0.0 <= label_smoothing < 1.0 + self.ignore_index = ignore_index + self.label_smoothing = label_smoothing + self.reduction = reduction + + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """ + Compute loss between x and target. + + Args: + x: + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.ignore_index of + dimension (batch_size, input_length). + + Returns: + A scalar tensor containing the loss without normalization. + """ + assert x.ndim == 3 + assert target.ndim == 2 + assert x.shape[:2] == target.shape + num_classes = x.size(-1) + x = x.reshape(-1, num_classes) + # Now x is of shape (N*T, C) + + # We don't want to change target in-place below, + # so we make a copy of it here + target = target.clone().reshape(-1) + + ignored = target == self.ignore_index + target[ignored] = 0 + + true_dist = torch.nn.functional.one_hot( + target, num_classes=num_classes + ).to(x) + + true_dist = ( + true_dist * (1 - self.label_smoothing) + + self.label_smoothing / num_classes + ) + # Set the value of ignored indexes to 0 + true_dist[ignored] = 0 + + loss = -1 * (torch.log_softmax(x, dim=1) * true_dist) + if self.reduction == "sum": + return loss.sum() + elif self.reduction == "mean": + return loss.sum() / (~ignored).sum() + else: + return loss.sum(dim=-1) diff --git a/egs/librispeech/ASR/conformer_ctc/test_label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/test_label_smoothing.py new file mode 100755 index 000000000..5d4438fd1 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/test_label_smoothing.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distutils.version import LooseVersion + +import torch +from label_smoothing import LabelSmoothingLoss + +torch_ver = LooseVersion(torch.__version__) + + +def test_with_torch_label_smoothing_loss(): + if torch_ver < LooseVersion("1.10.0"): + print(f"Current torch version: {torch_ver}") + print("Please use torch >= 1.10 to run this test - skipping") + return + torch.manual_seed(20211105) + x = torch.rand(20, 30, 5000) + tgt = torch.randint(low=-1, high=x.size(-1), size=x.shape[:2]) + for reduction in ["none", "sum", "mean"]: + custom_loss_func = LabelSmoothingLoss( + ignore_index=-1, label_smoothing=0.1, reduction=reduction + ) + custom_loss = custom_loss_func(x, tgt) + + torch_loss_func = torch.nn.CrossEntropyLoss( + ignore_index=-1, reduction=reduction, label_smoothing=0.1 + ) + torch_loss = torch_loss_func(x.reshape(-1, x.size(-1)), tgt.reshape(-1)) + assert torch.allclose(custom_loss, torch_loss) + + +def main(): + test_with_torch_label_smoothing_loss() + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index c9666362f..f93914aaa 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn +from label_smoothing import LabelSmoothingLoss from subsampling import Conv2dSubsampling, VggSubsampling from torch.nn.utils.rnn import pad_sequence @@ -152,7 +153,7 @@ class Transformer(nn.Module): d_model, self.decoder_num_class ) - self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) + self.decoder_criterion = LabelSmoothingLoss() else: self.decoder_criterion = None @@ -799,73 +800,6 @@ class Noam(object): setattr(self, key, value) -class LabelSmoothingLoss(nn.Module): - """ - Label-smoothing loss. KL-divergence between - q_{smoothed ground truth prob.}(w) - and p_{prob. computed by model}(w) is minimized. - Modified from - https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa - - Args: - size: the number of class - padding_idx: padding_idx: ignored class id - smoothing: smoothing rate (0.0 means the conventional CE) - normalize_length: normalize loss by sequence length if True - criterion: loss function to be smoothed - """ - - def __init__( - self, - size: int, - padding_idx: int = -1, - smoothing: float = 0.1, - normalize_length: bool = False, - criterion: nn.Module = nn.KLDivLoss(reduction="none"), - ) -> None: - """Construct an LabelSmoothingLoss object.""" - super(LabelSmoothingLoss, self).__init__() - self.criterion = criterion - self.padding_idx = padding_idx - assert 0.0 < smoothing <= 1.0 - self.confidence = 1.0 - smoothing - self.smoothing = smoothing - self.size = size - self.true_dist = None - self.normalize_length = normalize_length - - def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - """ - Compute loss between x and target. - - Args: - x: - prediction of dimension - (batch_size, input_length, number_of_classes). - target: - target masked with self.padding_id of - dimension (batch_size, input_length). - - Returns: - A scalar tensor containing the loss without normalization. - """ - assert x.size(2) == self.size - # batch_size = x.size(0) - x = x.view(-1, self.size) - target = target.view(-1) - with torch.no_grad(): - true_dist = x.clone() - true_dist.fill_(self.smoothing / (self.size - 1)) - ignore = target == self.padding_idx # (B,) - total = len(target) - ignore.sum().item() - target = target.masked_fill(ignore, 0) # avoid -1 index - true_dist.scatter_(1, target.unsqueeze(1), self.confidence) - kl = self.criterion(torch.log_softmax(x, dim=1), true_dist) - # denom = total if self.normalize_length else batch_size - denom = total if self.normalize_length else 1 - return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom - - def encoder_padding_mask( max_len: int, supervisions: Optional[Supervisions] = None ) -> Optional[torch.Tensor]: