Merge 54bcc167e17e0edb26f9d32a1f35510fbf323fee into 5b6699a8354b70b23b252b371c612a35ed186ec2

This commit is contained in:
LIyong.Guo 2021-12-23 18:07:27 +08:00 committed by GitHub
commit c59bdf651a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 115 additions and 4 deletions

View File

@ -24,6 +24,64 @@ from torch import Tensor, nn
from transformer import Supervisions, Transformer, encoder_padding_mask
class CodeIndicesNet(nn.Module):
"""Used to compute codebook indices and codebook loss."""
def __init__(
self,
d_model=512,
quantizer_dim=512,
num_codebooks=4,
):
"""
Args:
d_model:
The dimention of memory embeddings(input).
quantizer_dim:
The dimention of quantizer, i.e. num-classes of CE loss;
num_codebooks:
Number of codebooks used, i.e. number of CE losses actually used.
"""
super().__init__()
self.linear1 = nn.Linear(d_model, num_codebooks * quantizer_dim)
# Default reduction is 'mean'
self.ce = nn.CrossEntropyLoss(ignore_index=-100, reduction="sum")
self.num_codebooks = num_codebooks
self.quantizer_dim = quantizer_dim
def forward(self, memory):
"""
Args:
memory:
memory embeddings, with shape[T, N, C]
output:
shape [N, T, num_codebooks*quantizer_dim]
"""
x = self.linear1(memory)
return x
def loss(self, memory: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
memory:
memory embeddings, with shape[T, N, C]
target:
codebook indices, with shape[N, T, num_codebooks]
output:
codebook loss;
actually it's the sum of num_codebooks CE losses
"""
memory = memory.transpose(0, 1) # T, N, C --> N, T, C
x = self.forward(memory)
x = x.reshape(-1, self.quantizer_dim)
target = target.reshape(-1)
ret = self.ce(x, target)
return ret
class Conformer(Transformer):
"""
Args:
@ -95,6 +153,8 @@ class Conformer(Transformer):
# and throws an error without this change.
self.after_norm = identity
self.cdidxnet = CodeIndicesNet()
def run_encoder(
self, x: Tensor, supervisions: Optional[Supervisions] = None
) -> Tuple[Tensor, Optional[Tensor]]:

View File

@ -497,14 +497,22 @@ def save_results(
enable_log = True
test_set_wers = dict()
for key, results in results_dict.items():
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
recog_path = (
params.exp_dir
/ f"epoch-{params.epoch}-avg-{params.avg}- \
recogs-{test_set_name}-{key}.txt"
)
store_transcripts(filename=recog_path, texts=results)
if enable_log:
logging.info(f"The transcripts are stored in {recog_path}")
# The following prints out WERs, per-word error statistics and aligned
# ref/hyp pairs.
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
errs_filename = (
params.exp_dir
/ f"epoch-{params.epoch}-avg-{params.avg}- \
errs-{test_set_name}-{key}.txt"
)
with open(errs_filename, "w") as f:
wer = write_error_stats(
f, f"{test_set_name}-{key}", results, enable_log=enable_log
@ -517,7 +525,11 @@ def save_results(
)
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
errs_info = (
params.exp_dir
/ f"epoch-{params.epoch}-avg-{params.avg}- \
wer-summary-{test_set_name}.txt"
)
with open(errs_info, "w") as f:
print("settings\tWER", file=f)
for key, val in test_set_wers:

View File

@ -30,6 +30,7 @@ import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from conformer import Conformer
from lhotse.utils import fix_random_seed
from lhotse.dataset.collation import collate_custom_field
from torch import Tensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_norm_
@ -123,6 +124,16 @@ def get_parser():
""",
)
parser.add_argument(
"--codebook-weight",
type=float,
default=0.1,
help="""The weight of code book loss.
Note: Currently rate of ctc_loss + rate of att_loss = 1.0
codebook_weight is independent with previous two.
""",
)
parser.add_argument(
"--lr-factor",
type=float,
@ -397,6 +408,27 @@ def compute_loss(
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
if params.codebook_weight != 0.0:
cuts = batch["supervisions"]["cut"]
# -100 is identical to ignore_value in CE loss computation.
codebook_indices, codebook_indices_lens = collate_custom_field(
cuts, "codebook_indices", pad_value=-100
)
assert (
codebook_indices.shape[0] == encoder_memory.shape[1]
) # N: batch_size
assert (
codebook_indices.shape[1] == encoder_memory.shape[0]
) # T: num frames
codebook_indices = codebook_indices.to(encoder_memory.device).long()
codebook_loss = mmodel.cdidxnet.loss(
encoder_memory, target=codebook_indices
)
loss += params.codebook_weight * codebook_loss
else:
loss = ctc_loss
att_loss = torch.tensor([0])
@ -409,6 +441,9 @@ def compute_loss(
if params.att_rate != 0.0:
info["att_loss"] = att_loss.detach().cpu().item()
if params.codebook_weight != 0.0:
info["codebook_loss"] = codebook_loss.detach().cpu().item()
info["loss"] = loss.detach().cpu().item()
return loss, info

View File

@ -126,7 +126,11 @@ def setup_logger(
level = logging.CRITICAL
logging.basicConfig(
filename=log_filename, format=formatter, level=level, filemode="w"
filename=log_filename,
format=formatter,
level=level,
filemode="w",
force=True,
)
if use_console:
console = logging.StreamHandler()