mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-04 14:44:18 +00:00
train/decode with codebook loss
This commit is contained in:
parent
979f574259
commit
9d48f1ce7d
@ -18,36 +18,36 @@
|
||||
"""
|
||||
Usage:
|
||||
(1) greedy search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
./vq_pruned_transducer_stateless2/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./vq_pruned_transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method greedy_search
|
||||
|
||||
(2) beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
./vq_pruned_transducer_stateless2/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./vq_pruned_transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(3) modified beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
./vq_pruned_transducer_stateless2/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./vq_pruned_transducer_stateless2/exp \
|
||||
--max-duration 100 \
|
||||
--decoding-method modified_beam_search \
|
||||
--beam-size 4
|
||||
|
||||
(4) fast beam search
|
||||
./pruned_transducer_stateless2/decode.py \
|
||||
./vq_pruned_transducer_stateless2/decode.py \
|
||||
--epoch 28 \
|
||||
--avg 15 \
|
||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
||||
--exp-dir ./vq_pruned_transducer_stateless2/exp \
|
||||
--max-duration 1500 \
|
||||
--decoding-method fast_beam_search \
|
||||
--beam 4 \
|
||||
@ -124,7 +124,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="vq_pruned_transducer_stateless2/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
|
@ -19,27 +19,19 @@
|
||||
"""
|
||||
Usage:
|
||||
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2"
|
||||
|
||||
./pruned_transducer_stateless2/train.py \
|
||||
--world-size 4 \
|
||||
./vq_pruned_transducer_stateless2/train.py \
|
||||
--world-size 3 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--full-libri 1 \
|
||||
--manifest-dir data/quantizer-finetuned_hubert_xtralarge-36layer-f0238f68-bytes_per_frame-8/ \
|
||||
--codebook-loss-scale 0.1 \
|
||||
--num-codebooks=8 \
|
||||
--exp-dir vq_pruned_transducer_stateless2/exp \
|
||||
--full-libri 0 \
|
||||
--max-duration 300
|
||||
|
||||
# For mix precision training:
|
||||
|
||||
./pruned_transducer_stateless2/train.py \
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--use_fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@ -60,9 +52,10 @@ from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
from lhotse.cut import Cut
|
||||
from lhotse.cut import Cut, MonoCut
|
||||
from lhotse.dataset.sampling.base import CutSampler
|
||||
from lhotse.utils import fix_random_seed
|
||||
from lhotse.dataset.collation import collate_custom_field
|
||||
from model import Transducer
|
||||
from optim import Eden, Eve
|
||||
from torch import Tensor
|
||||
@ -138,7 +131,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="pruned_transducer_stateless2/exp",
|
||||
default="vq_pruned_transducer_stateless2/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -262,6 +255,23 @@ def get_parser():
|
||||
help="Whether to use half precision training.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--codebook-loss-scale",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="""The weight of code book loss.
|
||||
Note: Currently rate of ctc_loss + rate of att_loss = 1.0
|
||||
codebook_loss is independent with them.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-codebooks",
|
||||
type=int,
|
||||
default=8,
|
||||
help="number of code books",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -333,6 +343,8 @@ def get_params() -> AttributeDict:
|
||||
# parameters for Noam
|
||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||
"env_info": get_env_info(),
|
||||
"extra_output_layer": 5, # 0-based index
|
||||
"num_codebooks": 8,
|
||||
}
|
||||
)
|
||||
|
||||
@ -348,6 +360,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
||||
nhead=params.nhead,
|
||||
dim_feedforward=params.dim_feedforward,
|
||||
num_encoder_layers=params.num_encoder_layers,
|
||||
extra_output_layer=params.extra_output_layer,
|
||||
)
|
||||
return encoder
|
||||
|
||||
@ -385,6 +398,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
decoder_dim=params.decoder_dim,
|
||||
joiner_dim=params.joiner_dim,
|
||||
vocab_size=params.vocab_size,
|
||||
num_codebooks=params.num_codebooks,
|
||||
)
|
||||
return model
|
||||
|
||||
@ -539,8 +553,20 @@ def compute_loss(
|
||||
y = sp.encode(texts, out_type=int)
|
||||
y = k2.RaggedTensor(y).to(device)
|
||||
|
||||
if is_training:
|
||||
cuts = batch["supervisions"]["cut"]
|
||||
# -100 is identical to ignore_value in CE loss computation.
|
||||
cuts_pre_mixed = [
|
||||
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
|
||||
]
|
||||
codebook_indices, codebook_indices_lens = collate_custom_field(
|
||||
cuts_pre_mixed, "codebook_indices", pad_value=-100
|
||||
)
|
||||
codebook_indices = codebook_indices.to(device)
|
||||
else:
|
||||
codebook_indices = None
|
||||
with torch.set_grad_enabled(is_training):
|
||||
simple_loss, pruned_loss = model(
|
||||
simple_loss, pruned_loss, codebook_loss = model(
|
||||
x=feature,
|
||||
x_lens=feature_lens,
|
||||
y=y,
|
||||
@ -548,6 +574,7 @@ def compute_loss(
|
||||
am_scale=params.am_scale,
|
||||
lm_scale=params.lm_scale,
|
||||
warmup=warmup,
|
||||
codebook_indices=codebook_indices,
|
||||
)
|
||||
# after the main warmup step, we keep pruned_loss_scale small
|
||||
# for the same amount of time (model_warm_step), to avoid
|
||||
@ -562,6 +589,9 @@ def compute_loss(
|
||||
params.simple_loss_scale * simple_loss
|
||||
+ pruned_loss_scale * pruned_loss
|
||||
)
|
||||
if is_training:
|
||||
assert codebook_loss is not None
|
||||
loss += params.codebook_loss_scale * codebook_loss
|
||||
|
||||
assert loss.requires_grad == is_training
|
||||
|
||||
@ -576,6 +606,8 @@ def compute_loss(
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||
if is_training:
|
||||
info["codebook_loss"] = codebook_loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
||||
@ -835,7 +867,7 @@ def run(rank, world_size, args):
|
||||
|
||||
if params.print_diagnostics:
|
||||
opts = diagnostics.TensorDiagnosticOptions(
|
||||
2 ** 22
|
||||
2**22
|
||||
) # allow 4 megabytes per sub-module
|
||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user