mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
training with coodbook loss
This commit is contained in:
parent
0e541f5b5d
commit
a4722dd7c0
@ -25,6 +25,64 @@ from torch import Tensor, nn
|
|||||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||||
|
|
||||||
|
|
||||||
|
class CodeIndicesNet(nn.Module):
|
||||||
|
"""Used to compute codebook indices and codebook loss."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model=512,
|
||||||
|
quantizer_dim=512,
|
||||||
|
num_codebooks=4,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
d_model:
|
||||||
|
The dimention of memory embeddings(input).
|
||||||
|
quantizer_dim:
|
||||||
|
The dimention of quantizer, i.e. num-classes of CE loss;
|
||||||
|
num_codebooks:
|
||||||
|
Number of codebooks used, i.e. number of CE losses actually used.
|
||||||
|
"""
|
||||||
|
|
||||||
|
super().__init__()
|
||||||
|
self.linear1 = nn.Linear(d_model, num_codebooks * quantizer_dim)
|
||||||
|
# Default reduction is 'mean'
|
||||||
|
self.ce = nn.CrossEntropyLoss(ignore_index=-100, reduction="sum")
|
||||||
|
self.num_codebooks = num_codebooks
|
||||||
|
self.quantizer_dim = quantizer_dim
|
||||||
|
|
||||||
|
def forward(self, memory):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
memory:
|
||||||
|
memory embeddings, with shape[T, N, C]
|
||||||
|
output:
|
||||||
|
shape [N, T, num_codebooks*quantizer_dim]
|
||||||
|
"""
|
||||||
|
x = self.linear1(memory)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def loss(self, memory: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
memory:
|
||||||
|
memory embeddings, with shape[T, N, C]
|
||||||
|
target:
|
||||||
|
codebook indices, with shape[N, T, num_codebooks]
|
||||||
|
|
||||||
|
output:
|
||||||
|
codebook loss;
|
||||||
|
actually it's the sum of num_codebooks CE losses
|
||||||
|
"""
|
||||||
|
|
||||||
|
memory = memory.transpose(0, 1) # T, N, C --> N, T, C
|
||||||
|
x = self.forward(memory)
|
||||||
|
x = x.reshape(-1, self.quantizer_dim)
|
||||||
|
target = target.reshape(-1)
|
||||||
|
ret = self.ce(x, target)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
class Conformer(Transformer):
|
class Conformer(Transformer):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@ -92,6 +150,8 @@ class Conformer(Transformer):
|
|||||||
# and throws an error without this change.
|
# and throws an error without this change.
|
||||||
self.after_norm = identity
|
self.after_norm = identity
|
||||||
|
|
||||||
|
self.cdidxnet = CodeIndicesNet()
|
||||||
|
|
||||||
def run_encoder(
|
def run_encoder(
|
||||||
self, x: Tensor, supervisions: Optional[Supervisions] = None
|
self, x: Tensor, supervisions: Optional[Supervisions] = None
|
||||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||||
|
@ -499,14 +499,20 @@ def save_results(
|
|||||||
enable_log = True
|
enable_log = True
|
||||||
test_set_wers = dict()
|
test_set_wers = dict()
|
||||||
for key, results in results_dict.items():
|
for key, results in results_dict.items():
|
||||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
recog_path = (
|
||||||
|
params.exp_dir
|
||||||
|
/ f"epoch-{params.epoch}-avg-{params.avg}-recogs-{test_set_name}-{key}.txt"
|
||||||
|
)
|
||||||
store_transcripts(filename=recog_path, texts=results)
|
store_transcripts(filename=recog_path, texts=results)
|
||||||
if enable_log:
|
if enable_log:
|
||||||
logging.info(f"The transcripts are stored in {recog_path}")
|
logging.info(f"The transcripts are stored in {recog_path}")
|
||||||
|
|
||||||
# The following prints out WERs, per-word error statistics and aligned
|
# The following prints out WERs, per-word error statistics and aligned
|
||||||
# ref/hyp pairs.
|
# ref/hyp pairs.
|
||||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
errs_filename = (
|
||||||
|
params.exp_dir
|
||||||
|
/ f"epoch-{params.epoch}-avg-{params.avg}-errs-{test_set_name}-{key}.txt"
|
||||||
|
)
|
||||||
with open(errs_filename, "w") as f:
|
with open(errs_filename, "w") as f:
|
||||||
wer = write_error_stats(
|
wer = write_error_stats(
|
||||||
f, f"{test_set_name}-{key}", results, enable_log=enable_log
|
f, f"{test_set_name}-{key}", results, enable_log=enable_log
|
||||||
@ -519,7 +525,10 @@ def save_results(
|
|||||||
)
|
)
|
||||||
|
|
||||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||||
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
errs_info = (
|
||||||
|
params.exp_dir
|
||||||
|
/ f"epoch-{params.epoch}-avg-{params.avg}-wer-summary-{test_set_name}.txt"
|
||||||
|
)
|
||||||
with open(errs_info, "w") as f:
|
with open(errs_info, "w") as f:
|
||||||
print("settings\tWER", file=f)
|
print("settings\tWER", file=f)
|
||||||
for key, val in test_set_wers:
|
for key, val in test_set_wers:
|
||||||
|
@ -31,6 +31,7 @@ import torch.nn as nn
|
|||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from lhotse.utils import fix_random_seed
|
from lhotse.utils import fix_random_seed
|
||||||
|
from lhotse.dataset.collation import collate_custom_field
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
@ -124,6 +125,16 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--codebook-weight",
|
||||||
|
type=float,
|
||||||
|
default=0.1,
|
||||||
|
help="""The weight of code book loss.
|
||||||
|
Note: Currently rate of ctc_loss + rate of att_loss = 1.0
|
||||||
|
codebook_weight is independent with previous two.
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--lr-factor",
|
"--lr-factor",
|
||||||
type=float,
|
type=float,
|
||||||
@ -394,6 +405,27 @@ def compute_loss(
|
|||||||
eos_id=graph_compiler.eos_id,
|
eos_id=graph_compiler.eos_id,
|
||||||
)
|
)
|
||||||
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
|
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
|
||||||
|
|
||||||
|
if params.codebook_weight != 0.0:
|
||||||
|
|
||||||
|
cuts = batch["supervisions"]["cut"]
|
||||||
|
# -100 is identical to ignore_value in CE loss computation.
|
||||||
|
codebook_indices, codebook_indices_lens = collate_custom_field(
|
||||||
|
cuts, "codebook_indices", pad_value=-100
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
codebook_indices.shape[0] == encoder_memory.shape[1]
|
||||||
|
) # N: batch_size
|
||||||
|
assert (
|
||||||
|
codebook_indices.shape[1] == encoder_memory.shape[0]
|
||||||
|
) # T: num frames
|
||||||
|
codebook_indices = codebook_indices.to(encoder_memory.device).long()
|
||||||
|
codebook_loss = mmodel.cdidxnet.loss(
|
||||||
|
encoder_memory, target=codebook_indices
|
||||||
|
)
|
||||||
|
|
||||||
|
loss += params.codebook_weight * codebook_loss
|
||||||
else:
|
else:
|
||||||
loss = ctc_loss
|
loss = ctc_loss
|
||||||
att_loss = torch.tensor([0])
|
att_loss = torch.tensor([0])
|
||||||
@ -406,6 +438,9 @@ def compute_loss(
|
|||||||
if params.att_rate != 0.0:
|
if params.att_rate != 0.0:
|
||||||
info["att_loss"] = att_loss.detach().cpu().item()
|
info["att_loss"] = att_loss.detach().cpu().item()
|
||||||
|
|
||||||
|
if params.codebook_weight != 0.0:
|
||||||
|
info["codebook_loss"] = cdidx_loss.detach().cpu().item()
|
||||||
|
|
||||||
info["loss"] = loss.detach().cpu().item()
|
info["loss"] = loss.detach().cpu().item()
|
||||||
|
|
||||||
return loss, info
|
return loss, info
|
||||||
|
@ -126,7 +126,11 @@ def setup_logger(
|
|||||||
level = logging.CRITICAL
|
level = logging.CRITICAL
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
filename=log_filename, format=formatter, level=level, filemode="w"
|
filename=log_filename,
|
||||||
|
format=formatter,
|
||||||
|
level=level,
|
||||||
|
filemode="w",
|
||||||
|
force=True,
|
||||||
)
|
)
|
||||||
if use_console:
|
if use_console:
|
||||||
console = logging.StreamHandler()
|
console = logging.StreamHandler()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user