mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
New label smoothing (#109)
* Modify label smoothing to match the one implemented in PyTorch. * Enable CI for torch 1.10 * Fix CI errors. * Fix CI installation errors. * Fix CI installation errors. * Minor fixes. * Minor fixes. * Minor fixes. * Minor fixes. * Minor fixes. * Fix CI errors.
This commit is contained in:
parent
10e46f3e1d
commit
336283f872
9
.github/workflows/run-pretrained.yml
vendored
9
.github/workflows/run-pretrained.yml
vendored
@ -31,8 +31,9 @@ jobs:
|
|||||||
matrix:
|
matrix:
|
||||||
os: [ubuntu-18.04]
|
os: [ubuntu-18.04]
|
||||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||||
torch: ["1.8.1"]
|
torch: ["1.10.0"]
|
||||||
k2-version: ["1.9.dev20210919"]
|
torchaudio: ["0.10.0"]
|
||||||
|
k2-version: ["1.9.dev20211101"]
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
@ -49,7 +50,9 @@ jobs:
|
|||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install --upgrade pip pytest
|
python3 -m pip install --upgrade pip pytest
|
||||||
pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
# numpy 1.20.x does not support python 3.6
|
||||||
|
pip install numpy==1.19
|
||||||
|
pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||||
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
|
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
|
||||||
|
|
||||||
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
|
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
|
||||||
|
6
.github/workflows/run-yesno-recipe.yml
vendored
6
.github/workflows/run-yesno-recipe.yml
vendored
@ -33,8 +33,9 @@ jobs:
|
|||||||
# TODO: enable macOS for CPU testing
|
# TODO: enable macOS for CPU testing
|
||||||
os: [ubuntu-18.04]
|
os: [ubuntu-18.04]
|
||||||
python-version: [3.8]
|
python-version: [3.8]
|
||||||
torch: ["1.8.1"]
|
torch: ["1.10.0"]
|
||||||
k2-version: ["1.9.dev20210919"]
|
torchaudio: ["0.10.0"]
|
||||||
|
k2-version: ["1.9.dev20211101"]
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
@ -57,6 +58,7 @@ jobs:
|
|||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install -U pip
|
python3 -m pip install -U pip
|
||||||
|
pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||||
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
|
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
|
||||||
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
|
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
|
||||||
|
|
||||||
|
28
.github/workflows/test.yml
vendored
28
.github/workflows/test.yml
vendored
@ -33,8 +33,14 @@ jobs:
|
|||||||
# disable macOS test for now.
|
# disable macOS test for now.
|
||||||
os: [ubuntu-18.04]
|
os: [ubuntu-18.04]
|
||||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||||
torch: ["1.8.1"]
|
torch: ["1.8.0", "1.10.0"]
|
||||||
k2-version: ["1.9.dev20210919"]
|
torchaudio: ["0.8.0", "0.10.0"]
|
||||||
|
k2-version: ["1.9.dev20211101"]
|
||||||
|
exclude:
|
||||||
|
- torch: "1.8.0"
|
||||||
|
torchaudio: "0.10.0"
|
||||||
|
- torch: "1.10.0"
|
||||||
|
torchaudio: "0.8.0"
|
||||||
|
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|
||||||
@ -58,6 +64,15 @@ jobs:
|
|||||||
- name: Install Python dependencies
|
- name: Install Python dependencies
|
||||||
run: |
|
run: |
|
||||||
python3 -m pip install --upgrade pip pytest
|
python3 -m pip install --upgrade pip pytest
|
||||||
|
# numpy 1.20.x does not support python 3.6
|
||||||
|
pip install numpy==1.19
|
||||||
|
pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||||
|
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
|
||||||
|
pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||||
|
else
|
||||||
|
pip install torchaudio==${{ matrix.torchaudio }}
|
||||||
|
fi
|
||||||
|
|
||||||
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
|
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
|
||||||
pip install git+https://github.com/lhotse-speech/lhotse
|
pip install git+https://github.com/lhotse-speech/lhotse
|
||||||
# icefall requirements
|
# icefall requirements
|
||||||
@ -83,7 +98,10 @@ jobs:
|
|||||||
ls -lh
|
ls -lh
|
||||||
export PYTHONPATH=$PWD:$PWD/lhotse:$PYTHONPATH
|
export PYTHONPATH=$PWD:$PWD/lhotse:$PYTHONPATH
|
||||||
echo $PYTHONPATH
|
echo $PYTHONPATH
|
||||||
pytest ./test
|
pytest -v -s ./test
|
||||||
|
# runt tests for conformer ctc
|
||||||
|
cd egs/librispeech/ASR/conformer_ctc
|
||||||
|
pytest -v -s
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
if: startsWith(matrix.os, 'macos')
|
if: startsWith(matrix.os, 'macos')
|
||||||
@ -93,8 +111,8 @@ jobs:
|
|||||||
lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
|
lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
|
||||||
echo "lib_path: $lib_path"
|
echo "lib_path: $lib_path"
|
||||||
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
|
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
|
||||||
pytest ./test
|
pytest -v -s ./test
|
||||||
|
|
||||||
# runt tests for conformer ctc
|
# runt tests for conformer ctc
|
||||||
cd egs/librispeech/ASR/conformer_ctc
|
cd egs/librispeech/ASR/conformer_ctc
|
||||||
pytest
|
pytest -v -s
|
||||||
|
98
egs/librispeech/ASR/conformer_ctc/label_smoothing.py
Normal file
98
egs/librispeech/ASR/conformer_ctc/label_smoothing.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class LabelSmoothingLoss(torch.nn.Module):
|
||||||
|
"""
|
||||||
|
Implement the LabelSmoothingLoss proposed in the following paper
|
||||||
|
https://arxiv.org/pdf/1512.00567.pdf
|
||||||
|
(Rethinking the Inception Architecture for Computer Vision)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ignore_index: int = -1,
|
||||||
|
label_smoothing: float = 0.1,
|
||||||
|
reduction: str = "sum",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
ignore_index:
|
||||||
|
ignored class id
|
||||||
|
label_smoothing:
|
||||||
|
smoothing rate (0.0 means the conventional cross entropy loss)
|
||||||
|
reduction:
|
||||||
|
It has the same meaning as the reduction in
|
||||||
|
`torch.nn.CrossEntropyLoss`. It can be one of the following three
|
||||||
|
values: (1) "none": No reduction will be applied. (2) "mean": the
|
||||||
|
mean of the output is taken. (3) "sum": the output will be summed.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert 0.0 <= label_smoothing < 1.0
|
||||||
|
self.ignore_index = ignore_index
|
||||||
|
self.label_smoothing = label_smoothing
|
||||||
|
self.reduction = reduction
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Compute loss between x and target.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
prediction of dimension
|
||||||
|
(batch_size, input_length, number_of_classes).
|
||||||
|
target:
|
||||||
|
target masked with self.ignore_index of
|
||||||
|
dimension (batch_size, input_length).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A scalar tensor containing the loss without normalization.
|
||||||
|
"""
|
||||||
|
assert x.ndim == 3
|
||||||
|
assert target.ndim == 2
|
||||||
|
assert x.shape[:2] == target.shape
|
||||||
|
num_classes = x.size(-1)
|
||||||
|
x = x.reshape(-1, num_classes)
|
||||||
|
# Now x is of shape (N*T, C)
|
||||||
|
|
||||||
|
# We don't want to change target in-place below,
|
||||||
|
# so we make a copy of it here
|
||||||
|
target = target.clone().reshape(-1)
|
||||||
|
|
||||||
|
ignored = target == self.ignore_index
|
||||||
|
target[ignored] = 0
|
||||||
|
|
||||||
|
true_dist = torch.nn.functional.one_hot(
|
||||||
|
target, num_classes=num_classes
|
||||||
|
).to(x)
|
||||||
|
|
||||||
|
true_dist = (
|
||||||
|
true_dist * (1 - self.label_smoothing)
|
||||||
|
+ self.label_smoothing / num_classes
|
||||||
|
)
|
||||||
|
# Set the value of ignored indexes to 0
|
||||||
|
true_dist[ignored] = 0
|
||||||
|
|
||||||
|
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
|
||||||
|
if self.reduction == "sum":
|
||||||
|
return loss.sum()
|
||||||
|
elif self.reduction == "mean":
|
||||||
|
return loss.sum() / (~ignored).sum()
|
||||||
|
else:
|
||||||
|
return loss.sum(dim=-1)
|
52
egs/librispeech/ASR/conformer_ctc/test_label_smoothing.py
Executable file
52
egs/librispeech/ASR/conformer_ctc/test_label_smoothing.py
Executable file
@ -0,0 +1,52 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from distutils.version import LooseVersion
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from label_smoothing import LabelSmoothingLoss
|
||||||
|
|
||||||
|
torch_ver = LooseVersion(torch.__version__)
|
||||||
|
|
||||||
|
|
||||||
|
def test_with_torch_label_smoothing_loss():
|
||||||
|
if torch_ver < LooseVersion("1.10.0"):
|
||||||
|
print(f"Current torch version: {torch_ver}")
|
||||||
|
print("Please use torch >= 1.10 to run this test - skipping")
|
||||||
|
return
|
||||||
|
torch.manual_seed(20211105)
|
||||||
|
x = torch.rand(20, 30, 5000)
|
||||||
|
tgt = torch.randint(low=-1, high=x.size(-1), size=x.shape[:2])
|
||||||
|
for reduction in ["none", "sum", "mean"]:
|
||||||
|
custom_loss_func = LabelSmoothingLoss(
|
||||||
|
ignore_index=-1, label_smoothing=0.1, reduction=reduction
|
||||||
|
)
|
||||||
|
custom_loss = custom_loss_func(x, tgt)
|
||||||
|
|
||||||
|
torch_loss_func = torch.nn.CrossEntropyLoss(
|
||||||
|
ignore_index=-1, reduction=reduction, label_smoothing=0.1
|
||||||
|
)
|
||||||
|
torch_loss = torch_loss_func(x.reshape(-1, x.size(-1)), tgt.reshape(-1))
|
||||||
|
assert torch.allclose(custom_loss, torch_loss)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
test_with_torch_label_smoothing_loss()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
from label_smoothing import LabelSmoothingLoss
|
||||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||||
from torch.nn.utils.rnn import pad_sequence
|
from torch.nn.utils.rnn import pad_sequence
|
||||||
|
|
||||||
@ -152,7 +153,7 @@ class Transformer(nn.Module):
|
|||||||
d_model, self.decoder_num_class
|
d_model, self.decoder_num_class
|
||||||
)
|
)
|
||||||
|
|
||||||
self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class)
|
self.decoder_criterion = LabelSmoothingLoss()
|
||||||
else:
|
else:
|
||||||
self.decoder_criterion = None
|
self.decoder_criterion = None
|
||||||
|
|
||||||
@ -799,73 +800,6 @@ class Noam(object):
|
|||||||
setattr(self, key, value)
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
|
||||||
class LabelSmoothingLoss(nn.Module):
|
|
||||||
"""
|
|
||||||
Label-smoothing loss. KL-divergence between
|
|
||||||
q_{smoothed ground truth prob.}(w)
|
|
||||||
and p_{prob. computed by model}(w) is minimized.
|
|
||||||
Modified from
|
|
||||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
|
|
||||||
|
|
||||||
Args:
|
|
||||||
size: the number of class
|
|
||||||
padding_idx: padding_idx: ignored class id
|
|
||||||
smoothing: smoothing rate (0.0 means the conventional CE)
|
|
||||||
normalize_length: normalize loss by sequence length if True
|
|
||||||
criterion: loss function to be smoothed
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
size: int,
|
|
||||||
padding_idx: int = -1,
|
|
||||||
smoothing: float = 0.1,
|
|
||||||
normalize_length: bool = False,
|
|
||||||
criterion: nn.Module = nn.KLDivLoss(reduction="none"),
|
|
||||||
) -> None:
|
|
||||||
"""Construct an LabelSmoothingLoss object."""
|
|
||||||
super(LabelSmoothingLoss, self).__init__()
|
|
||||||
self.criterion = criterion
|
|
||||||
self.padding_idx = padding_idx
|
|
||||||
assert 0.0 < smoothing <= 1.0
|
|
||||||
self.confidence = 1.0 - smoothing
|
|
||||||
self.smoothing = smoothing
|
|
||||||
self.size = size
|
|
||||||
self.true_dist = None
|
|
||||||
self.normalize_length = normalize_length
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
Compute loss between x and target.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x:
|
|
||||||
prediction of dimension
|
|
||||||
(batch_size, input_length, number_of_classes).
|
|
||||||
target:
|
|
||||||
target masked with self.padding_id of
|
|
||||||
dimension (batch_size, input_length).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A scalar tensor containing the loss without normalization.
|
|
||||||
"""
|
|
||||||
assert x.size(2) == self.size
|
|
||||||
# batch_size = x.size(0)
|
|
||||||
x = x.view(-1, self.size)
|
|
||||||
target = target.view(-1)
|
|
||||||
with torch.no_grad():
|
|
||||||
true_dist = x.clone()
|
|
||||||
true_dist.fill_(self.smoothing / (self.size - 1))
|
|
||||||
ignore = target == self.padding_idx # (B,)
|
|
||||||
total = len(target) - ignore.sum().item()
|
|
||||||
target = target.masked_fill(ignore, 0) # avoid -1 index
|
|
||||||
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
|
||||||
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
|
|
||||||
# denom = total if self.normalize_length else batch_size
|
|
||||||
denom = total if self.normalize_length else 1
|
|
||||||
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
|
|
||||||
|
|
||||||
|
|
||||||
def encoder_padding_mask(
|
def encoder_padding_mask(
|
||||||
max_len: int, supervisions: Optional[Supervisions] = None
|
max_len: int, supervisions: Optional[Supervisions] = None
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user