training with coodbook loss

This commit is contained in:
Guo Liyong 2021-12-02 17:16:48 +08:00
parent 0e541f5b5d
commit a4722dd7c0
4 changed files with 112 additions and 4 deletions

View File

@ -25,6 +25,64 @@ from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask from transformer import Supervisions, Transformer, encoder_padding_mask
class CodeIndicesNet(nn.Module):
"""Used to compute codebook indices and codebook loss."""
def __init__(
self,
d_model=512,
quantizer_dim=512,
num_codebooks=4,
):
"""
Args:
d_model:
The dimention of memory embeddings(input).
quantizer_dim:
The dimention of quantizer, i.e. num-classes of CE loss;
num_codebooks:
Number of codebooks used, i.e. number of CE losses actually used.
"""
super().__init__()
self.linear1 = nn.Linear(d_model, num_codebooks * quantizer_dim)
# Default reduction is 'mean'
self.ce = nn.CrossEntropyLoss(ignore_index=-100, reduction="sum")
self.num_codebooks = num_codebooks
self.quantizer_dim = quantizer_dim
def forward(self, memory):
"""
Args:
memory:
memory embeddings, with shape[T, N, C]
output:
shape [N, T, num_codebooks*quantizer_dim]
"""
x = self.linear1(memory)
return x
def loss(self, memory: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
memory:
memory embeddings, with shape[T, N, C]
target:
codebook indices, with shape[N, T, num_codebooks]
output:
codebook loss;
actually it's the sum of num_codebooks CE losses
"""
memory = memory.transpose(0, 1) # T, N, C --> N, T, C
x = self.forward(memory)
x = x.reshape(-1, self.quantizer_dim)
target = target.reshape(-1)
ret = self.ce(x, target)
return ret
class Conformer(Transformer): class Conformer(Transformer):
""" """
Args: Args:
@ -92,6 +150,8 @@ class Conformer(Transformer):
# and throws an error without this change. # and throws an error without this change.
self.after_norm = identity self.after_norm = identity
self.cdidxnet = CodeIndicesNet()
def run_encoder( def run_encoder(
self, x: Tensor, supervisions: Optional[Supervisions] = None self, x: Tensor, supervisions: Optional[Supervisions] = None
) -> Tuple[Tensor, Optional[Tensor]]: ) -> Tuple[Tensor, Optional[Tensor]]:

View File

@ -499,14 +499,20 @@ def save_results(
enable_log = True enable_log = True
test_set_wers = dict() test_set_wers = dict()
for key, results in results_dict.items(): for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" recog_path = (
params.exp_dir
/ f"epoch-{params.epoch}-avg-{params.avg}-recogs-{test_set_name}-{key}.txt"
)
store_transcripts(filename=recog_path, texts=results) store_transcripts(filename=recog_path, texts=results)
if enable_log: if enable_log:
logging.info(f"The transcripts are stored in {recog_path}") logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned # The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs. # ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" errs_filename = (
params.exp_dir
/ f"epoch-{params.epoch}-avg-{params.avg}-errs-{test_set_name}-{key}.txt"
)
with open(errs_filename, "w") as f: with open(errs_filename, "w") as f:
wer = write_error_stats( wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log f, f"{test_set_name}-{key}", results, enable_log=enable_log
@ -519,7 +525,10 @@ def save_results(
) )
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" errs_info = (
params.exp_dir
/ f"epoch-{params.epoch}-avg-{params.avg}-wer-summary-{test_set_name}.txt"
)
with open(errs_info, "w") as f: with open(errs_info, "w") as f:
print("settings\tWER", file=f) print("settings\tWER", file=f)
for key, val in test_set_wers: for key, val in test_set_wers:

View File

@ -31,6 +31,7 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer from conformer import Conformer
from lhotse.utils import fix_random_seed from lhotse.utils import fix_random_seed
from lhotse.dataset.collation import collate_custom_field
from torch import Tensor from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
@ -124,6 +125,16 @@ def get_parser():
""", """,
) )
parser.add_argument(
"--codebook-weight",
type=float,
default=0.1,
help="""The weight of code book loss.
Note: Currently rate of ctc_loss + rate of att_loss = 1.0
codebook_weight is independent with previous two.
""",
)
parser.add_argument( parser.add_argument(
"--lr-factor", "--lr-factor",
type=float, type=float,
@ -394,6 +405,27 @@ def compute_loss(
eos_id=graph_compiler.eos_id, eos_id=graph_compiler.eos_id,
) )
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
if params.codebook_weight != 0.0:
cuts = batch["supervisions"]["cut"]
# -100 is identical to ignore_value in CE loss computation.
codebook_indices, codebook_indices_lens = collate_custom_field(
cuts, "codebook_indices", pad_value=-100
)
assert (
codebook_indices.shape[0] == encoder_memory.shape[1]
) # N: batch_size
assert (
codebook_indices.shape[1] == encoder_memory.shape[0]
) # T: num frames
codebook_indices = codebook_indices.to(encoder_memory.device).long()
codebook_loss = mmodel.cdidxnet.loss(
encoder_memory, target=codebook_indices
)
loss += params.codebook_weight * codebook_loss
else: else:
loss = ctc_loss loss = ctc_loss
att_loss = torch.tensor([0]) att_loss = torch.tensor([0])
@ -406,6 +438,9 @@ def compute_loss(
if params.att_rate != 0.0: if params.att_rate != 0.0:
info["att_loss"] = att_loss.detach().cpu().item() info["att_loss"] = att_loss.detach().cpu().item()
if params.codebook_weight != 0.0:
info["codebook_loss"] = cdidx_loss.detach().cpu().item()
info["loss"] = loss.detach().cpu().item() info["loss"] = loss.detach().cpu().item()
return loss, info return loss, info

View File

@ -126,7 +126,11 @@ def setup_logger(
level = logging.CRITICAL level = logging.CRITICAL
logging.basicConfig( logging.basicConfig(
filename=log_filename, format=formatter, level=level, filemode="w" filename=log_filename,
format=formatter,
level=level,
filemode="w",
force=True,
) )
if use_console: if use_console:
console = logging.StreamHandler() console = logging.StreamHandler()