New label smoothing (#109)

* Modify label smoothing to match the one implemented in PyTorch.

* Enable CI for torch 1.10

* Fix CI errors.

* Fix CI installation errors.

* Fix CI installation errors.

* Minor fixes.

* Minor fixes.

* Minor fixes.

* Minor fixes.

* Minor fixes.

* Fix CI errors.
This commit is contained in:
Fangjun Kuang 2021-11-17 19:24:07 +08:00 committed by GitHub
parent 10e46f3e1d
commit 336283f872
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 185 additions and 78 deletions

View File

@ -31,8 +31,9 @@ jobs:
matrix: matrix:
os: [ubuntu-18.04] os: [ubuntu-18.04]
python-version: [3.6, 3.7, 3.8, 3.9] python-version: [3.6, 3.7, 3.8, 3.9]
torch: ["1.8.1"] torch: ["1.10.0"]
k2-version: ["1.9.dev20210919"] torchaudio: ["0.10.0"]
k2-version: ["1.9.dev20211101"]
fail-fast: false fail-fast: false
@ -49,7 +50,9 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
python3 -m pip install --upgrade pip pytest python3 -m pip install --upgrade pip pytest
pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html # numpy 1.20.x does not support python 3.6
pip install numpy==1.19
pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
python3 -m pip install git+https://github.com/lhotse-speech/lhotse python3 -m pip install git+https://github.com/lhotse-speech/lhotse

View File

@ -33,8 +33,9 @@ jobs:
# TODO: enable macOS for CPU testing # TODO: enable macOS for CPU testing
os: [ubuntu-18.04] os: [ubuntu-18.04]
python-version: [3.8] python-version: [3.8]
torch: ["1.8.1"] torch: ["1.10.0"]
k2-version: ["1.9.dev20210919"] torchaudio: ["0.10.0"]
k2-version: ["1.9.dev20211101"]
fail-fast: false fail-fast: false
steps: steps:
@ -57,6 +58,7 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
python3 -m pip install -U pip python3 -m pip install -U pip
pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
python3 -m pip install git+https://github.com/lhotse-speech/lhotse python3 -m pip install git+https://github.com/lhotse-speech/lhotse

View File

@ -33,8 +33,14 @@ jobs:
# disable macOS test for now. # disable macOS test for now.
os: [ubuntu-18.04] os: [ubuntu-18.04]
python-version: [3.6, 3.7, 3.8, 3.9] python-version: [3.6, 3.7, 3.8, 3.9]
torch: ["1.8.1"] torch: ["1.8.0", "1.10.0"]
k2-version: ["1.9.dev20210919"] torchaudio: ["0.8.0", "0.10.0"]
k2-version: ["1.9.dev20211101"]
exclude:
- torch: "1.8.0"
torchaudio: "0.10.0"
- torch: "1.10.0"
torchaudio: "0.8.0"
fail-fast: false fail-fast: false
@ -58,6 +64,15 @@ jobs:
- name: Install Python dependencies - name: Install Python dependencies
run: | run: |
python3 -m pip install --upgrade pip pytest python3 -m pip install --upgrade pip pytest
# numpy 1.20.x does not support python 3.6
pip install numpy==1.19
pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
else
pip install torchaudio==${{ matrix.torchaudio }}
fi
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/ pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
pip install git+https://github.com/lhotse-speech/lhotse pip install git+https://github.com/lhotse-speech/lhotse
# icefall requirements # icefall requirements
@ -83,7 +98,10 @@ jobs:
ls -lh ls -lh
export PYTHONPATH=$PWD:$PWD/lhotse:$PYTHONPATH export PYTHONPATH=$PWD:$PWD/lhotse:$PYTHONPATH
echo $PYTHONPATH echo $PYTHONPATH
pytest ./test pytest -v -s ./test
# runt tests for conformer ctc
cd egs/librispeech/ASR/conformer_ctc
pytest -v -s
- name: Run tests - name: Run tests
if: startsWith(matrix.os, 'macos') if: startsWith(matrix.os, 'macos')
@ -93,8 +111,8 @@ jobs:
lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())") lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
echo "lib_path: $lib_path" echo "lib_path: $lib_path"
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
pytest ./test pytest -v -s ./test
# runt tests for conformer ctc # runt tests for conformer ctc
cd egs/librispeech/ASR/conformer_ctc cd egs/librispeech/ASR/conformer_ctc
pytest pytest -v -s

View File

