train/decode with codebook loss

This commit is contained in:
Guo Liyong 2022-04-27 14:18:12 +08:00
parent 979f574259
commit 9d48f1ce7d
2 changed files with 61 additions and 29 deletions

View File

@ -18,36 +18,36 @@
""" """
Usage: Usage:
(1) greedy search (1) greedy search
./pruned_transducer_stateless2/decode.py \ ./vq_pruned_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./vq_pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method greedy_search --decoding-method greedy_search
(2) beam search (2) beam search
./pruned_transducer_stateless2/decode.py \ ./vq_pruned_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./vq_pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method beam_search \ --decoding-method beam_search \
--beam-size 4 --beam-size 4
(3) modified beam search (3) modified beam search
./pruned_transducer_stateless2/decode.py \ ./vq_pruned_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./vq_pruned_transducer_stateless2/exp \
--max-duration 100 \ --max-duration 100 \
--decoding-method modified_beam_search \ --decoding-method modified_beam_search \
--beam-size 4 --beam-size 4
(4) fast beam search (4) fast beam search
./pruned_transducer_stateless2/decode.py \ ./vq_pruned_transducer_stateless2/decode.py \
--epoch 28 \ --epoch 28 \
--avg 15 \ --avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./vq_pruned_transducer_stateless2/exp \
--max-duration 1500 \ --max-duration 1500 \
--decoding-method fast_beam_search \ --decoding-method fast_beam_search \
--beam 4 \ --beam 4 \
@ -124,7 +124,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="vq_pruned_transducer_stateless2/exp",
help="The experiment dir", help="The experiment dir",
) )

View File

@ -19,27 +19,19 @@
""" """
Usage: Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3" export CUDA_VISIBLE_DEVICES="0,1,2"
./pruned_transducer_stateless2/train.py \ ./vq_pruned_transducer_stateless2/train.py \
--world-size 4 \ --world-size 3 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \ --manifest-dir data/quantizer-finetuned_hubert_xtralarge-36layer-f0238f68-bytes_per_frame-8/ \
--full-libri 1 \ --codebook-loss-scale 0.1 \
--num-codebooks=8 \
--exp-dir vq_pruned_transducer_stateless2/exp \
--full-libri 0 \
--max-duration 300 --max-duration 300
# For mix precision training:
./pruned_transducer_stateless2/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--use_fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \
--max-duration 550
""" """
@ -60,9 +52,10 @@ from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner
from lhotse.cut import Cut from lhotse.cut import Cut, MonoCut
from lhotse.dataset.sampling.base import CutSampler from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from lhotse.dataset.collation import collate_custom_field
from model import Transducer from model import Transducer
from optim import Eden, Eve from optim import Eden, Eve
from torch import Tensor from torch import Tensor
@ -138,7 +131,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="pruned_transducer_stateless2/exp", default="vq_pruned_transducer_stateless2/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -262,6 +255,23 @@ def get_parser():
help="Whether to use half precision training.", help="Whether to use half precision training.",
) )
parser.add_argument(
"--codebook-loss-scale",
type=float,
default=0.1,
help="""The weight of code book loss.
Note: Currently rate of ctc_loss + rate of att_loss = 1.0
codebook_loss is independent with them.
""",
)
parser.add_argument(
"--num-codebooks",
type=int,
default=8,
help="number of code books",
)
return parser return parser
@ -333,6 +343,8 @@ def get_params() -> AttributeDict:
# parameters for Noam # parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate "model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(), "env_info": get_env_info(),
"extra_output_layer": 5, # 0-based index
"num_codebooks": 8,
} }
) )
@ -348,6 +360,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
nhead=params.nhead, nhead=params.nhead,
dim_feedforward=params.dim_feedforward, dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers, num_encoder_layers=params.num_encoder_layers,
extra_output_layer=params.extra_output_layer,
) )
return encoder return encoder
@ -385,6 +398,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder_dim=params.decoder_dim, decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim, joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size, vocab_size=params.vocab_size,
num_codebooks=params.num_codebooks,
) )
return model return model
@ -539,8 +553,20 @@ def compute_loss(
y = sp.encode(texts, out_type=int) y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device) y = k2.RaggedTensor(y).to(device)
if is_training:
cuts = batch["supervisions"]["cut"]
# -100 is identical to ignore_value in CE loss computation.
cuts_pre_mixed = [
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
]
codebook_indices, codebook_indices_lens = collate_custom_field(
cuts_pre_mixed, "codebook_indices", pad_value=-100
)
codebook_indices = codebook_indices.to(device)
else:
codebook_indices = None
with torch.set_grad_enabled(is_training): with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model( simple_loss, pruned_loss, codebook_loss = model(
x=feature, x=feature,
x_lens=feature_lens, x_lens=feature_lens,
y=y, y=y,
@ -548,6 +574,7 @@ def compute_loss(
am_scale=params.am_scale, am_scale=params.am_scale,
lm_scale=params.lm_scale, lm_scale=params.lm_scale,
warmup=warmup, warmup=warmup,
codebook_indices=codebook_indices,
) )
# after the main warmup step, we keep pruned_loss_scale small # after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid # for the same amount of time (model_warm_step), to avoid
@ -562,6 +589,9 @@ def compute_loss(
params.simple_loss_scale * simple_loss params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss + pruned_loss_scale * pruned_loss
) )
if is_training:
assert codebook_loss is not None
loss += params.codebook_loss_scale * codebook_loss
assert loss.requires_grad == is_training assert loss.requires_grad == is_training
@ -576,6 +606,8 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item() info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item() info["pruned_loss"] = pruned_loss.detach().cpu().item()
if is_training:
info["codebook_loss"] = codebook_loss.detach().cpu().item()
return loss, info return loss, info