mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Merge 54bcc167e17e0edb26f9d32a1f35510fbf323fee into 5b6699a8354b70b23b252b371c612a35ed186ec2
This commit is contained in:
commit
c59bdf651a
@ -24,6 +24,64 @@ from torch import Tensor, nn
|
||||
from transformer import Supervisions, Transformer, encoder_padding_mask
|
||||
|
||||
|
||||
class CodeIndicesNet(nn.Module):
|
||||
"""Used to compute codebook indices and codebook loss."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model=512,
|
||||
quantizer_dim=512,
|
||||
num_codebooks=4,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
d_model:
|
||||
The dimention of memory embeddings(input).
|
||||
quantizer_dim:
|
||||
The dimention of quantizer, i.e. num-classes of CE loss;
|
||||
num_codebooks:
|
||||
Number of codebooks used, i.e. number of CE losses actually used.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(d_model, num_codebooks * quantizer_dim)
|
||||
# Default reduction is 'mean'
|
||||
self.ce = nn.CrossEntropyLoss(ignore_index=-100, reduction="sum")
|
||||
self.num_codebooks = num_codebooks
|
||||
self.quantizer_dim = quantizer_dim
|
||||
|
||||
def forward(self, memory):
|
||||
"""
|
||||
Args:
|
||||
memory:
|
||||
memory embeddings, with shape[T, N, C]
|
||||
output:
|
||||
shape [N, T, num_codebooks*quantizer_dim]
|
||||
"""
|
||||
x = self.linear1(memory)
|
||||
return x
|
||||
|
||||
def loss(self, memory: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
memory:
|
||||
memory embeddings, with shape[T, N, C]
|
||||
target:
|
||||
codebook indices, with shape[N, T, num_codebooks]
|
||||
|
||||
output:
|
||||
codebook loss;
|
||||
actually it's the sum of num_codebooks CE losses
|
||||
"""
|
||||
|
||||
memory = memory.transpose(0, 1) # T, N, C --> N, T, C
|
||||
x = self.forward(memory)
|
||||
x = x.reshape(-1, self.quantizer_dim)
|
||||
target = target.reshape(-1)
|
||||
ret = self.ce(x, target)
|
||||
return ret
|
||||
|
||||
|
||||
class Conformer(Transformer):
|
||||
"""
|
||||
Args:
|
||||
@ -95,6 +153,8 @@ class Conformer(Transformer):
|
||||
# and throws an error without this change.
|
||||
self.after_norm = identity
|
||||
|
||||
self.cdidxnet = CodeIndicesNet()
|
||||
|
||||
def run_encoder(
|
||||
self, x: Tensor, supervisions: Optional[Supervisions] = None
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
|
@ -497,14 +497,22 @@ def save_results(
|
||||
enable_log = True
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
recog_path = (
|
||||
params.exp_dir
|
||||
/ f"epoch-{params.epoch}-avg-{params.avg}- \
|
||||
recogs-{test_set_name}-{key}.txt"
|
||||
)
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
if enable_log:
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
||||
errs_filename = (
|
||||
params.exp_dir
|
||||
/ f"epoch-{params.epoch}-avg-{params.avg}- \
|
||||
errs-{test_set_name}-{key}.txt"
|
||||
)
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(
|
||||
f, f"{test_set_name}-{key}", results, enable_log=enable_log
|
||||
@ -517,7 +525,11 @@ def save_results(
|
||||
)
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
||||
errs_info = (
|
||||
params.exp_dir
|
||||
/ f"epoch-{params.epoch}-avg-{params.avg}- \
|
||||
wer-summary-{test_set_name}.txt"
|
||||
)
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
|
@ -30,6 +30,7 @@ import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from conformer import Conformer
|
||||
from lhotse.utils import fix_random_seed
|
||||
from lhotse.dataset.collation import collate_custom_field
|
||||
from torch import Tensor
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
@ -123,6 +124,16 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--codebook-weight",
|
||||
type=float,
|
||||
default=0.1,
|
||||
help="""The weight of code book loss.
|
||||
Note: Currently rate of ctc_loss + rate of att_loss = 1.0
|
||||
codebook_weight is independent with previous two.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--lr-factor",
|
||||
type=float,
|
||||
@ -397,6 +408,27 @@ def compute_loss(
|
||||
eos_id=graph_compiler.eos_id,
|
||||
)
|
||||
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
|
||||
|
||||
if params.codebook_weight != 0.0:
|
||||
|
||||
cuts = batch["supervisions"]["cut"]
|
||||
# -100 is identical to ignore_value in CE loss computation.
|
||||
codebook_indices, codebook_indices_lens = collate_custom_field(
|
||||
cuts, "codebook_indices", pad_value=-100
|
||||
)
|
||||
|
||||
assert (
|
||||
codebook_indices.shape[0] == encoder_memory.shape[1]
|
||||
) # N: batch_size
|
||||
assert (
|
||||
codebook_indices.shape[1] == encoder_memory.shape[0]
|
||||
) # T: num frames
|
||||
codebook_indices = codebook_indices.to(encoder_memory.device).long()
|
||||
codebook_loss = mmodel.cdidxnet.loss(
|
||||
encoder_memory, target=codebook_indices
|
||||
)
|
||||
|
||||
loss += params.codebook_weight * codebook_loss
|
||||
else:
|
||||
loss = ctc_loss
|
||||
att_loss = torch.tensor([0])
|
||||
@ -409,6 +441,9 @@ def compute_loss(
|
||||
if params.att_rate != 0.0:
|
||||
info["att_loss"] = att_loss.detach().cpu().item()
|
||||
|
||||
if params.codebook_weight != 0.0:
|
||||
info["codebook_loss"] = codebook_loss.detach().cpu().item()
|
||||
|
||||
info["loss"] = loss.detach().cpu().item()
|
||||
|
||||
return loss, info
|
||||
|
@ -126,7 +126,11 @@ def setup_logger(
|
||||
level = logging.CRITICAL
|
||||
|
||||
logging.basicConfig(
|
||||
filename=log_filename, format=formatter, level=level, filemode="w"
|
||||
filename=log_filename,
|
||||
format=formatter,
|
||||
level=level,
|
||||
filemode="w",
|
||||
force=True,
|
||||
)
|
||||
if use_console:
|
||||
console = logging.StreamHandler()
|
||||
|
Loading…
x
Reference in New Issue
Block a user