Add Consistency-Regularized CTC (#1766)

* support consistency-regularized CTC

* update arguments of cr-ctc

* set default value of cr_loss_masked_scale to 1.0

* minor fix

* refactor codes

* update RESULTS.md
This commit is contained in:
Zengwei Yao 2024-10-21 10:35:26 +08:00 committed by GitHub
parent f84270c935
commit 693d84a301
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 556 additions and 20 deletions

View File

@ -50,7 +50,7 @@ We place an additional Conv1d layer right after the input embedding layer.
| `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head | | `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head |
| `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty | | `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty |
| `zipformer-ctc` | Zipformer | Use auxiliary attention head | | `zipformer-ctc` | Zipformer | Use auxiliary attention head |
| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head / attention-decoder head | The latest recipe | | `zipformer` | Upgraded Zipformer | Use auxiliary transducer head / attention-decoder head (the latest recipe) |
# MMI # MMI
@ -58,3 +58,9 @@ We place an additional Conv1d layer right after the input embedding layer.
|------------------------------|-----------|---------------------------------------------------| |------------------------------|-----------|---------------------------------------------------|
| `conformer-mmi` | Conformer | | | `conformer-mmi` | Conformer | |
| `zipformer-mmi` | Zipformer | CTC warmup + use HP as decoding graph for decoding | | `zipformer-mmi` | Zipformer | CTC warmup + use HP as decoding graph for decoding |
# CR-CTC
| | Encoder | Comment |
|------------------------------|--------------------|------------------------------|
| `zipformer` | Upgraded Zipformer | Could also be an auxiliary loss to improve transducer or CTC/AED (the latest recipe) |

View File

@ -1,5 +1,315 @@
## Results ## Results
### zipformer (zipformer + pruned-transducer w/ CR-CTC)
See <https://github.com/k2-fsa/icefall/pull/1766> for more details.
[zipformer](./zipformer)
#### Non-streaming
##### large-scale model, number of model parameters: 148824074, i.e., 148.8 M
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-large-transducer-with-CR-CTC-20241019>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|--------------------------------------|------------|------------|---------------------|
| greedy_search | 1.9 | 3.96 | --epoch 50 --avg 26 |
| modified_beam_search | 1.88 | 3.95 | --epoch 50 --avg 26 |
The training command using 2 80G-A100 GPUs is:
```bash
export CUDA_VISIBLE_DEVICES="0,1"
# for non-streaming model training:
./zipformer/train.py \
--world-size 2 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-large-cr-ctc-rnnt \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 1 \
--use-attention-decoder 0 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--ctc-loss-scale 0.1 \
--enable-spec-aug 0 \
--cr-loss-scale 0.02 \
--time-mask-ratio 2.5 \
--full-libri 1 \
--max-duration 1400 \
--master-port 12345
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in greedy_search modified_beam_search; do
./zipformer/decode.py \
--epoch 50 \
--avg 26 \
--exp-dir zipformer/exp-large-cr-ctc-rnnt \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 1 \
--use-attention-decoder 0 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--max-duration 300 \
--decoding-method $m
done
```
### zipformer (zipformer + CR-CTC-AED)
See <https://github.com/k2-fsa/icefall/pull/1766> for more details.
[zipformer](./zipformer)
#### Non-streaming
##### large-scale model, number of model parameters: 174319650, i.e., 174.3 M
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-large-cr-ctc-aed-20241020>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|--------------------------------------|------------|------------|---------------------|
| attention-decoder-rescoring-no-ngram | 1.96 | 4.08 | --epoch 50 --avg 20 |
The training command using 2 80G-A100 GPUs is:
```bash
export CUDA_VISIBLE_DEVICES="0,1"
# for non-streaming model training:
./zipformer/train.py \
--world-size 2 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-large-cr-ctc-aed \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 1 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--ctc-loss-scale 0.1 \
--attention-decoder-loss-scale 0.9 \
--enable-spec-aug 0 \
--cr-loss-scale 0.02 \
--time-mask-ratio 2.5 \
--full-libri 1 \
--max-duration 1200 \
--master-port 12345
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
./zipformer/ctc_decode.py \
--epoch 50 \
--avg 20 \
--exp-dir zipformer/exp-large-cr-ctc-aed/ \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 1 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--max-duration 200 \
--decoding-method attention-decoder-rescoring-no-ngram
done
```
### zipformer (zipformer + CR-CTC)
See <https://github.com/k2-fsa/icefall/pull/1766> for more details.
[zipformer](./zipformer)
#### Non-streaming
##### small-scale model, number of model parameters: 22118279, i.e., 22.1 M
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-small-cr-ctc-20241018>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|--------------------------------------|------------|------------|---------------------|
| ctc-greedy-decoding | 2.57 | 5.95 | --epoch 50 --avg 25 |
The training command using 2 32G-V100 GPUs is:
```bash
export CUDA_VISIBLE_DEVICES="0,1"
# for non-streaming model training:
./zipformer/train.py \
--world-size 2 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-small/ \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 0 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--base-lr 0.04 \
--enable-spec-aug 0 \
--cr-loss-scale 0.2 \
--time-mask-ratio 2.5 \
--full-libri 1 \
--max-duration 850 \
--master-port 12345
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in ctc-greedy-search; do
./zipformer/ctc_decode.py \
--epoch 50 \
--avg 25 \
--exp-dir zipformer/exp-small \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 0 \
--num-encoder-layers 2,2,2,2,2,2 \
--feedforward-dim 512,768,768,768,768,768 \
--encoder-dim 192,256,256,256,256,256 \
--encoder-unmasked-dim 192,192,192,192,192,192 \
--max-duration 600 \
--decoding-method $m
done
```
##### medium-scale model, number of model parameters: 64250603, i.e., 64.3 M
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-medium-cr-ctc-20241018>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|--------------------------------------|------------|------------|---------------------|
| ctc-greedy-decoding | 2.12 | 4.62 | --epoch 50 --avg 24 |
The training command using 4 32G-V100 GPUs is:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
# For non-streaming model training:
./zipformer/train.py \
--world-size 4 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 0 \
--enable-spec-aug 0 \
--cr-loss-scale 0.2 \
--time-mask-ratio 2.5 \
--full-libri 1 \
--max-duration 700 \
--master-port 12345
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in ctc-greedy-search; do
./zipformer/ctc_decode.py \
--epoch 50 \
--avg 24 \
--exp-dir zipformer/exp \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 0 \
--max-duration 600 \
--decoding-method $m
done
```
##### large-scale model, number of model parameters: 147010094, i.e., 147.0 M
You can find a pretrained model, training logs, decoding logs, and decoding results at:
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-large-cr-ctc-20241018>
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
| decoding method | test-clean | test-other | comment |
|--------------------------------------|------------|------------|---------------------|
| ctc-greedy-decoding | 2.03 | 4.37 | --epoch 50 --avg 26 |
The training command using 2 80G-A100 GPUs is:
```bash
export CUDA_VISIBLE_DEVICES="0,1"
# For non-streaming model training:
./zipformer/train.py \
--world-size 2 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir zipformer/exp-large \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 0 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--enable-spec-aug 0 \
--cr-loss-scale 0.2 \
--time-mask-ratio 2.5 \
--full-libri 1 \
--max-duration 1400 \
--master-port 12345
```
The decoding command is:
```bash
export CUDA_VISIBLE_DEVICES="0"
for m in ctc-greedy-search; do
./zipformer/ctc_decode.py \
--epoch 50 \
--avg 26 \
--exp-dir zipformer/exp-large \
--use-cr-ctc 1 \
--use-ctc 1 \
--use-transducer 0 \
--use-attention-decoder 0 \
--num-encoder-layers 2,2,4,5,4,2 \
--feedforward-dim 512,768,1536,2048,1536,768 \
--encoder-dim 192,256,512,768,512,256 \
--encoder-unmasked-dim 192,192,256,320,256,192 \
--max-duration 600 \
--decoding-method $m
done
```
### zipformer (zipformer + CTC/AED) ### zipformer (zipformer + CTC/AED)
See <https://github.com/k2-fsa/icefall/pull/1389> for more details. See <https://github.com/k2-fsa/icefall/pull/1389> for more details.

View File

@ -24,7 +24,8 @@ import torch.nn as nn
from encoder_interface import EncoderInterface from encoder_interface import EncoderInterface
from scaling import ScaledLinear from scaling import ScaledLinear
from icefall.utils import add_sos, make_pad_mask from icefall.utils import add_sos, make_pad_mask, time_warp
from lhotse.dataset import SpecAugment
class AsrModel(nn.Module): class AsrModel(nn.Module):
@ -181,6 +182,49 @@ class AsrModel(nn.Module):
) )
return ctc_loss return ctc_loss
def forward_cr_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute CTC loss with consistency regularization loss.
Args:
encoder_out:
Encoder output, of shape (2 * N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (2 * N,).
targets:
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC loss
ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C)
targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(),
reduction="sum",
)
# Compute consistency regularization loss
exchanged_targets = ctc_output.detach().chunk(2, dim=0)
exchanged_targets = torch.cat(
[exchanged_targets[1], exchanged_targets[0]], dim=0
) # exchange: [x1, x2] -> [x2, x1]
cr_loss = nn.functional.kl_div(
input=ctc_output,
target=exchanged_targets,
reduction="none",
log_target=True,
) # (2 * N, T, C)
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
return ctc_loss, cr_loss
def forward_transducer( def forward_transducer(
self, self,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
@ -296,7 +340,12 @@ class AsrModel(nn.Module):
prune_range: int = 5, prune_range: int = 5,
am_scale: float = 0.0, am_scale: float = 0.0,
lm_scale: float = 0.0, lm_scale: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: use_cr_ctc: bool = False,
use_spec_aug: bool = False,
spec_augment: Optional[SpecAugment] = None,
supervision_segments: Optional[torch.Tensor] = None,
time_warp_factor: Optional[int] = 80,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Args: Args:
x: x:
@ -316,9 +365,26 @@ class AsrModel(nn.Module):
lm_scale: lm_scale:
The scale to smooth the loss with lm (output of predictor network) The scale to smooth the loss with lm (output of predictor network)
part part
use_cr_ctc:
Whether use consistency-regularized CTC.
use_spec_aug:
Whether apply spec-augment manually, used only if use_cr_ctc is True.
spec_augment:
The SpecAugment instance that returns time masks,
used only if use_cr_ctc is True.
supervision_segments:
An int tensor of shape ``(S, 3)``. ``S`` is the number of
supervision segments that exist in ``features``.
Used only if use_cr_ctc is True.
time_warp_factor:
Parameter for the time warping; larger values mean more warping.
Set to ``None``, or less than ``1``, to disable.
Used only if use_cr_ctc is True.
Returns: Returns:
Return the transducer losses and CTC loss, Return the transducer losses, CTC loss, AED loss,
in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss) and consistency-regularization loss in form of
(simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss)
Note: Note:
Regarding am_scale & lm_scale, it will make the loss-function one of Regarding am_scale & lm_scale, it will make the loss-function one of
@ -334,6 +400,24 @@ class AsrModel(nn.Module):
device = x.device device = x.device
if use_cr_ctc:
assert self.use_ctc
if use_spec_aug:
assert spec_augment is not None and spec_augment.time_warp_factor < 1
# Apply time warping before input duplicating
assert supervision_segments is not None
x = time_warp(
x,
time_warp_factor=time_warp_factor,
supervision_segments=supervision_segments,
)
# Independently apply frequency masking and time masking to the two copies
x = spec_augment(x.repeat(2, 1, 1))
else:
x = x.repeat(2, 1, 1)
x_lens = x_lens.repeat(2)
y = k2.ragged.cat([y, y], axis=0)
# Compute encoder outputs # Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
@ -351,6 +435,9 @@ class AsrModel(nn.Module):
am_scale=am_scale, am_scale=am_scale,
lm_scale=lm_scale, lm_scale=lm_scale,
) )
if use_cr_ctc:
simple_loss = simple_loss * 0.5
pruned_loss = pruned_loss * 0.5
else: else:
simple_loss = torch.empty(0) simple_loss = torch.empty(0)
pruned_loss = torch.empty(0) pruned_loss = torch.empty(0)
@ -358,14 +445,26 @@ class AsrModel(nn.Module):
if self.use_ctc: if self.use_ctc:
# Compute CTC loss # Compute CTC loss
targets = y.values targets = y.values
ctc_loss = self.forward_ctc( if not use_cr_ctc:
encoder_out=encoder_out, ctc_loss = self.forward_ctc(
encoder_out_lens=encoder_out_lens, encoder_out=encoder_out,
targets=targets, encoder_out_lens=encoder_out_lens,
target_lengths=y_lens, targets=targets,
) target_lengths=y_lens,
)
cr_loss = torch.empty(0)
else:
ctc_loss, cr_loss = self.forward_cr_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
ctc_loss = ctc_loss * 0.5
cr_loss = cr_loss * 0.5
else: else:
ctc_loss = torch.empty(0) ctc_loss = torch.empty(0)
cr_loss = torch.empty(0)
if self.use_attention_decoder: if self.use_attention_decoder:
attention_decoder_loss = self.attention_decoder.calc_att_loss( attention_decoder_loss = self.attention_decoder.calc_att_loss(
@ -374,7 +473,9 @@ class AsrModel(nn.Module):
ys=y.to(device), ys=y.to(device),
ys_lens=y_lens.to(device), ys_lens=y_lens.to(device),
) )
if use_cr_ctc:
attention_decoder_loss = attention_decoder_loss * 0.5
else: else:
attention_decoder_loss = torch.empty(0) attention_decoder_loss = torch.empty(0)
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss

View File

@ -45,11 +45,10 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--max-duration 1000 --max-duration 1000
It supports training with: It supports training with:
- transducer loss (default), with `--use-transducer True --use-ctc False` - transducer loss (default)
- ctc loss (not recommended), with `--use-transducer False --use-ctc True` - ctc loss
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True` - attention decoder loss
- ctc loss & attention decoder loss, no transducer loss, - cr-ctc loss (should use half the max-duration compared to regular ctc)
with `--use-transducer False --use-ctc True --use-attention-decoder True`
""" """
@ -72,6 +71,7 @@ from attention_decoder import AttentionDecoderModel
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut
from lhotse.dataset import SpecAugment
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from model import AsrModel from model import AsrModel
@ -304,6 +304,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
help="If True, use attention-decoder head.", help="If True, use attention-decoder head.",
) )
parser.add_argument(
"--use-cr-ctc",
type=str2bool,
default=False,
help="If True, use consistency-regularized CTC.",
)
def get_parser(): def get_parser():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
@ -449,6 +456,20 @@ def get_parser():
help="Scale for CTC loss.", help="Scale for CTC loss.",
) )
parser.add_argument(
"--cr-loss-scale",
type=float,
default=0.2,
help="Scale for consistency-regularization loss.",
)
parser.add_argument(
"--time-mask-ratio",
type=float,
default=2.5,
help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.",
)
parser.add_argument( parser.add_argument(
"--attention-decoder-loss-scale", "--attention-decoder-loss-scale",
type=float, type=float,
@ -717,6 +738,24 @@ def get_model(params: AttributeDict) -> nn.Module:
return model return model
def get_spec_augment(params: AttributeDict) -> SpecAugment:
num_frame_masks = int(10 * params.time_mask_ratio)
max_frames_mask_fraction = 0.15 * params.time_mask_ratio
logging.info(
f"num_frame_masks: {num_frame_masks}, "
f"max_frames_mask_fraction: {max_frames_mask_fraction}"
)
spec_augment = SpecAugment(
time_warp_factor=0, # Do time warping in model.py
num_frame_masks=num_frame_masks, # default: 10
features_mask_size=27,
num_feature_masks=2,
frames_mask_size=100,
max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15
)
return spec_augment
def load_checkpoint_if_available( def load_checkpoint_if_available(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
@ -839,6 +878,7 @@ def compute_loss(
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
batch: dict, batch: dict,
is_training: bool, is_training: bool,
spec_augment: Optional[SpecAugment] = None,
) -> Tuple[Tensor, MetricsTracker]: ) -> Tuple[Tensor, MetricsTracker]:
""" """
Compute loss given the model and its inputs. Compute loss given the model and its inputs.
@ -855,8 +895,8 @@ def compute_loss(
True for training. False for validation. When it is True, this True for training. False for validation. When it is True, this
function enables autograd during computation; when it is False, it function enables autograd during computation; when it is False, it
disables autograd. disables autograd.
warmup: a floating point value which increases throughout training; spec_augment:
values >= 1.0 are fully warmed up and have all modules present. The SpecAugment instance used only when use_cr_ctc is True.
""" """
device = model.device if isinstance(model, DDP) else next(model.parameters()).device device = model.device if isinstance(model, DDP) else next(model.parameters()).device
feature = batch["inputs"] feature = batch["inputs"]
@ -874,14 +914,34 @@ def compute_loss(
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y) y = k2.RaggedTensor(y)
use_cr_ctc = params.use_cr_ctc
use_spec_aug = use_cr_ctc and is_training
if use_spec_aug:
supervision_intervals = batch["supervisions"]
supervision_segments = torch.stack(
[
supervision_intervals["sequence_idx"],
supervision_intervals["start_frame"],
supervision_intervals["num_frames"],
],
dim=1,
) # shape: (S, 3)
else:
supervision_segments = None
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model( simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
prune_range=params.prune_range, prune_range=params.prune_range,
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
use_cr_ctc=use_cr_ctc,
use_spec_aug=use_spec_aug,
spec_augment=spec_augment,
supervision_segments=supervision_segments,
time_warp_factor=params.spec_aug_time_warp_factor,
) )
loss = 0.0 loss = 0.0
@ -904,6 +964,8 @@ def compute_loss(
if params.use_ctc: if params.use_ctc:
loss += params.ctc_loss_scale * ctc_loss loss += params.ctc_loss_scale * ctc_loss
if use_cr_ctc:
loss += params.cr_loss_scale * cr_loss
if params.use_attention_decoder: if params.use_attention_decoder:
loss += params.attention_decoder_loss_scale * attention_decoder_loss loss += params.attention_decoder_loss_scale * attention_decoder_loss
@ -922,6 +984,8 @@ def compute_loss(
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if params.use_ctc: if params.use_ctc:
info["ctc_loss"] = ctc_loss.detach().cpu().item() info["ctc_loss"] = ctc_loss.detach().cpu().item()
if params.use_cr_ctc:
info["cr_loss"] = cr_loss.detach().cpu().item()
if params.use_attention_decoder: if params.use_attention_decoder:
info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item() info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()
@ -971,6 +1035,7 @@ def train_one_epoch(
train_dl: torch.utils.data.DataLoader, train_dl: torch.utils.data.DataLoader,
valid_dl: torch.utils.data.DataLoader, valid_dl: torch.utils.data.DataLoader,
scaler: GradScaler, scaler: GradScaler,
spec_augment: Optional[SpecAugment] = None,
model_avg: Optional[nn.Module] = None, model_avg: Optional[nn.Module] = None,
tb_writer: Optional[SummaryWriter] = None, tb_writer: Optional[SummaryWriter] = None,
world_size: int = 1, world_size: int = 1,
@ -997,6 +1062,8 @@ def train_one_epoch(
Dataloader for the validation dataset. Dataloader for the validation dataset.
scaler: scaler:
The scaler used for mix precision training. The scaler used for mix precision training.
spec_augment:
The SpecAugment instance used only when use_cr_ctc is True.
model_avg: model_avg:
The stored model averaged from the start of training. The stored model averaged from the start of training.
tb_writer: tb_writer:
@ -1043,6 +1110,7 @@ def train_one_epoch(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
spec_augment=spec_augment,
) )
# summary stats # summary stats
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
@ -1238,6 +1306,13 @@ def run(rank, world_size, args):
num_param = sum([p.numel() for p in model.parameters()]) num_param = sum([p.numel() for p in model.parameters()])
logging.info(f"Number of model parameters: {num_param}") logging.info(f"Number of model parameters: {num_param}")
if params.use_cr_ctc:
assert params.use_ctc
assert not params.enable_spec_aug # we will do spec_augment in model.py
spec_augment = get_spec_augment(params)
else:
spec_augment = None
assert params.save_every_n >= params.average_period assert params.save_every_n >= params.average_period
model_avg: Optional[nn.Module] = None model_avg: Optional[nn.Module] = None
if rank == 0: if rank == 0:
@ -1360,6 +1435,7 @@ def run(rank, world_size, args):
optimizer=optimizer, optimizer=optimizer,
sp=sp, sp=sp,
params=params, params=params,
spec_augment=spec_augment,
) )
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0) scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
@ -1387,6 +1463,7 @@ def run(rank, world_size, args):
train_dl=train_dl, train_dl=train_dl,
valid_dl=valid_dl, valid_dl=valid_dl,
scaler=scaler, scaler=scaler,
spec_augment=spec_augment,
tb_writer=tb_writer, tb_writer=tb_writer,
world_size=world_size, world_size=world_size,
rank=rank, rank=rank,
@ -1452,6 +1529,7 @@ def scan_pessimistic_batches_for_oom(
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
sp: spm.SentencePieceProcessor, sp: spm.SentencePieceProcessor,
params: AttributeDict, params: AttributeDict,
spec_augment: Optional[SpecAugment] = None,
): ):
from lhotse.dataset import find_pessimistic_batches from lhotse.dataset import find_pessimistic_batches
@ -1471,6 +1549,7 @@ def scan_pessimistic_batches_for_oom(
sp=sp, sp=sp,
batch=batch, batch=batch,
is_training=True, is_training=True,
spec_augment=spec_augment,
) )
loss.backward() loss.backward()
optimizer.zero_grad() optimizer.zero_grad()

View File

@ -21,6 +21,7 @@ import argparse
import collections import collections
import logging import logging
import os import os
import random
import re import re
import subprocess import subprocess
from collections import defaultdict from collections import defaultdict
@ -38,6 +39,7 @@ import sentencepiece as spm
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from lhotse.dataset.signal_transforms import time_warp as time_warp_impl
from pypinyin import lazy_pinyin, pinyin from pypinyin import lazy_pinyin, pinyin
from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
@ -2271,3 +2273,41 @@ def num_tokens(
if 0 in ans: if 0 in ans:
num_tokens -= 1 num_tokens -= 1
return num_tokens return num_tokens
# Based on https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py
def time_warp(
features: torch.Tensor,
p: float = 0.9,
time_warp_factor: Optional[int] = 80,
supervision_segments: Optional[torch.Tensor] = None,
):
"""Apply time warping on a batch of features
"""
if time_warp_factor is None or time_warp_factor < 1:
return features
assert len(features.shape) == 3, (
"SpecAugment only supports batches of single-channel feature matrices."
)
features = features.clone()
if supervision_segments is None:
# No supervisions - apply spec augment to full feature matrices.
for sequence_idx in range(features.size(0)):
if random.random() > p:
# Randomly choose whether this transform is applied
continue
features[sequence_idx] = time_warp_impl(
features[sequence_idx], factor=time_warp_factor
)
else:
# Supervisions provided - we will apply time warping only on the supervised areas.
for sequence_idx, start_frame, num_frames in supervision_segments:
if random.random() > p:
# Randomly choose whether this transform is applied
continue
end_frame = start_frame + num_frames
features[sequence_idx, start_frame:end_frame] = time_warp_impl(
features[sequence_idx, start_frame:end_frame], factor=time_warp_factor
)
return features