add conformer exps

This commit is contained in:
yaozengwei 2024-11-25 15:28:05 +08:00
parent b65873fb4c
commit 2fc53cd7ce
6 changed files with 1514 additions and 4 deletions

View File

@ -44,6 +44,7 @@ from icefall.decode import (
nbest_oracle,
one_best_decoding,
rescore_with_attention_decoder,
rescore_with_attention_decoder_no_ngram_old,
rescore_with_n_best_list,
rescore_with_rnn_lm,
rescore_with_whole_lattice,
@ -459,6 +460,27 @@ def decode_one_batch(
key = "ctc-greedy-search"
return {key: hyps}
if params.method == "attention-decoder-rescoring-no-ngram":
best_path_dict = rescore_with_attention_decoder_no_ngram_old(
lattice=lattice,
num_paths=params.num_paths,
model=model,
memory=memory,
memory_key_padding_mask=memory_key_padding_mask,
sos_id=sos_id,
eos_id=eos_id,
)
ans = dict()
for a_scale_str, best_path in best_path_dict.items():
# token_ids is a lit-of-list of IDs
token_ids = get_texts(best_path)
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
hyps = bpe_model.decode(token_ids)
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
hyps = [s.split() for s in hyps]
ans[a_scale_str] = hyps
return ans
if params.method == "nbest-oracle":
# Note: You can also pass rescored lattices to it.
# We choose the HLG decoded lattice for speed reasons
@ -761,7 +783,7 @@ def main():
params.sos_id = sos_id
params.eos_id = eos_id
if params.method == "ctc-decoding" or params.method == "ctc-greedy-search":
if params.method == "ctc-decoding" or params.method == "ctc-greedy-search" or params.method == "attention-decoder-rescoring-no-ngram":
HLG = None
H = k2.ctc_topo(
max_token=max_token_id,

View File

@ -72,7 +72,7 @@ import sentencepiece as spm
import torch
import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule
from train import add_model_arguments, get_ctc_model, get_params
from train_cr_ctc import add_model_arguments, get_ctc_model, get_params
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import (
@ -458,8 +458,9 @@ def decode_one_batch(
else:
encoder_out, encoder_out_lens = model.encoder(feature, feature_lens)
nnet_output = model.get_ctc_output(encoder_out)
# nnet_output = model.get_ctc_output(encoder_out)
# nnet_output is (N, T, C)
nnet_output = model.ctc_output(encoder_out) # (N, T, C)
if params.decoding_method == "ctc-greedy-search":
timestamps, hyps = ctc_greedy_search(

View File

@ -0,0 +1,209 @@
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
import k2
import torch
import torch.nn as nn
from encoder_interface import EncoderInterface
from scaling import ScaledLinear
from icefall.utils import make_pad_mask, time_warp
from lhotse.dataset import SpecAugment
class CTCModel(nn.Module):
"""It implements https://www.cs.toronto.edu/~graves/icml_2006.pdf
"Connectionist Temporal Classification: Labelling Unsegmented
Sequence Data with Recurrent Neural Networks"
"""
def __init__(
self,
encoder: EncoderInterface,
encoder_dim: int,
vocab_size: int,
):
"""
Args:
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
`logit_lens` of shape (N,).
encoder_dim:
The feature embedding dimension.
vocab_size:
The vocabulary size.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
self.encoder = encoder
self.ctc_output = nn.Sequential(
nn.Dropout(p=0.1),
ScaledLinear(encoder_dim, vocab_size),
nn.LogSoftmax(dim=-1),
)
def forward_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> torch.Tensor:
"""Compute CTC loss.
Args:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
targets:
Target Tensor of shape (sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC log-prob
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(),
reduction="sum",
)
return ctc_loss
def forward_cr_ctc(
self,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
targets: torch.Tensor,
target_lengths: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute CTC loss with consistency regularization loss.
Args:
encoder_out:
Encoder output, of shape (2 * N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (2 * N,).
targets:
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
to be un-padded and concatenated within 1 dimension.
"""
# Compute CTC loss
ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C)
ctc_loss = torch.nn.functional.ctc_loss(
log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C)
targets=targets.cpu(),
input_lengths=encoder_out_lens.cpu(),
target_lengths=target_lengths.cpu(),
reduction="sum",
)
# Compute consistency regularization loss
exchanged_targets = ctc_output.detach().chunk(2, dim=0)
exchanged_targets = torch.cat(
[exchanged_targets[1], exchanged_targets[0]], dim=0
) # exchange: [x1, x2] -> [x2, x1]
cr_loss = nn.functional.kl_div(
input=ctc_output,
target=exchanged_targets,
reduction="none",
log_target=True,
) # (2 * N, T, C)
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
return ctc_loss, cr_loss
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
y: k2.RaggedTensor,
warmup: float = 1.0,
use_cr_ctc: bool = False,
use_spec_aug: bool = False,
spec_augment: Optional[SpecAugment] = None,
supervision_segments: Optional[torch.Tensor] = None,
time_warp_factor: Optional[int] = 80,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
warmup: a floating point value which increases throughout training;
values >= 1.0 are fully warmed up and have all modules present.
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
assert y.num_axes == 2, y.num_axes
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
if use_cr_ctc:
if use_spec_aug:
assert spec_augment is not None and spec_augment.time_warp_factor < 1
# Apply time warping before input duplicating
assert supervision_segments is not None
x = time_warp(
x,
time_warp_factor=time_warp_factor,
supervision_segments=supervision_segments,
)
# Independently apply frequency masking and time masking to the two copies
x = spec_augment(x.repeat(2, 1, 1))
else:
x = x.repeat(2, 1, 1)
x_lens = x_lens.repeat(2)
y = k2.ragged.cat([y, y], axis=0)
# Compute encoder outputs
encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup)
assert torch.all(encoder_out_lens > 0)
row_splits = y.shape.row_splits(1)
y_lens = row_splits[1:] - row_splits[:-1]
# Compute CTC loss
targets = y.values
if not use_cr_ctc:
ctc_loss = self.forward_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
cr_loss = torch.empty(0)
else:
ctc_loss, cr_loss = self.forward_cr_ctc(
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
targets=targets,
target_lengths=y_lens,
)
ctc_loss = ctc_loss * 0.5
cr_loss = cr_loss * 0.5
return ctc_loss, cr_loss

File diff suppressed because it is too large Load Diff

View File

@ -1083,6 +1083,185 @@ def rescore_with_attention_decoder(
return ans
def rescore_with_attention_decoder_no_ngram_old(
lattice: k2.Fsa,
num_paths: int,
model: torch.nn.Module,
memory: torch.Tensor,
memory_key_padding_mask: Optional[torch.Tensor],
sos_id: int,
eos_id: int,
attention_scale: Optional[float] = None,
use_double_scores: bool = True,
) -> Dict[str, k2.Fsa]:
"""This function extracts `num_paths` paths from the given lattice and uses
an attention decoder to rescore them. The path with the highest score is
the decoding output.
Args:
lattice:
An FsaVec with axes [utt][state][arc].
num_paths:
Number of paths to extract from the given lattice for rescoring.
model:
A transformer model. See the class "Transformer" in
conformer_ctc/transformer.py for its interface.
memory:
The encoder memory of the given model. It is the output of
the last torch.nn.TransformerEncoder layer in the given model.
Its shape is `(T, N, C)`.
memory_key_padding_mask:
The padding mask for memory with shape `(N, T)`.
sos_id:
The token ID for SOS.
eos_id:
The token ID for EOS.
nbest_scale:
It's the scale applied to `lattice.scores`. A smaller value
leads to more unique paths at the risk of missing the correct path.
ngram_lm_scale:
Optional. It specifies the scale for n-gram LM scores.
attention_scale:
Optional. It specifies the scale for attention decoder scores.
Returns:
A dict of FsaVec, whose key contains a string
ngram_lm_scale_attention_scale and the value is the
best decoding path for each utterance in the lattice.
"""
# max_loop_count = 10
# loop_count = 0
# while loop_count <= max_loop_count:
# try:
# nbest = Nbest.from_lattice(
# lattice=lattice,
# num_paths=num_paths,
# use_double_scores=use_double_scores,
# nbest_scale=nbest_scale,
# )
# # nbest.fsa.scores are all 0s at this point
# nbest = nbest.intersect(lattice)
# break
# except RuntimeError as e:
# logging.info(f"Caught exception:\n{e}\n")
# logging.info(f"num_paths before decreasing: {num_paths}")
# num_paths = int(num_paths / 2)
# if loop_count >= max_loop_count or num_paths <= 0:
# logging.info("Return None as the resulting lattice is too large.")
# return None
# logging.info(
# "This OOM is not an error. You can ignore it. "
# "If your model does not converge well, or --max-duration "
# "is too large, or the input sound file is difficult to "
# "decode, you will meet this exception."
# )
# logging.info(f"num_paths after decreasing: {num_paths}")
# loop_count += 1
# # Now nbest.fsa has its scores set.
# # Also, nbest.fsa inherits the attributes from `lattice`.
# assert hasattr(nbest.fsa, "lm_scores")
# am_scores = nbest.compute_am_scores()
# ngram_lm_scores = nbest.compute_lm_scores()
# # The `tokens` attribute is set inside `compile_hlg.py`
# assert hasattr(nbest.fsa, "tokens")
# assert isinstance(nbest.fsa.tokens, torch.Tensor)
# path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
# # the shape of memory is (T, N, C), so we use axis=1 here
# expanded_memory = memory.index_select(1, path_to_utt_map)
# if memory_key_padding_mask is not None:
# # The shape of memory_key_padding_mask is (N, T), so we
# # use axis=0 here.
# expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
# 0, path_to_utt_map
# )
# else:
# expanded_memory_key_padding_mask = None
# # remove axis corresponding to states.
# tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
# tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
# tokens = tokens.remove_values_leq(0)
# token_ids = tokens.tolist()
# if len(token_ids) == 0:
# print("Warning: rescore_with_attention_decoder(): empty token-ids")
# return None
# path is a ragged tensor with dtype torch.int32.
# It has three axes [utt][path][arc_pos]
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
# Note that labels, aux_labels and scores contains 0s and -1s.
# The last entry in each sublist is -1.
# The axes are [path][token_id]
labels = k2.ragged.index(lattice.labels.contiguous(), path).remove_axis(0)
aux_labels = k2.ragged.index(lattice.aux_labels.contiguous(), path).remove_axis(0)
scores = k2.ragged.index(lattice.scores.contiguous(), path).remove_axis(0)
# Remove -1 from labels as we will use it to construct a linear FSA
labels = labels.remove_values_eq(-1)
fsa = k2.linear_fsa(labels)
fsa.aux_labels = aux_labels.values
# utt_to_path_shape has axes [utt][path]
utt_to_path_shape = path.shape.get_layer(0)
scores = k2.RaggedTensor(utt_to_path_shape, scores.sum())
path_to_utt_map = utt_to_path_shape.row_ids(1).to(torch.long)
# the shape of memory is (N, T, C), so we use axis=0 here
# expanded_encoder_out = encoder_out.index_select(0, path_to_utt_map)
# expanded_encoder_out_lens = encoder_out_lens.index_select(0, path_to_utt_map)
# # the shape of memory is (T, N, C), so we use axis=1 here
expanded_memory = memory.index_select(1, path_to_utt_map)
if memory_key_padding_mask is not None:
# The shape of memory_key_padding_mask is (N, T), so we
# use axis=0 here.
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
0, path_to_utt_map
)
else:
expanded_memory_key_padding_mask = None
token_ids = aux_labels.remove_values_leq(0).tolist()
nll = model.decoder_nll(
memory=expanded_memory,
memory_key_padding_mask=expanded_memory_key_padding_mask,
token_ids=token_ids,
sos_id=sos_id,
eos_id=eos_id,
)
assert nll.ndim == 2
assert nll.shape[0] == len(token_ids)
attention_scores = -nll.sum(dim=1)
if attention_scale is None:
attention_scale_list = [0.01, 0.05, 0.08]
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
else:
attention_scale_list = [attention_scale]
ans = dict()
for a_scale in attention_scale_list:
tot_scores = scores.values + a_scale * attention_scores
ragged_tot_scores = k2.RaggedTensor(utt_to_path_shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(fsa, max_indexes)
key = f"attention_scale_{a_scale}"
ans[key] = best_path
return ans
def rescore_with_attention_decoder_with_ngram(
lattice: k2.Fsa,
num_paths: int,

View File

@ -983,7 +983,8 @@ def write_error_stats_with_timestamps(
hyp_count = corr + hyp_sub + ins
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
return float(tot_err_rate), float(mean_delay), float(var_delay)
# return float(tot_err_rate), float(mean_delay), float(var_delay)
return float(tot_err_rate), mean_delay, var_delay
def write_surt_error_stats(