mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
New label smoothing (#109)
* Modify label smoothing to match the one implemented in PyTorch. * Enable CI for torch 1.10 * Fix CI errors. * Fix CI installation errors. * Fix CI installation errors. * Minor fixes. * Minor fixes. * Minor fixes. * Minor fixes. * Minor fixes. * Fix CI errors.
This commit is contained in:
parent
10e46f3e1d
commit
336283f872
9
.github/workflows/run-pretrained.yml
vendored
9
.github/workflows/run-pretrained.yml
vendored
@ -31,8 +31,9 @@ jobs:
|
||||
matrix:
|
||||
os: [ubuntu-18.04]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
torch: ["1.8.1"]
|
||||
k2-version: ["1.9.dev20210919"]
|
||||
torch: ["1.10.0"]
|
||||
torchaudio: ["0.10.0"]
|
||||
k2-version: ["1.9.dev20211101"]
|
||||
|
||||
fail-fast: false
|
||||
|
||||
@ -49,7 +50,9 @@ jobs:
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python3 -m pip install --upgrade pip pytest
|
||||
pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
|
||||
# numpy 1.20.x does not support python 3.6
|
||||
pip install numpy==1.19
|
||||
pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
|
||||
|
||||
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
|
||||
|
6
.github/workflows/run-yesno-recipe.yml
vendored
6
.github/workflows/run-yesno-recipe.yml
vendored
@ -33,8 +33,9 @@ jobs:
|
||||
# TODO: enable macOS for CPU testing
|
||||
os: [ubuntu-18.04]
|
||||
python-version: [3.8]
|
||||
torch: ["1.8.1"]
|
||||
k2-version: ["1.9.dev20210919"]
|
||||
torch: ["1.10.0"]
|
||||
torchaudio: ["0.10.0"]
|
||||
k2-version: ["1.9.dev20211101"]
|
||||
fail-fast: false
|
||||
|
||||
steps:
|
||||
@ -57,6 +58,7 @@ jobs:
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python3 -m pip install -U pip
|
||||
pip install torch==${{ matrix.torch }}+cpu torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
|
||||
python3 -m pip install git+https://github.com/lhotse-speech/lhotse
|
||||
|
||||
|
28
.github/workflows/test.yml
vendored
28
.github/workflows/test.yml
vendored
@ -33,8 +33,14 @@ jobs:
|
||||
# disable macOS test for now.
|
||||
os: [ubuntu-18.04]
|
||||
python-version: [3.6, 3.7, 3.8, 3.9]
|
||||
torch: ["1.8.1"]
|
||||
k2-version: ["1.9.dev20210919"]
|
||||
torch: ["1.8.0", "1.10.0"]
|
||||
torchaudio: ["0.8.0", "0.10.0"]
|
||||
k2-version: ["1.9.dev20211101"]
|
||||
exclude:
|
||||
- torch: "1.8.0"
|
||||
torchaudio: "0.10.0"
|
||||
- torch: "1.10.0"
|
||||
torchaudio: "0.8.0"
|
||||
|
||||
fail-fast: false
|
||||
|
||||
@ -58,6 +64,15 @@ jobs:
|
||||
- name: Install Python dependencies
|
||||
run: |
|
||||
python3 -m pip install --upgrade pip pytest
|
||||
# numpy 1.20.x does not support python 3.6
|
||||
pip install numpy==1.19
|
||||
pip install torch==${{ matrix.torch }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
if [[ ${{ matrix.torchaudio }} == "0.10.0" ]]; then
|
||||
pip install torchaudio==${{ matrix.torchaudio }}+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html
|
||||
else
|
||||
pip install torchaudio==${{ matrix.torchaudio }}
|
||||
fi
|
||||
|
||||
pip install k2==${{ matrix.k2-version }}+cpu.torch${{ matrix.torch }} -f https://k2-fsa.org/nightly/
|
||||
pip install git+https://github.com/lhotse-speech/lhotse
|
||||
# icefall requirements
|
||||
@ -83,7 +98,10 @@ jobs:
|
||||
ls -lh
|
||||
export PYTHONPATH=$PWD:$PWD/lhotse:$PYTHONPATH
|
||||
echo $PYTHONPATH
|
||||
pytest ./test
|
||||
pytest -v -s ./test
|
||||
# runt tests for conformer ctc
|
||||
cd egs/librispeech/ASR/conformer_ctc
|
||||
pytest -v -s
|
||||
|
||||
- name: Run tests
|
||||
if: startsWith(matrix.os, 'macos')
|
||||
@ -93,8 +111,8 @@ jobs:
|
||||
lib_path=$(python -c "from distutils.sysconfig import get_python_lib; print(get_python_lib())")
|
||||
echo "lib_path: $lib_path"
|
||||
export DYLD_LIBRARY_PATH=$lib_path:$DYLD_LIBRARY_PATH
|
||||
pytest ./test
|
||||
pytest -v -s ./test
|
||||
|
||||
# runt tests for conformer ctc
|
||||
cd egs/librispeech/ASR/conformer_ctc
|
||||
pytest
|
||||
pytest -v -s
|
||||
|
98
egs/librispeech/ASR/conformer_ctc/label_smoothing.py
Normal file
98
egs/librispeech/ASR/conformer_ctc/label_smoothing.py
Normal file
@ -0,0 +1,98 @@
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LabelSmoothingLoss(torch.nn.Module):
|
||||
"""
|
||||
Implement the LabelSmoothingLoss proposed in the following paper
|
||||
https://arxiv.org/pdf/1512.00567.pdf
|
||||
(Rethinking the Inception Architecture for Computer Vision)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ignore_index: int = -1,
|
||||
label_smoothing: float = 0.1,
|
||||
reduction: str = "sum",
|
||||
) -> None:
|
||||
"""
|
||||
Args:
|
||||
ignore_index:
|
||||
ignored class id
|
||||
label_smoothing:
|
||||
smoothing rate (0.0 means the conventional cross entropy loss)
|
||||
reduction:
|
||||
It has the same meaning as the reduction in
|
||||
`torch.nn.CrossEntropyLoss`. It can be one of the following three
|
||||
values: (1) "none": No reduction will be applied. (2) "mean": the
|
||||
mean of the output is taken. (3) "sum": the output will be summed.
|
||||
"""
|
||||
super().__init__()
|
||||
assert 0.0 <= label_smoothing < 1.0
|
||||
self.ignore_index = ignore_index
|
||||
self.label_smoothing = label_smoothing
|
||||
self.reduction = reduction
|
||||
|
||||
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute loss between x and target.
|
||||
|
||||
Args:
|
||||
x:
|
||||
prediction of dimension
|
||||
(batch_size, input_length, number_of_classes).
|
||||
target:
|
||||
target masked with self.ignore_index of
|
||||
dimension (batch_size, input_length).
|
||||
|
||||
Returns:
|
||||
A scalar tensor containing the loss without normalization.
|
||||
"""
|
||||
assert x.ndim == 3
|
||||
assert target.ndim == 2
|
||||
assert x.shape[:2] == target.shape
|
||||
num_classes = x.size(-1)
|
||||
x = x.reshape(-1, num_classes)
|
||||
# Now x is of shape (N*T, C)
|
||||
|
||||
# We don't want to change target in-place below,
|
||||
# so we make a copy of it here
|
||||
target = target.clone().reshape(-1)
|
||||
|
||||
ignored = target == self.ignore_index
|
||||
target[ignored] = 0
|
||||
|
||||
true_dist = torch.nn.functional.one_hot(
|
||||
target, num_classes=num_classes
|
||||
).to(x)
|
||||
|
||||
true_dist = (
|
||||
true_dist * (1 - self.label_smoothing)
|
||||
+ self.label_smoothing / num_classes
|
||||
)
|
||||
# Set the value of ignored indexes to 0
|
||||
true_dist[ignored] = 0
|
||||
|
||||
loss = -1 * (torch.log_softmax(x, dim=1) * true_dist)
|
||||
if self.reduction == "sum":
|
||||
return loss.sum()
|
||||
elif self.reduction == "mean":
|
||||
return loss.sum() / (~ignored).sum()
|
||||
else:
|
||||
return loss.sum(dim=-1)
|
52
egs/librispeech/ASR/conformer_ctc/test_label_smoothing.py
Executable file
52
egs/librispeech/ASR/conformer_ctc/test_label_smoothing.py
Executable file
@ -0,0 +1,52 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import torch
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
|
||||
torch_ver = LooseVersion(torch.__version__)
|
||||
|
||||
|
||||
def test_with_torch_label_smoothing_loss():
|
||||
if torch_ver < LooseVersion("1.10.0"):
|
||||
print(f"Current torch version: {torch_ver}")
|
||||
print("Please use torch >= 1.10 to run this test - skipping")
|
||||
return
|
||||
torch.manual_seed(20211105)
|
||||
x = torch.rand(20, 30, 5000)
|
||||
tgt = torch.randint(low=-1, high=x.size(-1), size=x.shape[:2])
|
||||
for reduction in ["none", "sum", "mean"]:
|
||||
custom_loss_func = LabelSmoothingLoss(
|
||||
ignore_index=-1, label_smoothing=0.1, reduction=reduction
|
||||
)
|
||||
custom_loss = custom_loss_func(x, tgt)
|
||||
|
||||
torch_loss_func = torch.nn.CrossEntropyLoss(
|
||||
ignore_index=-1, reduction=reduction, label_smoothing=0.1
|
||||
)
|
||||
torch_loss = torch_loss_func(x.reshape(-1, x.size(-1)), tgt.reshape(-1))
|
||||
assert torch.allclose(custom_loss, torch_loss)
|
||||
|
||||
|
||||
def main():
|
||||
test_with_torch_label_smoothing_loss()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -20,6 +20,7 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from label_smoothing import LabelSmoothingLoss
|
||||
from subsampling import Conv2dSubsampling, VggSubsampling
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
@ -152,7 +153,7 @@ class Transformer(nn.Module):
|
||||
d_model, self.decoder_num_class
|
||||
)
|
||||
|
||||
self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class)
|
||||
self.decoder_criterion = LabelSmoothingLoss()
|
||||
else:
|
||||
self.decoder_criterion = None
|
||||
|
||||
@ -799,73 +800,6 @@ class Noam(object):
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
class LabelSmoothingLoss(nn.Module):
|
||||
"""
|
||||
Label-smoothing loss. KL-divergence between
|
||||
q_{smoothed ground truth prob.}(w)
|
||||
and p_{prob. computed by model}(w) is minimized.
|
||||
Modified from
|
||||
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa
|
||||
|
||||
Args:
|
||||
size: the number of class
|
||||
padding_idx: padding_idx: ignored class id
|
||||
smoothing: smoothing rate (0.0 means the conventional CE)
|
||||
normalize_length: normalize loss by sequence length if True
|
||||
criterion: loss function to be smoothed
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
padding_idx: int = -1,
|
||||
smoothing: float = 0.1,
|
||||
normalize_length: bool = False,
|
||||
criterion: nn.Module = nn.KLDivLoss(reduction="none"),
|
||||
) -> None:
|
||||
"""Construct an LabelSmoothingLoss object."""
|
||||
super(LabelSmoothingLoss, self).__init__()
|
||||
self.criterion = criterion
|
||||
self.padding_idx = padding_idx
|
||||
assert 0.0 < smoothing <= 1.0
|
||||
self.confidence = 1.0 - smoothing
|
||||
self.smoothing = smoothing
|
||||
self.size = size
|
||||
self.true_dist = None
|
||||
self.normalize_length = normalize_length
|
||||
|
||||
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Compute loss between x and target.
|
||||
|
||||
Args:
|
||||
x:
|
||||
prediction of dimension
|
||||
(batch_size, input_length, number_of_classes).
|
||||
target:
|
||||
target masked with self.padding_id of
|
||||
dimension (batch_size, input_length).
|
||||
|
||||
Returns:
|
||||
A scalar tensor containing the loss without normalization.
|
||||
"""
|
||||
assert x.size(2) == self.size
|
||||
# batch_size = x.size(0)
|
||||
x = x.view(-1, self.size)
|
||||
target = target.view(-1)
|
||||
with torch.no_grad():
|
||||
true_dist = x.clone()
|
||||
true_dist.fill_(self.smoothing / (self.size - 1))
|
||||
ignore = target == self.padding_idx # (B,)
|
||||
total = len(target) - ignore.sum().item()
|
||||
target = target.masked_fill(ignore, 0) # avoid -1 index
|
||||
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
||||
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
|
||||
# denom = total if self.normalize_length else batch_size
|
||||
denom = total if self.normalize_length else 1
|
||||
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
|
||||
|
||||
|
||||
def encoder_padding_mask(
|
||||
max_len: int, supervisions: Optional[Supervisions] = None
|
||||
) -> Optional[torch.Tensor]:
|
||||
|
Loading…
x
Reference in New Issue
Block a user