mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
add conformer exps
This commit is contained in:
parent
b65873fb4c
commit
2fc53cd7ce
@ -44,6 +44,7 @@ from icefall.decode import (
|
||||
nbest_oracle,
|
||||
one_best_decoding,
|
||||
rescore_with_attention_decoder,
|
||||
rescore_with_attention_decoder_no_ngram_old,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_rnn_lm,
|
||||
rescore_with_whole_lattice,
|
||||
@ -459,6 +460,27 @@ def decode_one_batch(
|
||||
key = "ctc-greedy-search"
|
||||
return {key: hyps}
|
||||
|
||||
if params.method == "attention-decoder-rescoring-no-ngram":
|
||||
best_path_dict = rescore_with_attention_decoder_no_ngram_old(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
model=model,
|
||||
memory=memory,
|
||||
memory_key_padding_mask=memory_key_padding_mask,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
ans = dict()
|
||||
for a_scale_str, best_path in best_path_dict.items():
|
||||
# token_ids is a lit-of-list of IDs
|
||||
token_ids = get_texts(best_path)
|
||||
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||
hyps = bpe_model.decode(token_ids)
|
||||
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||
hyps = [s.split() for s in hyps]
|
||||
ans[a_scale_str] = hyps
|
||||
return ans
|
||||
|
||||
if params.method == "nbest-oracle":
|
||||
# Note: You can also pass rescored lattices to it.
|
||||
# We choose the HLG decoded lattice for speed reasons
|
||||
@ -761,7 +783,7 @@ def main():
|
||||
params.sos_id = sos_id
|
||||
params.eos_id = eos_id
|
||||
|
||||
if params.method == "ctc-decoding" or params.method == "ctc-greedy-search":
|
||||
if params.method == "ctc-decoding" or params.method == "ctc-greedy-search" or params.method == "attention-decoder-rescoring-no-ngram":
|
||||
HLG = None
|
||||
H = k2.ctc_topo(
|
||||
max_token=max_token_id,
|
||||
|
@ -72,7 +72,7 @@ import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from train import add_model_arguments, get_ctc_model, get_params
|
||||
from train_cr_ctc import add_model_arguments, get_ctc_model, get_params
|
||||
|
||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||
from icefall.checkpoint import (
|
||||
@ -458,8 +458,9 @@ def decode_one_batch(
|
||||
else:
|
||||
encoder_out, encoder_out_lens = model.encoder(feature, feature_lens)
|
||||
|
||||
nnet_output = model.get_ctc_output(encoder_out)
|
||||
# nnet_output = model.get_ctc_output(encoder_out)
|
||||
# nnet_output is (N, T, C)
|
||||
nnet_output = model.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
if params.decoding_method == "ctc-greedy-search":
|
||||
timestamps, hyps = ctc_greedy_search(
|
||||
|
209
egs/librispeech/ASR/conformer_ctc3/model_cr_ctc.py
Normal file
209
egs/librispeech/ASR/conformer_ctc3/model_cr_ctc.py
Normal file
@ -0,0 +1,209 @@
|
||||
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Wei Kang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from encoder_interface import EncoderInterface
|
||||
from scaling import ScaledLinear
|
||||
from icefall.utils import make_pad_mask, time_warp
|
||||
from lhotse.dataset import SpecAugment
|
||||
|
||||
|
||||
class CTCModel(nn.Module):
|
||||
"""It implements https://www.cs.toronto.edu/~graves/icml_2006.pdf
|
||||
"Connectionist Temporal Classification: Labelling Unsegmented
|
||||
Sequence Data with Recurrent Neural Networks"
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder: EncoderInterface,
|
||||
encoder_dim: int,
|
||||
vocab_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
encoder:
|
||||
It is the transcription network in the paper. Its accepts
|
||||
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||
`logit_lens` of shape (N,).
|
||||
encoder_dim:
|
||||
The feature embedding dimension.
|
||||
vocab_size:
|
||||
The vocabulary size.
|
||||
"""
|
||||
super().__init__()
|
||||
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||
|
||||
self.encoder = encoder
|
||||
self.ctc_output = nn.Sequential(
|
||||
nn.Dropout(p=0.1),
|
||||
ScaledLinear(encoder_dim, vocab_size),
|
||||
nn.LogSoftmax(dim=-1),
|
||||
)
|
||||
|
||||
def forward_ctc(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
target_lengths: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute CTC loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (N,).
|
||||
targets:
|
||||
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||
to be un-padded and concatenated within 1 dimension.
|
||||
"""
|
||||
# Compute CTC log-prob
|
||||
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||
|
||||
ctc_loss = torch.nn.functional.ctc_loss(
|
||||
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||
targets=targets.cpu(),
|
||||
input_lengths=encoder_out_lens.cpu(),
|
||||
target_lengths=target_lengths.cpu(),
|
||||
reduction="sum",
|
||||
)
|
||||
return ctc_loss
|
||||
|
||||
def forward_cr_ctc(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
targets: torch.Tensor,
|
||||
target_lengths: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute CTC loss with consistency regularization loss.
|
||||
Args:
|
||||
encoder_out:
|
||||
Encoder output, of shape (2 * N, T, C).
|
||||
encoder_out_lens:
|
||||
Encoder output lengths, of shape (2 * N,).
|
||||
targets:
|
||||
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
|
||||
to be un-padded and concatenated within 1 dimension.
|
||||
"""
|
||||
# Compute CTC loss
|
||||
ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C)
|
||||
ctc_loss = torch.nn.functional.ctc_loss(
|
||||
log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C)
|
||||
targets=targets.cpu(),
|
||||
input_lengths=encoder_out_lens.cpu(),
|
||||
target_lengths=target_lengths.cpu(),
|
||||
reduction="sum",
|
||||
)
|
||||
|
||||
# Compute consistency regularization loss
|
||||
exchanged_targets = ctc_output.detach().chunk(2, dim=0)
|
||||
exchanged_targets = torch.cat(
|
||||
[exchanged_targets[1], exchanged_targets[0]], dim=0
|
||||
) # exchange: [x1, x2] -> [x2, x1]
|
||||
cr_loss = nn.functional.kl_div(
|
||||
input=ctc_output,
|
||||
target=exchanged_targets,
|
||||
reduction="none",
|
||||
log_target=True,
|
||||
) # (2 * N, T, C)
|
||||
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
|
||||
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
|
||||
|
||||
return ctc_loss, cr_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_lens: torch.Tensor,
|
||||
y: k2.RaggedTensor,
|
||||
warmup: float = 1.0,
|
||||
use_cr_ctc: bool = False,
|
||||
use_spec_aug: bool = False,
|
||||
spec_augment: Optional[SpecAugment] = None,
|
||||
supervision_segments: Optional[torch.Tensor] = None,
|
||||
time_warp_factor: Optional[int] = 80,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args:
|
||||
x:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
x_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
warmup: a floating point value which increases throughout training;
|
||||
values >= 1.0 are fully warmed up and have all modules present.
|
||||
"""
|
||||
assert x.ndim == 3, x.shape
|
||||
assert x_lens.ndim == 1, x_lens.shape
|
||||
assert y.num_axes == 2, y.num_axes
|
||||
|
||||
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
|
||||
|
||||
if use_cr_ctc:
|
||||
if use_spec_aug:
|
||||
assert spec_augment is not None and spec_augment.time_warp_factor < 1
|
||||
# Apply time warping before input duplicating
|
||||
assert supervision_segments is not None
|
||||
x = time_warp(
|
||||
x,
|
||||
time_warp_factor=time_warp_factor,
|
||||
supervision_segments=supervision_segments,
|
||||
)
|
||||
# Independently apply frequency masking and time masking to the two copies
|
||||
x = spec_augment(x.repeat(2, 1, 1))
|
||||
else:
|
||||
x = x.repeat(2, 1, 1)
|
||||
x_lens = x_lens.repeat(2)
|
||||
y = k2.ragged.cat([y, y], axis=0)
|
||||
|
||||
# Compute encoder outputs
|
||||
encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup)
|
||||
assert torch.all(encoder_out_lens > 0)
|
||||
|
||||
row_splits = y.shape.row_splits(1)
|
||||
y_lens = row_splits[1:] - row_splits[:-1]
|
||||
|
||||
# Compute CTC loss
|
||||
targets = y.values
|
||||
if not use_cr_ctc:
|
||||
ctc_loss = self.forward_ctc(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
targets=targets,
|
||||
target_lengths=y_lens,
|
||||
)
|
||||
cr_loss = torch.empty(0)
|
||||
else:
|
||||
ctc_loss, cr_loss = self.forward_cr_ctc(
|
||||
encoder_out=encoder_out,
|
||||
encoder_out_lens=encoder_out_lens,
|
||||
targets=targets,
|
||||
target_lengths=y_lens,
|
||||
)
|
||||
ctc_loss = ctc_loss * 0.5
|
||||
cr_loss = cr_loss * 0.5
|
||||
|
||||
return ctc_loss, cr_loss
|
1098
egs/librispeech/ASR/conformer_ctc3/train_cr_ctc.py
Executable file
1098
egs/librispeech/ASR/conformer_ctc3/train_cr_ctc.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -1083,6 +1083,185 @@ def rescore_with_attention_decoder(
|
||||
return ans
|
||||
|
||||
|
||||
def rescore_with_attention_decoder_no_ngram_old(
|
||||
lattice: k2.Fsa,
|
||||
num_paths: int,
|
||||
model: torch.nn.Module,
|
||||
memory: torch.Tensor,
|
||||
memory_key_padding_mask: Optional[torch.Tensor],
|
||||
sos_id: int,
|
||||
eos_id: int,
|
||||
attention_scale: Optional[float] = None,
|
||||
use_double_scores: bool = True,
|
||||
) -> Dict[str, k2.Fsa]:
|
||||
"""This function extracts `num_paths` paths from the given lattice and uses
|
||||
an attention decoder to rescore them. The path with the highest score is
|
||||
the decoding output.
|
||||
|
||||
Args:
|
||||
lattice:
|
||||
An FsaVec with axes [utt][state][arc].
|
||||
num_paths:
|
||||
Number of paths to extract from the given lattice for rescoring.
|
||||
model:
|
||||
A transformer model. See the class "Transformer" in
|
||||
conformer_ctc/transformer.py for its interface.
|
||||
memory:
|
||||
The encoder memory of the given model. It is the output of
|
||||
the last torch.nn.TransformerEncoder layer in the given model.
|
||||
Its shape is `(T, N, C)`.
|
||||
memory_key_padding_mask:
|
||||
The padding mask for memory with shape `(N, T)`.
|
||||
sos_id:
|
||||
The token ID for SOS.
|
||||
eos_id:
|
||||
The token ID for EOS.
|
||||
nbest_scale:
|
||||
It's the scale applied to `lattice.scores`. A smaller value
|
||||
leads to more unique paths at the risk of missing the correct path.
|
||||
ngram_lm_scale:
|
||||
Optional. It specifies the scale for n-gram LM scores.
|
||||
attention_scale:
|
||||
Optional. It specifies the scale for attention decoder scores.
|
||||
Returns:
|
||||
A dict of FsaVec, whose key contains a string
|
||||
ngram_lm_scale_attention_scale and the value is the
|
||||
best decoding path for each utterance in the lattice.
|
||||
"""
|
||||
# max_loop_count = 10
|
||||
# loop_count = 0
|
||||
# while loop_count <= max_loop_count:
|
||||
# try:
|
||||
# nbest = Nbest.from_lattice(
|
||||
# lattice=lattice,
|
||||
# num_paths=num_paths,
|
||||
# use_double_scores=use_double_scores,
|
||||
# nbest_scale=nbest_scale,
|
||||
# )
|
||||
# # nbest.fsa.scores are all 0s at this point
|
||||
# nbest = nbest.intersect(lattice)
|
||||
# break
|
||||
# except RuntimeError as e:
|
||||
# logging.info(f"Caught exception:\n{e}\n")
|
||||
# logging.info(f"num_paths before decreasing: {num_paths}")
|
||||
# num_paths = int(num_paths / 2)
|
||||
# if loop_count >= max_loop_count or num_paths <= 0:
|
||||
# logging.info("Return None as the resulting lattice is too large.")
|
||||
# return None
|
||||
# logging.info(
|
||||
# "This OOM is not an error. You can ignore it. "
|
||||
# "If your model does not converge well, or --max-duration "
|
||||
# "is too large, or the input sound file is difficult to "
|
||||
# "decode, you will meet this exception."
|
||||
# )
|
||||
# logging.info(f"num_paths after decreasing: {num_paths}")
|
||||
# loop_count += 1
|
||||
|
||||
# # Now nbest.fsa has its scores set.
|
||||
# # Also, nbest.fsa inherits the attributes from `lattice`.
|
||||
# assert hasattr(nbest.fsa, "lm_scores")
|
||||
|
||||
# am_scores = nbest.compute_am_scores()
|
||||
# ngram_lm_scores = nbest.compute_lm_scores()
|
||||
|
||||
# # The `tokens` attribute is set inside `compile_hlg.py`
|
||||
# assert hasattr(nbest.fsa, "tokens")
|
||||
# assert isinstance(nbest.fsa.tokens, torch.Tensor)
|
||||
|
||||
# path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
|
||||
# # the shape of memory is (T, N, C), so we use axis=1 here
|
||||
# expanded_memory = memory.index_select(1, path_to_utt_map)
|
||||
|
||||
# if memory_key_padding_mask is not None:
|
||||
# # The shape of memory_key_padding_mask is (N, T), so we
|
||||
# # use axis=0 here.
|
||||
# expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||
# 0, path_to_utt_map
|
||||
# )
|
||||
# else:
|
||||
# expanded_memory_key_padding_mask = None
|
||||
|
||||
# # remove axis corresponding to states.
|
||||
# tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
|
||||
# tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
|
||||
# tokens = tokens.remove_values_leq(0)
|
||||
# token_ids = tokens.tolist()
|
||||
|
||||
# if len(token_ids) == 0:
|
||||
# print("Warning: rescore_with_attention_decoder(): empty token-ids")
|
||||
# return None
|
||||
|
||||
# path is a ragged tensor with dtype torch.int32.
|
||||
# It has three axes [utt][path][arc_pos]
|
||||
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
|
||||
# Note that labels, aux_labels and scores contains 0s and -1s.
|
||||
# The last entry in each sublist is -1.
|
||||
# The axes are [path][token_id]
|
||||
labels = k2.ragged.index(lattice.labels.contiguous(), path).remove_axis(0)
|
||||
aux_labels = k2.ragged.index(lattice.aux_labels.contiguous(), path).remove_axis(0)
|
||||
scores = k2.ragged.index(lattice.scores.contiguous(), path).remove_axis(0)
|
||||
|
||||
# Remove -1 from labels as we will use it to construct a linear FSA
|
||||
labels = labels.remove_values_eq(-1)
|
||||
fsa = k2.linear_fsa(labels)
|
||||
fsa.aux_labels = aux_labels.values
|
||||
|
||||
# utt_to_path_shape has axes [utt][path]
|
||||
utt_to_path_shape = path.shape.get_layer(0)
|
||||
scores = k2.RaggedTensor(utt_to_path_shape, scores.sum())
|
||||
|
||||
path_to_utt_map = utt_to_path_shape.row_ids(1).to(torch.long)
|
||||
# the shape of memory is (N, T, C), so we use axis=0 here
|
||||
# expanded_encoder_out = encoder_out.index_select(0, path_to_utt_map)
|
||||
# expanded_encoder_out_lens = encoder_out_lens.index_select(0, path_to_utt_map)
|
||||
# # the shape of memory is (T, N, C), so we use axis=1 here
|
||||
expanded_memory = memory.index_select(1, path_to_utt_map)
|
||||
|
||||
if memory_key_padding_mask is not None:
|
||||
# The shape of memory_key_padding_mask is (N, T), so we
|
||||
# use axis=0 here.
|
||||
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||
0, path_to_utt_map
|
||||
)
|
||||
else:
|
||||
expanded_memory_key_padding_mask = None
|
||||
|
||||
token_ids = aux_labels.remove_values_leq(0).tolist()
|
||||
|
||||
nll = model.decoder_nll(
|
||||
memory=expanded_memory,
|
||||
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
||||
token_ids=token_ids,
|
||||
sos_id=sos_id,
|
||||
eos_id=eos_id,
|
||||
)
|
||||
assert nll.ndim == 2
|
||||
assert nll.shape[0] == len(token_ids)
|
||||
|
||||
attention_scores = -nll.sum(dim=1)
|
||||
|
||||
if attention_scale is None:
|
||||
attention_scale_list = [0.01, 0.05, 0.08]
|
||||
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||
else:
|
||||
attention_scale_list = [attention_scale]
|
||||
|
||||
ans = dict()
|
||||
|
||||
for a_scale in attention_scale_list:
|
||||
tot_scores = scores.values + a_scale * attention_scores
|
||||
ragged_tot_scores = k2.RaggedTensor(utt_to_path_shape, tot_scores)
|
||||
max_indexes = ragged_tot_scores.argmax()
|
||||
best_path = k2.index_fsa(fsa, max_indexes)
|
||||
|
||||
key = f"attention_scale_{a_scale}"
|
||||
ans[key] = best_path
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
def rescore_with_attention_decoder_with_ngram(
|
||||
lattice: k2.Fsa,
|
||||
num_paths: int,
|
||||
|
@ -983,7 +983,8 @@ def write_error_stats_with_timestamps(
|
||||
hyp_count = corr + hyp_sub + ins
|
||||
|
||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||
return float(tot_err_rate), float(mean_delay), float(var_delay)
|
||||
# return float(tot_err_rate), float(mean_delay), float(var_delay)
|
||||
return float(tot_err_rate), mean_delay, var_delay
|
||||
|
||||
|
||||
def write_surt_error_stats(
|
||||
|
Loading…
x
Reference in New Issue
Block a user