train/decode with codebook loss

This commit is contained in:
Guo Liyong 2022-04-27 14:18:12 +08:00
parent 979f574259
commit 9d48f1ce7d
2 changed files with 61 additions and 29 deletions

View File

@ -18,36 +18,36 @@
"""
Usage:
(1) greedy search
./pruned_transducer_stateless2/decode.py \
./vq_pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./vq_pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method greedy_search
(2) beam search
./pruned_transducer_stateless2/decode.py \
./vq_pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./vq_pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless2/decode.py \
./vq_pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./vq_pruned_transducer_stateless2/exp \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search
./pruned_transducer_stateless2/decode.py \
./vq_pruned_transducer_stateless2/decode.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless2/exp \
--exp-dir ./vq_pruned_transducer_stateless2/exp \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
@ -124,7 +124,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="vq_pruned_transducer_stateless2/exp",
help="The experiment dir",
)

View File

@ -19,27 +19,19 @@
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export CUDA_VISIBLE_DEVICES="0,1,2"
./pruned_transducer_stateless2/train.py \
--world-size 4 \
./vq_pruned_transducer_stateless2/train.py \
--world-size 3 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \
--manifest-dir data/quantizer-finetuned_hubert_xtralarge-36layer-f0238f68-bytes_per_frame-8/ \
--codebook-loss-scale 0.1 \
--num-codebooks=8 \
--exp-dir vq_pruned_transducer_stateless2/exp \
--full-libri 0 \
--max-duration 300
# For mix precision training:
./pruned_transducer_stateless2/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--use_fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \
--max-duration 550
"""
@ -60,9 +52,10 @@ from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.cut import Cut, MonoCut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from lhotse.dataset.collation import collate_custom_field
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
@ -138,7 +131,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless2/exp",
default="vq_pruned_transducer_stateless2/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -262,6 +255,23 @@ def get_parser():
help="Whether to use half precision training.",
)
parser.add_argument(
"--codebook-loss-scale",
type=float,
default=0.1,
help="""The weight of code book loss.
Note: Currently rate of ctc_loss + rate of att_loss = 1.0
codebook_loss is independent with them.
""",
)
parser.add_argument(
"--num-codebooks",
type=int,
default=8,
help="number of code books",
)
return parser
@ -333,6 +343,8 @@ def get_params() -> AttributeDict:
# parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(),
"extra_output_layer": 5, # 0-based index
"num_codebooks": 8,
}
)
@ -348,6 +360,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
nhead=params.nhead,
dim_feedforward=params.dim_feedforward,
num_encoder_layers=params.num_encoder_layers,
extra_output_layer=params.extra_output_layer,
)
return encoder
@ -385,6 +398,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
num_codebooks=params.num_codebooks,
)
return model
@ -539,8 +553,20 @@ def compute_loss(
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)
if is_training:
cuts = batch["supervisions"]["cut"]
# -100 is identical to ignore_value in CE loss computation.
cuts_pre_mixed = [
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
]
codebook_indices, codebook_indices_lens = collate_custom_field(
cuts_pre_mixed, "codebook_indices", pad_value=-100
)
codebook_indices = codebook_indices.to(device)
else:
codebook_indices = None
with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
simple_loss, pruned_loss, codebook_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
@ -548,6 +574,7 @@ def compute_loss(
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
codebook_indices=codebook_indices,
)
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
@ -562,6 +589,9 @@ def compute_loss(
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
if is_training:
assert codebook_loss is not None
loss += params.codebook_loss_scale * codebook_loss
assert loss.requires_grad == is_training
@ -576,6 +606,8 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if is_training:
info["codebook_loss"] = codebook_loss.detach().cpu().item()
return loss, info
@ -835,7 +867,7 @@ def run(rank, world_size, args):
if params.print_diagnostics:
opts = diagnostics.TensorDiagnosticOptions(
2 ** 22
2**22
) # allow 4 megabytes per sub-module
diagnostic = diagnostics.attach_diagnostics(model, opts)