mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-07 08:04:18 +00:00
train/decode with codebook loss
This commit is contained in:
parent
979f574259
commit
9d48f1ce7d
@ -18,36 +18,36 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
(1) greedy search
|
(1) greedy search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./vq_pruned_transducer_stateless2/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./vq_pruned_transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method greedy_search
|
--decoding-method greedy_search
|
||||||
|
|
||||||
(2) beam search
|
(2) beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./vq_pruned_transducer_stateless2/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./vq_pruned_transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method beam_search \
|
--decoding-method beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(3) modified beam search
|
(3) modified beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./vq_pruned_transducer_stateless2/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./vq_pruned_transducer_stateless2/exp \
|
||||||
--max-duration 100 \
|
--max-duration 100 \
|
||||||
--decoding-method modified_beam_search \
|
--decoding-method modified_beam_search \
|
||||||
--beam-size 4
|
--beam-size 4
|
||||||
|
|
||||||
(4) fast beam search
|
(4) fast beam search
|
||||||
./pruned_transducer_stateless2/decode.py \
|
./vq_pruned_transducer_stateless2/decode.py \
|
||||||
--epoch 28 \
|
--epoch 28 \
|
||||||
--avg 15 \
|
--avg 15 \
|
||||||
--exp-dir ./pruned_transducer_stateless2/exp \
|
--exp-dir ./vq_pruned_transducer_stateless2/exp \
|
||||||
--max-duration 1500 \
|
--max-duration 1500 \
|
||||||
--decoding-method fast_beam_search \
|
--decoding-method fast_beam_search \
|
||||||
--beam 4 \
|
--beam 4 \
|
||||||
@ -124,7 +124,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless2/exp",
|
default="vq_pruned_transducer_stateless2/exp",
|
||||||
help="The experiment dir",
|
help="The experiment dir",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -19,27 +19,19 @@
|
|||||||
"""
|
"""
|
||||||
Usage:
|
Usage:
|
||||||
|
|
||||||
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
export CUDA_VISIBLE_DEVICES="0,1,2"
|
||||||
|
|
||||||
./pruned_transducer_stateless2/train.py \
|
./vq_pruned_transducer_stateless2/train.py \
|
||||||
--world-size 4 \
|
--world-size 3 \
|
||||||
--num-epochs 30 \
|
--num-epochs 30 \
|
||||||
--start-epoch 0 \
|
--start-epoch 0 \
|
||||||
--exp-dir pruned_transducer_stateless2/exp \
|
--manifest-dir data/quantizer-finetuned_hubert_xtralarge-36layer-f0238f68-bytes_per_frame-8/ \
|
||||||
--full-libri 1 \
|
--codebook-loss-scale 0.1 \
|
||||||
|
--num-codebooks=8 \
|
||||||
|
--exp-dir vq_pruned_transducer_stateless2/exp \
|
||||||
|
--full-libri 0 \
|
||||||
--max-duration 300
|
--max-duration 300
|
||||||
|
|
||||||
# For mix precision training:
|
|
||||||
|
|
||||||
./pruned_transducer_stateless2/train.py \
|
|
||||||
--world-size 4 \
|
|
||||||
--num-epochs 30 \
|
|
||||||
--start-epoch 0 \
|
|
||||||
--use_fp16 1 \
|
|
||||||
--exp-dir pruned_transducer_stateless2/exp \
|
|
||||||
--full-libri 1 \
|
|
||||||
--max-duration 550
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -60,9 +52,10 @@ from asr_datamodule import LibriSpeechAsrDataModule
|
|||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
from lhotse.cut import Cut
|
from lhotse.cut import Cut, MonoCut
|
||||||
from lhotse.dataset.sampling.base import CutSampler
|
from lhotse.dataset.sampling.base import CutSampler
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
|
from lhotse.dataset.collation import collate_custom_field
|
||||||
from model import Transducer
|
from model import Transducer
|
||||||
from optim import Eden, Eve
|
from optim import Eden, Eve
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
@ -138,7 +131,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="pruned_transducer_stateless2/exp",
|
default="vq_pruned_transducer_stateless2/exp",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
@ -262,6 +255,23 @@ def get_parser():
|
|||||||
help="Whether to use half precision training.",
|
help="Whether to use half precision training.",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--codebook-loss-scale",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="""The weight of code book loss.
|
||||||
|
Note: Currently rate of ctc_loss + rate of att_loss = 1.0
|
||||||
|
codebook_loss is independent with them.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-codebooks",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="number of code books",
|
||||||
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -333,6 +343,8 @@ def get_params() -> AttributeDict:
|
|||||||
# parameters for Noam
|
# parameters for Noam
|
||||||
"model_warm_step": 3000, # arg given to model, not for lrate
|
"model_warm_step": 3000, # arg given to model, not for lrate
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
|
"extra_output_layer": 5, # 0-based index
|
||||||
|
"num_codebooks": 8,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -348,6 +360,7 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
nhead=params.nhead,
|
nhead=params.nhead,
|
||||||
dim_feedforward=params.dim_feedforward,
|
dim_feedforward=params.dim_feedforward,
|
||||||
num_encoder_layers=params.num_encoder_layers,
|
num_encoder_layers=params.num_encoder_layers,
|
||||||
|
extra_output_layer=params.extra_output_layer,
|
||||||
)
|
)
|
||||||
return encoder
|
return encoder
|
||||||
|
|
||||||
@ -385,6 +398,7 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
|
|||||||
decoder_dim=params.decoder_dim,
|
decoder_dim=params.decoder_dim,
|
||||||
joiner_dim=params.joiner_dim,
|
joiner_dim=params.joiner_dim,
|
||||||
vocab_size=params.vocab_size,
|
vocab_size=params.vocab_size,
|
||||||
|
num_codebooks=params.num_codebooks,
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@ -539,8 +553,20 @@ def compute_loss(
|
|||||||
y = sp.encode(texts, out_type=int)
|
y = sp.encode(texts, out_type=int)
|
||||||
y = k2.RaggedTensor(y).to(device)
|
y = k2.RaggedTensor(y).to(device)
|
||||||
|
|
||||||
|
if is_training:
|
||||||
|
cuts = batch["supervisions"]["cut"]
|
||||||
|
# -100 is identical to ignore_value in CE loss computation.
|
||||||
|
cuts_pre_mixed = [
|
||||||
|
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
|
||||||
|
]
|
||||||
|
codebook_indices, codebook_indices_lens = collate_custom_field(
|
||||||
|
cuts_pre_mixed, "codebook_indices", pad_value=-100
|
||||||
|
)
|
||||||
|
codebook_indices = codebook_indices.to(device)
|
||||||
|
else:
|
||||||
|
codebook_indices = None
|
||||||
with torch.set_grad_enabled(is_training):
|
with torch.set_grad_enabled(is_training):
|
||||||
simple_loss, pruned_loss = model(
|
simple_loss, pruned_loss, codebook_loss = model(
|
||||||
x=feature,
|
x=feature,
|
||||||
x_lens=feature_lens,
|
x_lens=feature_lens,
|
||||||
y=y,
|
y=y,
|
||||||
@ -548,6 +574,7 @@ def compute_loss(
|
|||||||
am_scale=params.am_scale,
|
am_scale=params.am_scale,
|
||||||
lm_scale=params.lm_scale,
|
lm_scale=params.lm_scale,
|
||||||
warmup=warmup,
|
warmup=warmup,
|
||||||
|
codebook_indices=codebook_indices,
|
||||||
)
|
)
|
||||||
# after the main warmup step, we keep pruned_loss_scale small
|
# after the main warmup step, we keep pruned_loss_scale small
|
||||||
# for the same amount of time (model_warm_step), to avoid
|
# for the same amount of time (model_warm_step), to avoid
|
||||||
@ -562,6 +589,9 @@ def compute_loss(
|
|||||||
params.simple_loss_scale * simple_loss
|
params.simple_loss_scale * simple_loss
|
||||||
+ pruned_loss_scale * pruned_loss
|
+ pruned_loss_scale * pruned_loss
|
||||||
)
|
)
|
||||||
|
if is_training:
|
||||||
|
assert codebook_loss is not None
|
||||||
|
loss += params.codebook_loss_scale * codebook_loss
|
||||||
|
|
||||||
assert loss.requires_grad == is_training
|
assert loss.requires_grad == is_training
|
||||||
|
|
||||||
@ -576,6 +606,8 @@ def compute_loss(
|
|||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
info["simple_loss"] = simple_loss.detach().cpu().item()
|
info["simple_loss"] = simple_loss.detach().cpu().item()
|
||||||
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
info["pruned_loss"] = pruned_loss.detach().cpu().item()
|
||||||
|
if is_training:
|
||||||
|
info["codebook_loss"] = codebook_loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
|
||||||
@ -835,7 +867,7 @@ def run(rank, world_size, args):
|
|||||||
|
|
||||||
if params.print_diagnostics:
|
if params.print_diagnostics:
|
||||||
opts = diagnostics.TensorDiagnosticOptions(
|
opts = diagnostics.TensorDiagnosticOptions(
|
||||||
2 ** 22
|
2**22
|
||||||
) # allow 4 megabytes per sub-module
|
) # allow 4 megabytes per sub-module
|
||||||
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
diagnostic = diagnostics.attach_diagnostics(model, opts)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user