@ -0,0 +1,98 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
class LabelSmoothingLoss(torch.nn.Module):
"""
Implement the LabelSmoothingLoss proposed in the following paper
https://arxiv.org/pdf/1512.00567.pdf
(Rethinking the Inception Architecture for Computer Vision)
"""
def __init__(
self,
ignore_index: int = -1,
label_smoothing: float = 0.1,
reduction: str = "sum",
) -> None:
"""
Args:
ignore_index:
ignored class id
label_smoothing:
smoothing rate (0.0 means the conventional cross entropy loss)
reduction:
It has the same meaning as the reduction in
`torch.nn.CrossEntropyLoss`. It can be one of the following three
values: (1) "none": No reduction will be applied. (2) "mean": the
mean of the output is taken. (3) "sum": the output will be summed.
"""
super().__init__()
assert 0.0 <= label_smoothing < 1.0
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.reduction = reduction
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss between x and target.
Args:
x:
prediction of dimension
(batch_size, input_length, number_of_classes).
target:
target masked with self.ignore_index of
dimension (batch_size, input_length).
Returns:
A scalar tensor containing the loss without normalization.
"""
assert x.ndim == 3
assert target.ndim == 2
assert x.shape[:2] == target.shape
num_classes = x.size(-1)
x = x.reshape(-1, num_classes)
# Now x is of shape (N*T, C)
# We don't want to change target in-place below,
# so we make a copy of it here
target = target.clone().reshape(-1)
ignored = target == self.ignore_index
target[ignored] = 0
true_dist = torch.nn.functional.one_hot(
target, num_classes=num_classes
).to(x)
true_dist = (
true_dist * (1 - self.label_smoothing)
+ self.label_smoothing / num_classes
)
# Set the value of ignored indexes to 0
true_dist[ignored] = 0
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
if self.reduction == "sum":
return loss.sum()
elif self.reduction == "mean":
return loss.sum() / (~ignored).sum()
else:
return loss.sum(dim=-1)

View File

@ -0,0 +1,52 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from distutils.version import LooseVersion
import torch
from label_smoothing import LabelSmoothingLoss
torch_ver = LooseVersion(torch.__version__)
def test_with_torch_label_smoothing_loss():
if torch_ver < LooseVersion("1.10.0"):
print(f"Current torch version: {torch_ver}")
print("Please use torch >= 1.10 to run this test - skipping")
return
torch.manual_seed(20211105)
x = torch.rand(20, 30, 5000)
tgt = torch.randint(low=-1, high=x.size(-1), size=x.shape[:2])
for reduction in ["none", "sum", "mean"]:
custom_loss_func = LabelSmoothingLoss(
ignore_index=-1, label_smoothing=0.1, reduction=reduction
)
custom_loss = custom_loss_func(x, tgt)
torch_loss_func = torch.nn.CrossEntropyLoss(
ignore_index=-1, reduction=reduction, label_smoothing=0.1
)
torch_loss = torch_loss_func(x.reshape(-1, x.size(-1)), tgt.reshape(-1))
assert torch.allclose(custom_loss, torch_loss)
def main():
test_with_torch_label_smoothing_loss()
if __name__ == "__main__":
main()

View File

@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from label_smoothing import LabelSmoothingLoss
from subsampling import Conv2dSubsampling, VggSubsampling from subsampling import Conv2dSubsampling, VggSubsampling
from torch.nn.utils.rnn import pad_sequence from torch.nn.utils.rnn import pad_sequence
@ -152,7 +153,7 @@ class Transformer(nn.Module):
d_model, self.decoder_num_class d_model, self.decoder_num_class
) )
self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) self.decoder_criterion = LabelSmoothingLoss()
else: else:
self.decoder_criterion = None self.decoder_criterion = None
@ -799,73 +800,6 @@ class Noam(object):
setattr(self, key, value) setattr(self, key, value)
class LabelSmoothingLoss(nn.Module):
"""
Label-smoothing loss. KL-divergence between
q_{smoothed ground truth prob.}(w)
and p_{prob. computed by model}(w) is minimized.
Modified from
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
Args:
size: the number of class
padding_idx: padding_idx: ignored class id
smoothing: smoothing rate (0.0 means the conventional CE)
normalize_length: normalize loss by sequence length if True
criterion: loss function to be smoothed
"""
def __init__(
self,
size: int,
padding_idx: int = -1,
smoothing: float = 0.1,
normalize_length: bool = False,
criterion: nn.Module = nn.KLDivLoss(reduction="none"),
) -> None:
"""Construct an LabelSmoothingLoss object."""
super(LabelSmoothingLoss, self).__init__()
self.criterion = criterion
self.padding_idx = padding_idx
assert 0.0 < smoothing <= 1.0
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.size = size
self.true_dist = None
self.normalize_length = normalize_length
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Compute loss between x and target.
Args:
x:
prediction of dimension
(batch_size, input_length, number_of_classes).
target:
target masked with self.padding_id of
dimension (batch_size, input_length).
Returns:
A scalar tensor containing the loss without normalization.
"""
assert x.size(2) == self.size
# batch_size = x.size(0)
x = x.view(-1, self.size)
target = target.view(-1)
with torch.no_grad():
true_dist = x.clone()
true_dist.fill_(self.smoothing / (self.size - 1))
ignore = target == self.padding_idx # (B,)
total = len(target) - ignore.sum().item()
target = target.masked_fill(ignore, 0) # avoid -1 index
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
# denom = total if self.normalize_length else batch_size
denom = total if self.normalize_length else 1
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
def encoder_padding_mask( def encoder_padding_mask(
max_len: int, supervisions: Optional[Supervisions] = None max_len: int, supervisions: Optional[Supervisions] = None
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]: