mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add Consistency-Regularized CTC (#1766)
* support consistency-regularized CTC * update arguments of cr-ctc * set default value of cr_loss_masked_scale to 1.0 * minor fix * refactor codes * update RESULTS.md
This commit is contained in:
parent
f84270c935
commit
693d84a301
@ -50,7 +50,7 @@ We place an additional Conv1d layer right after the input embedding layer.
|
||||
| `conformer-ctc2` | Reworked Conformer | Use auxiliary attention head |
|
||||
| `conformer-ctc3` | Reworked Conformer | Streaming version + delay penalty |
|
||||
| `zipformer-ctc` | Zipformer | Use auxiliary attention head |
|
||||
| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head / attention-decoder head | The latest recipe |
|
||||
| `zipformer` | Upgraded Zipformer | Use auxiliary transducer head / attention-decoder head (the latest recipe) |
|
||||
|
||||
# MMI
|
||||
|
||||
@ -58,3 +58,9 @@ We place an additional Conv1d layer right after the input embedding layer.
|
||||
|------------------------------|-----------|---------------------------------------------------|
|
||||
| `conformer-mmi` | Conformer | |
|
||||
| `zipformer-mmi` | Zipformer | CTC warmup + use HP as decoding graph for decoding |
|
||||
|
||||
# CR-CTC
|
||||
|
||||
| | Encoder | Comment |
|
||||
|------------------------------|--------------------|------------------------------|
|
||||
| `zipformer` | Upgraded Zipformer | Could also be an auxiliary loss to improve transducer or CTC/AED (the latest recipe) |
|
||||
|
@ -1,5 +1,315 @@
|
||||
## Results
|
||||
|
||||
### zipformer (zipformer + pruned-transducer w/ CR-CTC)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/1766> for more details.
|
||||
|
||||
[zipformer](./zipformer)
|
||||
|
||||
#### Non-streaming
|
||||
|
||||
##### large-scale model, number of model parameters: 148824074, i.e., 148.8 M
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-large-transducer-with-CR-CTC-20241019>
|
||||
|
||||
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
||||
|
||||
| decoding method | test-clean | test-other | comment |
|
||||
|--------------------------------------|------------|------------|---------------------|
|
||||
| greedy_search | 1.9 | 3.96 | --epoch 50 --avg 26 |
|
||||
| modified_beam_search | 1.88 | 3.95 | --epoch 50 --avg 26 |
|
||||
|
||||
The training command using 2 80G-A100 GPUs is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
# for non-streaming model training:
|
||||
./zipformer/train.py \
|
||||
--world-size 2 \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp-large-cr-ctc-rnnt \
|
||||
--use-cr-ctc 1 \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 1 \
|
||||
--use-attention-decoder 0 \
|
||||
--num-encoder-layers 2,2,4,5,4,2 \
|
||||
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||
--encoder-dim 192,256,512,768,512,256 \
|
||||
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
||||
--ctc-loss-scale 0.1 \
|
||||
--enable-spec-aug 0 \
|
||||
--cr-loss-scale 0.02 \
|
||||
--time-mask-ratio 2.5 \
|
||||
--full-libri 1 \
|
||||
--max-duration 1400 \
|
||||
--master-port 12345
|
||||
```
|
||||
|
||||
The decoding command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
for m in greedy_search modified_beam_search; do
|
||||
./zipformer/decode.py \
|
||||
--epoch 50 \
|
||||
--avg 26 \
|
||||
--exp-dir zipformer/exp-large-cr-ctc-rnnt \
|
||||
--use-cr-ctc 1 \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 1 \
|
||||
--use-attention-decoder 0 \
|
||||
--num-encoder-layers 2,2,4,5,4,2 \
|
||||
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||
--encoder-dim 192,256,512,768,512,256 \
|
||||
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
||||
--max-duration 300 \
|
||||
--decoding-method $m
|
||||
done
|
||||
```
|
||||
|
||||
### zipformer (zipformer + CR-CTC-AED)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/1766> for more details.
|
||||
|
||||
[zipformer](./zipformer)
|
||||
|
||||
#### Non-streaming
|
||||
|
||||
##### large-scale model, number of model parameters: 174319650, i.e., 174.3 M
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-large-cr-ctc-aed-20241020>
|
||||
|
||||
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
||||
|
||||
| decoding method | test-clean | test-other | comment |
|
||||
|--------------------------------------|------------|------------|---------------------|
|
||||
| attention-decoder-rescoring-no-ngram | 1.96 | 4.08 | --epoch 50 --avg 20 |
|
||||
|
||||
The training command using 2 80G-A100 GPUs is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
# for non-streaming model training:
|
||||
./zipformer/train.py \
|
||||
--world-size 2 \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp-large-cr-ctc-aed \
|
||||
--use-cr-ctc 1 \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 0 \
|
||||
--use-attention-decoder 1 \
|
||||
--num-encoder-layers 2,2,4,5,4,2 \
|
||||
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||
--encoder-dim 192,256,512,768,512,256 \
|
||||
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
||||
--ctc-loss-scale 0.1 \
|
||||
--attention-decoder-loss-scale 0.9 \
|
||||
--enable-spec-aug 0 \
|
||||
--cr-loss-scale 0.02 \
|
||||
--time-mask-ratio 2.5 \
|
||||
--full-libri 1 \
|
||||
--max-duration 1200 \
|
||||
--master-port 12345
|
||||
```
|
||||
|
||||
The decoding command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 50 \
|
||||
--avg 20 \
|
||||
--exp-dir zipformer/exp-large-cr-ctc-aed/ \
|
||||
--use-cr-ctc 1 \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 0 \
|
||||
--use-attention-decoder 1 \
|
||||
--num-encoder-layers 2,2,4,5,4,2 \
|
||||
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||
--encoder-dim 192,256,512,768,512,256 \
|
||||
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
||||
--max-duration 200 \
|
||||
--decoding-method attention-decoder-rescoring-no-ngram
|
||||
done
|
||||
```
|
||||
|
||||
### zipformer (zipformer + CR-CTC)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/1766> for more details.
|
||||
|
||||
[zipformer](./zipformer)
|
||||
|
||||
#### Non-streaming
|
||||
|
||||
##### small-scale model, number of model parameters: 22118279, i.e., 22.1 M
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-small-cr-ctc-20241018>
|
||||
|
||||
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
||||
|
||||
| decoding method | test-clean | test-other | comment |
|
||||
|--------------------------------------|------------|------------|---------------------|
|
||||
| ctc-greedy-decoding | 2.57 | 5.95 | --epoch 50 --avg 25 |
|
||||
|
||||
The training command using 2 32G-V100 GPUs is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
# for non-streaming model training:
|
||||
./zipformer/train.py \
|
||||
--world-size 2 \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp-small/ \
|
||||
--use-cr-ctc 1 \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 0 \
|
||||
--use-attention-decoder 0 \
|
||||
--num-encoder-layers 2,2,2,2,2,2 \
|
||||
--feedforward-dim 512,768,768,768,768,768 \
|
||||
--encoder-dim 192,256,256,256,256,256 \
|
||||
--encoder-unmasked-dim 192,192,192,192,192,192 \
|
||||
--base-lr 0.04 \
|
||||
--enable-spec-aug 0 \
|
||||
--cr-loss-scale 0.2 \
|
||||
--time-mask-ratio 2.5 \
|
||||
--full-libri 1 \
|
||||
--max-duration 850 \
|
||||
--master-port 12345
|
||||
```
|
||||
|
||||
The decoding command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
for m in ctc-greedy-search; do
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 50 \
|
||||
--avg 25 \
|
||||
--exp-dir zipformer/exp-small \
|
||||
--use-cr-ctc 1 \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 0 \
|
||||
--use-attention-decoder 0 \
|
||||
--num-encoder-layers 2,2,2,2,2,2 \
|
||||
--feedforward-dim 512,768,768,768,768,768 \
|
||||
--encoder-dim 192,256,256,256,256,256 \
|
||||
--encoder-unmasked-dim 192,192,192,192,192,192 \
|
||||
--max-duration 600 \
|
||||
--decoding-method $m
|
||||
done
|
||||
```
|
||||
|
||||
##### medium-scale model, number of model parameters: 64250603, i.e., 64.3 M
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-medium-cr-ctc-20241018>
|
||||
|
||||
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
||||
|
||||
| decoding method | test-clean | test-other | comment |
|
||||
|--------------------------------------|------------|------------|---------------------|
|
||||
| ctc-greedy-decoding | 2.12 | 4.62 | --epoch 50 --avg 24 |
|
||||
|
||||
The training command using 4 32G-V100 GPUs is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
# For non-streaming model training:
|
||||
./zipformer/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp \
|
||||
--use-cr-ctc 1 \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 0 \
|
||||
--use-attention-decoder 0 \
|
||||
--enable-spec-aug 0 \
|
||||
--cr-loss-scale 0.2 \
|
||||
--time-mask-ratio 2.5 \
|
||||
--full-libri 1 \
|
||||
--max-duration 700 \
|
||||
--master-port 12345
|
||||
```
|
||||
|
||||
The decoding command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
for m in ctc-greedy-search; do
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 50 \
|
||||
--avg 24 \
|
||||
--exp-dir zipformer/exp \
|
||||
--use-cr-ctc 1 \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 0 \
|
||||
--use-attention-decoder 0 \
|
||||
--max-duration 600 \
|
||||
--decoding-method $m
|
||||
done
|
||||
```
|
||||
|
||||
##### large-scale model, number of model parameters: 147010094, i.e., 147.0 M
|
||||
|
||||
You can find a pretrained model, training logs, decoding logs, and decoding results at:
|
||||
<https://huggingface.co/Zengwei/icefall-asr-librispeech-zipformer-large-cr-ctc-20241018>
|
||||
|
||||
You can use <https://github.com/k2-fsa/sherpa> to deploy it.
|
||||
|
||||
| decoding method | test-clean | test-other | comment |
|
||||
|--------------------------------------|------------|------------|---------------------|
|
||||
| ctc-greedy-decoding | 2.03 | 4.37 | --epoch 50 --avg 26 |
|
||||
|
||||
The training command using 2 80G-A100 GPUs is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0,1"
|
||||
# For non-streaming model training:
|
||||
./zipformer/train.py \
|
||||
--world-size 2 \
|
||||
--num-epochs 50 \
|
||||
--start-epoch 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir zipformer/exp-large \
|
||||
--use-cr-ctc 1 \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 0 \
|
||||
--use-attention-decoder 0 \
|
||||
--num-encoder-layers 2,2,4,5,4,2 \
|
||||
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||
--encoder-dim 192,256,512,768,512,256 \
|
||||
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
||||
--enable-spec-aug 0 \
|
||||
--cr-loss-scale 0.2 \
|
||||
--time-mask-ratio 2.5 \
|
||||
--full-libri 1 \
|
||||
--max-duration 1400 \
|
||||
--master-port 12345
|
||||
```
|
||||
|
||||
The decoding command is:
|
||||
```bash
|
||||
export CUDA_VISIBLE_DEVICES="0"
|
||||
for m in ctc-greedy-search; do
|
||||
./zipformer/ctc_decode.py \
|
||||
--epoch 50 \
|
||||
--avg 26 \
|
||||
--exp-dir zipformer/exp-large \
|
||||
--use-cr-ctc 1 \
|
||||
--use-ctc 1 \
|
||||
--use-transducer 0 \
|
||||
--use-attention-decoder 0 \
|
||||
--num-encoder-layers 2,2,4,5,4,2 \
|
||||
--feedforward-dim 512,768,1536,2048,1536,768 \
|
||||
--encoder-dim 192,256,512,768,512,256 \
|
||||
--encoder-unmasked-dim 192,192,256,320,256,192 \
|
||||
--max-duration 600 \
|
||||
--decoding-method $m
|
||||
done
|
||||
```
|
||||
|
||||
### zipformer (zipformer + CTC/AED)
|
||||
|
||||
See <https://github.com/k2-fsa/icefall/pull/1389> for more details.
|
||||
|
@ -24,7 +24,8 @@ import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
|
||||
from icefall.utils import add_sos, make_pad_mask
|
||||
from icefall.utils import add_sos, make_pad_mask, time_warp
|
||||
from lhotse.dataset import SpecAugment
|
||||
|
||||
|
||||
class AsrModel(nn.Module):
|
||||
@ -181,6 +182,49 @@ class AsrModel(nn.Module):
|
||||
)
|
||||
return ctc_loss
|
||||
|
||||
def forward_cr_ctc(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
target_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute CTC loss with consistency regularization loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (2 * N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (2 * N,).
|
||||
targets:
|
||||
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
|
||||
to be un-padded and concatenated within 1 dimension.
|
||||
"""
|
||||
# Compute CTC loss
|
||||
ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C)
|
||||
ctc_loss = torch.nn.functional.ctc_loss(
|
||||
log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C)
|
||||
targets=targets.cpu(),
|
||||
input_lengths=encoder_out_lens.cpu(),
|
||||
target_lengths=target_lengths.cpu(),
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
# Compute consistency regularization loss
|
||||
exchanged_targets = ctc_output.detach().chunk(2, dim=0)
|
||||
exchanged_targets = torch.cat(
|
||||
[exchanged_targets[1], exchanged_targets[0]], dim=0
|
||||
) # exchange: [x1, x2] -> [x2, x1]
|
||||
cr_loss = nn.functional.kl_div(
|
||||
input=ctc_output,
|
||||
target=exchanged_targets,
|
||||
reduction="none",
|
||||
log_target=True,
|
||||
) # (2 * N, T, C)
|
||||
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
|
||||
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
|
||||
|
||||
return ctc_loss, cr_loss
|
||||
|
||||
def forward_transducer(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
@ -296,7 +340,12 @@ class AsrModel(nn.Module):
|
||||
prune_range: int = 5,
|
||||
am_scale: float = 0.0,
|
||||
lm_scale: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
use_cr_ctc: bool = False,
|
||||
use_spec_aug: bool = False,
|
||||
spec_augment: Optional[SpecAugment] = None,
|
||||
supervision_segments: Optional[torch.Tensor] = None,
|
||||
time_warp_factor: Optional[int] = 80,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
@ -316,9 +365,26 @@ class AsrModel(nn.Module):
|
||||
lm_scale:
|
||||
The scale to smooth the loss with lm (output of predictor network)
|
||||
part
|
||||
use_cr_ctc:
|
||||
Whether use consistency-regularized CTC.
|
||||
use_spec_aug:
|
||||
Whether apply spec-augment manually, used only if use_cr_ctc is True.
|
||||
spec_augment:
|
||||
The SpecAugment instance that returns time masks,
|
||||
used only if use_cr_ctc is True.
|
||||
supervision_segments:
|
||||
An int tensor of shape ``(S, 3)``. ``S`` is the number of
|
||||
supervision segments that exist in ``features``.
|
||||
Used only if use_cr_ctc is True.
|
||||
time_warp_factor:
|
||||
Parameter for the time warping; larger values mean more warping.
|
||||
Set to ``None``, or less than ``1``, to disable.
|
||||
Used only if use_cr_ctc is True.
|
||||
|
||||
Returns:
|
||||
Return the transducer losses and CTC loss,
|
||||
in form of (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss)
|
||||
Return the transducer losses, CTC loss, AED loss,
|
||||
and consistency-regularization loss in form of
|
||||
(simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss)
|
||||
|
||||
Note:
|
||||
Regarding am_scale & lm_scale, it will make the loss-function one of
|
||||
@ -334,6 +400,24 @@ class AsrModel(nn.Module):
|
||||
|
||||
device = x.device
|
||||
|
||||
if use_cr_ctc:
|
||||
assert self.use_ctc
|
||||
if use_spec_aug:
|
||||
assert spec_augment is not None and spec_augment.time_warp_factor < 1
|
||||
# Apply time warping before input duplicating
|
||||
assert supervision_segments is not None
|
||||
x = time_warp(
|
||||
x,
|
||||
time_warp_factor=time_warp_factor,
|
||||
supervision_segments=supervision_segments,
|
||||
)
|
||||
# Independently apply frequency masking and time masking to the two copies
|
||||
x = spec_augment(x.repeat(2, 1, 1))
|
||||
else:
|
||||
x = x.repeat(2, 1, 1)
|
||||
x_lens = x_lens.repeat(2)
|
||||
y = k2.ragged.cat([y, y], axis=0)
|
||||
|
||||
# Compute encoder outputs
|
||||
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
|
||||
|
||||
@ -351,6 +435,9 @@ class AsrModel(nn.Module):
|
||||
am_scale=am_scale,
|
||||
lm_scale=lm_scale,
|
||||
)
|
||||
if use_cr_ctc:
|
||||
simple_loss = simple_loss * 0.5
|
||||
pruned_loss = pruned_loss * 0.5
|
||||
else:
|
||||
simple_loss = torch.empty(0)
|
||||
pruned_loss = torch.empty(0)
|
||||
@ -358,14 +445,26 @@ class AsrModel(nn.Module):
|
||||
if self.use_ctc:
|
||||
# Compute CTC loss
|
||||
targets = y.values
|
||||
ctc_loss = self.forward_ctc(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
targets=targets,
|
||||
target_lengths=y_lens,
|
||||
)
|
||||
if not use_cr_ctc:
|
||||
ctc_loss = self.forward_ctc(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
targets=targets,
|
||||
target_lengths=y_lens,
|
||||
)
|
||||
cr_loss = torch.empty(0)
|
||||
else:
|
||||
ctc_loss, cr_loss = self.forward_cr_ctc(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
targets=targets,
|
||||
target_lengths=y_lens,
|
||||
)
|
||||
ctc_loss = ctc_loss * 0.5
|
||||
cr_loss = cr_loss * 0.5
|
||||
else:
|
||||
ctc_loss = torch.empty(0)
|
||||
cr_loss = torch.empty(0)
|
||||
|
||||
if self.use_attention_decoder:
|
||||
attention_decoder_loss = self.attention_decoder.calc_att_loss(
|
||||
@ -374,7 +473,9 @@ class AsrModel(nn.Module):
|
||||
ys=y.to(device),
|
||||
ys_lens=y_lens.to(device),
|
||||
)
|
||||
if use_cr_ctc:
|
||||
attention_decoder_loss = attention_decoder_loss * 0.5
|
||||
else:
|
||||
attention_decoder_loss = torch.empty(0)
|
||||
|
||||
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss
|
||||
return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss
|
||||
|
@ -45,11 +45,10 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--max-duration 1000
|
||||
|
||||
It supports training with:
|
||||
- transducer loss (default), with `--use-transducer True --use-ctc False`
|
||||
- ctc loss (not recommended), with `--use-transducer False --use-ctc True`
|
||||
- transducer loss & ctc loss, with `--use-transducer True --use-ctc True`
|
||||
- ctc loss & attention decoder loss, no transducer loss,
|
||||
with `--use-transducer False --use-ctc True --use-attention-decoder True`
|
||||
- transducer loss (default)
|
||||
- ctc loss
|
||||
- attention decoder loss
|
||||
- cr-ctc loss (should use half the max-duration compared to regular ctc)
|
||||
"""
|
||||
|
||||
|
||||
@ -72,6 +71,7 @@ from attention_decoder import AttentionDecoderModel
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.dataset import SpecAugment
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from model import AsrModel
|
||||
@ -304,6 +304,13 @@ def add_model_arguments(parser: argparse.ArgumentParser):
|
||||
help="If True, use attention-decoder head.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use-cr-ctc",
|
||||
type=str2bool,
|
||||
default=False,
|
||||
help="If True, use consistency-regularized CTC.",
|
||||
)
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
@ -449,6 +456,20 @@ def get_parser():
|
||||
help="Scale for CTC loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cr-loss-scale",
|
||||
type=float,
|
||||
default=0.2,
|
||||
help="Scale for consistency-regularization loss.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--time-mask-ratio",
|
||||
type=float,
|
||||
default=2.5,
|
||||
help="When using cr-ctc, we increase the amount of time-masking in SpecAugment.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--attention-decoder-loss-scale",
|
||||
type=float,
|
||||
@ -717,6 +738,24 @@ def get_model(params: AttributeDict) -> nn.Module:
|
||||
return model
|
||||
|
||||
|
||||
def get_spec_augment(params: AttributeDict) -> SpecAugment:
|
||||
num_frame_masks = int(10 * params.time_mask_ratio)
|
||||
max_frames_mask_fraction = 0.15 * params.time_mask_ratio
|
||||
logging.info(
|
||||
f"num_frame_masks: {num_frame_masks}, "
|
||||
f"max_frames_mask_fraction: {max_frames_mask_fraction}"
|
||||
)
|
||||
spec_augment = SpecAugment(
|
||||
time_warp_factor=0, # Do time warping in model.py
|
||||
num_frame_masks=num_frame_masks, # default: 10
|
||||
features_mask_size=27,
|
||||
num_feature_masks=2,
|
||||
frames_mask_size=100,
|
||||
max_frames_mask_fraction=max_frames_mask_fraction, # default: 0.15
|
||||
)
|
||||
return spec_augment
|
||||
|
||||
|
||||
def load_checkpoint_if_available(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
@ -839,6 +878,7 @@ def compute_loss(
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
is_training: bool,
|
||||
spec_augment: Optional[SpecAugment] = None,
|
||||
) -> Tuple[Tensor, MetricsTracker]:
|
||||
"""
|
||||
Compute loss given the model and its inputs.
|
||||
@ -855,8 +895,8 @@ def compute_loss(
|
||||
True for training. False for validation. When it is True, this
|
||||
function enables autograd during computation; when it is False, it
|
||||
disables autograd.
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
spec_augment:
|
||||
The SpecAugment instance used only when use_cr_ctc is True.
|
||||
"""
|
||||
device = model.device if isinstance(model, DDP) else next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
@ -874,14 +914,34 @@ def compute_loss(
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y)
|
||||
|
||||
use_cr_ctc = params.use_cr_ctc
|
||||
use_spec_aug = use_cr_ctc and is_training
|
||||
if use_spec_aug:
|
||||
supervision_intervals = batch["supervisions"]
|
||||
supervision_segments = torch.stack(
|
||||
[
|
||||
supervision_intervals["sequence_idx"],
|
||||
supervision_intervals["start_frame"],
|
||||
supervision_intervals["num_frames"],
|
||||
],
|
||||
dim=1,
|
||||
) # shape: (S, 3)
|
||||
else:
|
||||
supervision_segments = None
|
||||
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss = model(
|
||||
simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
prune_range=params.prune_range,
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
use_cr_ctc=use_cr_ctc,
|
||||
use_spec_aug=use_spec_aug,
|
||||
spec_augment=spec_augment,
|
||||
supervision_segments=supervision_segments,
|
||||
time_warp_factor=params.spec_aug_time_warp_factor,
|
||||
)
|
||||
|
||||
loss = 0.0
|
||||
@ -904,6 +964,8 @@ def compute_loss(
|
||||
|
||||
if params.use_ctc:
|
||||
loss += params.ctc_loss_scale * ctc_loss
|
||||
if use_cr_ctc:
|
||||
loss += params.cr_loss_scale * cr_loss
|
||||
|
||||
if params.use_attention_decoder:
|
||||
loss += params.attention_decoder_loss_scale * attention_decoder_loss
|
||||
@ -922,6 +984,8 @@ def compute_loss(
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
if params.use_ctc:
|
||||
info["ctc_loss"] = ctc_loss.detach().cpu().item()
|
||||
if params.use_cr_ctc:
|
||||
info["cr_loss"] = cr_loss.detach().cpu().item()
|
||||
if params.use_attention_decoder:
|
||||
info["attn_decoder_loss"] = attention_decoder_loss.detach().cpu().item()
|
||||
|
||||
@ -971,6 +1035,7 @@ def train_one_epoch(
|
||||
train_dl: torch.utils.data.DataLoader,
|
||||
valid_dl: torch.utils.data.DataLoader,
|
||||
scaler: GradScaler,
|
||||
spec_augment: Optional[SpecAugment] = None,
|
||||
model_avg: Optional[nn.Module] = None,
|
||||
tb_writer: Optional[SummaryWriter] = None,
|
||||
world_size: int = 1,
|
||||
@ -997,6 +1062,8 @@ def train_one_epoch(
|
||||
Dataloader for the validation dataset.
|
||||
scaler:
|
||||
The scaler used for mix precision training.
|
||||
spec_augment:
|
||||
The SpecAugment instance used only when use_cr_ctc is True.
|
||||
model_avg:
|
||||
The stored model averaged from the start of training.
|
||||
tb_writer:
|
||||
@ -1043,6 +1110,7 @@ def train_one_epoch(
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
spec_augment=spec_augment,
|
||||
)
|
||||
# summary stats
|
||||
tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info
|
||||
@ -1238,6 +1306,13 @@ def run(rank, world_size, args):
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
if params.use_cr_ctc:
|
||||
assert params.use_ctc
|
||||
assert not params.enable_spec_aug # we will do spec_augment in model.py
|
||||
spec_augment = get_spec_augment(params)
|
||||
else:
|
||||
spec_augment = None
|
||||
|
||||
assert params.save_every_n >= params.average_period
|
||||
model_avg: Optional[nn.Module] = None
|
||||
if rank == 0:
|
||||
@ -1360,6 +1435,7 @@ def run(rank, world_size, args):
|
||||
optimizer=optimizer,
|
||||
sp=sp,
|
||||
params=params,
|
||||
spec_augment=spec_augment,
|
||||
)
|
||||
|
||||
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
|
||||
@ -1387,6 +1463,7 @@ def run(rank, world_size, args):
|
||||
train_dl=train_dl,
|
||||
valid_dl=valid_dl,
|
||||
scaler=scaler,
|
||||
spec_augment=spec_augment,
|
||||
tb_writer=tb_writer,
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
@ -1452,6 +1529,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
params: AttributeDict,
|
||||
spec_augment: Optional[SpecAugment] = None,
|
||||
):
|
||||
from lhotse.dataset import find_pessimistic_batches
|
||||
|
||||
@ -1471,6 +1549,7 @@ def scan_pessimistic_batches_for_oom(
|
||||
sp=sp,
|
||||
batch=batch,
|
||||
is_training=True,
|
||||
spec_augment=spec_augment,
|
||||
)
|
||||
loss.backward()
|
||||
optimizer.zero_grad()
|
||||
|
@ -21,6 +21,7 @@ import argparse
|
||||
import collections
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import subprocess
|
||||
from collections import defaultdict
|
||||
@ -38,6 +39,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from lhotse.dataset.signal_transforms import time_warp as time_warp_impl
|
||||
from pypinyin import lazy_pinyin, pinyin
|
||||
from pypinyin.contrib.tone_convert import to_finals, to_finals_tone, to_initials
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
@ -2271,3 +2273,41 @@ def num_tokens(
|
||||
if 0 in ans:
|
||||
num_tokens -= 1
|
||||
return num_tokens
|
||||
|
||||
|
||||
# Based on https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/signal_transforms.py
|
||||
def time_warp(
|
||||
features: torch.Tensor,
|
||||
p: float = 0.9,
|
||||
time_warp_factor: Optional[int] = 80,
|
||||
supervision_segments: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Apply time warping on a batch of features
|
||||
"""
|
||||
if time_warp_factor is None or time_warp_factor < 1:
|
||||
return features
|
||||
assert len(features.shape) == 3, (
|
||||
"SpecAugment only supports batches of single-channel feature matrices."
|
||||
)
|
||||
features = features.clone()
|
||||
if supervision_segments is None:
|
||||
# No supervisions - apply spec augment to full feature matrices.
|
||||
for sequence_idx in range(features.size(0)):
|
||||
if random.random() > p:
|
||||
# Randomly choose whether this transform is applied
|
||||
continue
|
||||
features[sequence_idx] = time_warp_impl(
|
||||
features[sequence_idx], factor=time_warp_factor
|
||||
)
|
||||
else:
|
||||
# Supervisions provided - we will apply time warping only on the supervised areas.
|
||||
for sequence_idx, start_frame, num_frames in supervision_segments:
|
||||
if random.random() > p:
|
||||
# Randomly choose whether this transform is applied
|
||||
continue
|
||||
end_frame = start_frame + num_frames
|
||||
features[sequence_idx, start_frame:end_frame] = time_warp_impl(
|
||||
features[sequence_idx, start_frame:end_frame], factor=time_warp_factor
|
||||
)
|
||||
|
||||
return features
|
||||
|
Loading…
x
Reference in New Issue
Block a user