mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-13 12:02:21 +00:00
add conformer exps
This commit is contained in:
parent
b65873fb4c
commit
2fc53cd7ce
@ -44,6 +44,7 @@ from icefall.decode import (
|
|||||||
nbest_oracle,
|
nbest_oracle,
|
||||||
one_best_decoding,
|
one_best_decoding,
|
||||||
rescore_with_attention_decoder,
|
rescore_with_attention_decoder,
|
||||||
|
rescore_with_attention_decoder_no_ngram_old,
|
||||||
rescore_with_n_best_list,
|
rescore_with_n_best_list,
|
||||||
rescore_with_rnn_lm,
|
rescore_with_rnn_lm,
|
||||||
rescore_with_whole_lattice,
|
rescore_with_whole_lattice,
|
||||||
@ -459,6 +460,27 @@ def decode_one_batch(
|
|||||||
key = "ctc-greedy-search"
|
key = "ctc-greedy-search"
|
||||||
return {key: hyps}
|
return {key: hyps}
|
||||||
|
|
||||||
|
if params.method == "attention-decoder-rescoring-no-ngram":
|
||||||
|
best_path_dict = rescore_with_attention_decoder_no_ngram_old(
|
||||||
|
lattice=lattice,
|
||||||
|
num_paths=params.num_paths,
|
||||||
|
model=model,
|
||||||
|
memory=memory,
|
||||||
|
memory_key_padding_mask=memory_key_padding_mask,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
)
|
||||||
|
ans = dict()
|
||||||
|
for a_scale_str, best_path in best_path_dict.items():
|
||||||
|
# token_ids is a lit-of-list of IDs
|
||||||
|
token_ids = get_texts(best_path)
|
||||||
|
# hyps is a list of str, e.g., ['xxx yyy zzz', ...]
|
||||||
|
hyps = bpe_model.decode(token_ids)
|
||||||
|
# hyps is a list of list of str, e.g., [['xxx', 'yyy', 'zzz'], ... ]
|
||||||
|
hyps = [s.split() for s in hyps]
|
||||||
|
ans[a_scale_str] = hyps
|
||||||
|
return ans
|
||||||
|
|
||||||
if params.method == "nbest-oracle":
|
if params.method == "nbest-oracle":
|
||||||
# Note: You can also pass rescored lattices to it.
|
# Note: You can also pass rescored lattices to it.
|
||||||
# We choose the HLG decoded lattice for speed reasons
|
# We choose the HLG decoded lattice for speed reasons
|
||||||
@ -761,7 +783,7 @@ def main():
|
|||||||
params.sos_id = sos_id
|
params.sos_id = sos_id
|
||||||
params.eos_id = eos_id
|
params.eos_id = eos_id
|
||||||
|
|
||||||
if params.method == "ctc-decoding" or params.method == "ctc-greedy-search":
|
if params.method == "ctc-decoding" or params.method == "ctc-greedy-search" or params.method == "attention-decoder-rescoring-no-ngram":
|
||||||
HLG = None
|
HLG = None
|
||||||
H = k2.ctc_topo(
|
H = k2.ctc_topo(
|
||||||
max_token=max_token_id,
|
max_token=max_token_id,
|
||||||
|
@ -72,7 +72,7 @@ import sentencepiece as spm
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
from train import add_model_arguments, get_ctc_model, get_params
|
from train_cr_ctc import add_model_arguments, get_ctc_model, get_params
|
||||||
|
|
||||||
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -458,8 +458,9 @@ def decode_one_batch(
|
|||||||
else:
|
else:
|
||||||
encoder_out, encoder_out_lens = model.encoder(feature, feature_lens)
|
encoder_out, encoder_out_lens = model.encoder(feature, feature_lens)
|
||||||
|
|
||||||
nnet_output = model.get_ctc_output(encoder_out)
|
# nnet_output = model.get_ctc_output(encoder_out)
|
||||||
# nnet_output is (N, T, C)
|
# nnet_output is (N, T, C)
|
||||||
|
nnet_output = model.ctc_output(encoder_out) # (N, T, C)
|
||||||
|
|
||||||
if params.decoding_method == "ctc-greedy-search":
|
if params.decoding_method == "ctc-greedy-search":
|
||||||
timestamps, hyps = ctc_greedy_search(
|
timestamps, hyps = ctc_greedy_search(
|
||||||
|
209
egs/librispeech/ASR/conformer_ctc3/model_cr_ctc.py
Normal file
209
egs/librispeech/ASR/conformer_ctc3/model_cr_ctc.py
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||||
|
# Wei Kang,
|
||||||
|
# Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from encoder_interface import EncoderInterface
|
||||||
|
from scaling import ScaledLinear
|
||||||
|
from icefall.utils import make_pad_mask, time_warp
|
||||||
|
from lhotse.dataset import SpecAugment
|
||||||
|
|
||||||
|
|
||||||
|
class CTCModel(nn.Module):
|
||||||
|
"""It implements https://www.cs.toronto.edu/~graves/icml_2006.pdf
|
||||||
|
"Connectionist Temporal Classification: Labelling Unsegmented
|
||||||
|
Sequence Data with Recurrent Neural Networks"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder: EncoderInterface,
|
||||||
|
encoder_dim: int,
|
||||||
|
vocab_size: int,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder:
|
||||||
|
It is the transcription network in the paper. Its accepts
|
||||||
|
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
|
||||||
|
It returns two tensors: `logits` of shape (N, T, encoder_dm) and
|
||||||
|
`logit_lens` of shape (N,).
|
||||||
|
encoder_dim:
|
||||||
|
The feature embedding dimension.
|
||||||
|
vocab_size:
|
||||||
|
The vocabulary size.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
assert isinstance(encoder, EncoderInterface), type(encoder)
|
||||||
|
|
||||||
|
self.encoder = encoder
|
||||||
|
self.ctc_output = nn.Sequential(
|
||||||
|
nn.Dropout(p=0.1),
|
||||||
|
ScaledLinear(encoder_dim, vocab_size),
|
||||||
|
nn.LogSoftmax(dim=-1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward_ctc(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
targets: torch.Tensor,
|
||||||
|
target_lengths: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute CTC loss.
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
Encoder output, of shape (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
Encoder output lengths, of shape (N,).
|
||||||
|
targets:
|
||||||
|
Target Tensor of shape (sum(target_lengths)). The targets are assumed
|
||||||
|
to be un-padded and concatenated within 1 dimension.
|
||||||
|
"""
|
||||||
|
# Compute CTC log-prob
|
||||||
|
ctc_output = self.ctc_output(encoder_out) # (N, T, C)
|
||||||
|
|
||||||
|
ctc_loss = torch.nn.functional.ctc_loss(
|
||||||
|
log_probs=ctc_output.permute(1, 0, 2), # (T, N, C)
|
||||||
|
targets=targets.cpu(),
|
||||||
|
input_lengths=encoder_out_lens.cpu(),
|
||||||
|
target_lengths=target_lengths.cpu(),
|
||||||
|
reduction="sum",
|
||||||
|
)
|
||||||
|
return ctc_loss
|
||||||
|
|
||||||
|
def forward_cr_ctc(
|
||||||
|
self,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
targets: torch.Tensor,
|
||||||
|
target_lengths: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Compute CTC loss with consistency regularization loss.
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
Encoder output, of shape (2 * N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
Encoder output lengths, of shape (2 * N,).
|
||||||
|
targets:
|
||||||
|
Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed
|
||||||
|
to be un-padded and concatenated within 1 dimension.
|
||||||
|
"""
|
||||||
|
# Compute CTC loss
|
||||||
|
ctc_output = self.ctc_output(encoder_out) # (2 * N, T, C)
|
||||||
|
ctc_loss = torch.nn.functional.ctc_loss(
|
||||||
|
log_probs=ctc_output.permute(1, 0, 2), # (T, 2 * N, C)
|
||||||
|
targets=targets.cpu(),
|
||||||
|
input_lengths=encoder_out_lens.cpu(),
|
||||||
|
target_lengths=target_lengths.cpu(),
|
||||||
|
reduction="sum",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute consistency regularization loss
|
||||||
|
exchanged_targets = ctc_output.detach().chunk(2, dim=0)
|
||||||
|
exchanged_targets = torch.cat(
|
||||||
|
[exchanged_targets[1], exchanged_targets[0]], dim=0
|
||||||
|
) # exchange: [x1, x2] -> [x2, x1]
|
||||||
|
cr_loss = nn.functional.kl_div(
|
||||||
|
input=ctc_output,
|
||||||
|
target=exchanged_targets,
|
||||||
|
reduction="none",
|
||||||
|
log_target=True,
|
||||||
|
) # (2 * N, T, C)
|
||||||
|
length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1)
|
||||||
|
cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum()
|
||||||
|
|
||||||
|
return ctc_loss, cr_loss
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
x_lens: torch.Tensor,
|
||||||
|
y: k2.RaggedTensor,
|
||||||
|
warmup: float = 1.0,
|
||||||
|
use_cr_ctc: bool = False,
|
||||||
|
use_spec_aug: bool = False,
|
||||||
|
spec_augment: Optional[SpecAugment] = None,
|
||||||
|
supervision_segments: Optional[torch.Tensor] = None,
|
||||||
|
time_warp_factor: Optional[int] = 80,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x:
|
||||||
|
A 3-D tensor of shape (N, T, C).
|
||||||
|
x_lens:
|
||||||
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
|
before padding.
|
||||||
|
warmup: a floating point value which increases throughout training;
|
||||||
|
values >= 1.0 are fully warmed up and have all modules present.
|
||||||
|
"""
|
||||||
|
assert x.ndim == 3, x.shape
|
||||||
|
assert x_lens.ndim == 1, x_lens.shape
|
||||||
|
assert y.num_axes == 2, y.num_axes
|
||||||
|
|
||||||
|
assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0)
|
||||||
|
|
||||||
|
if use_cr_ctc:
|
||||||
|
if use_spec_aug:
|
||||||
|
assert spec_augment is not None and spec_augment.time_warp_factor < 1
|
||||||
|
# Apply time warping before input duplicating
|
||||||
|
assert supervision_segments is not None
|
||||||
|
x = time_warp(
|
||||||
|
x,
|
||||||
|
time_warp_factor=time_warp_factor,
|
||||||
|
supervision_segments=supervision_segments,
|
||||||
|
)
|
||||||
|
# Independently apply frequency masking and time masking to the two copies
|
||||||
|
x = spec_augment(x.repeat(2, 1, 1))
|
||||||
|
else:
|
||||||
|
x = x.repeat(2, 1, 1)
|
||||||
|
x_lens = x_lens.repeat(2)
|
||||||
|
y = k2.ragged.cat([y, y], axis=0)
|
||||||
|
|
||||||
|
# Compute encoder outputs
|
||||||
|
encoder_out, encoder_out_lens = self.encoder(x, x_lens, warmup=warmup)
|
||||||
|
assert torch.all(encoder_out_lens > 0)
|
||||||
|
|
||||||
|
row_splits = y.shape.row_splits(1)
|
||||||
|
y_lens = row_splits[1:] - row_splits[:-1]
|
||||||
|
|
||||||
|
# Compute CTC loss
|
||||||
|
targets = y.values
|
||||||
|
if not use_cr_ctc:
|
||||||
|
ctc_loss = self.forward_ctc(
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
targets=targets,
|
||||||
|
target_lengths=y_lens,
|
||||||
|
)
|
||||||
|
cr_loss = torch.empty(0)
|
||||||
|
else:
|
||||||
|
ctc_loss, cr_loss = self.forward_cr_ctc(
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
targets=targets,
|
||||||
|
target_lengths=y_lens,
|
||||||
|
)
|
||||||
|
ctc_loss = ctc_loss * 0.5
|
||||||
|
cr_loss = cr_loss * 0.5
|
||||||
|
|
||||||
|
return ctc_loss, cr_loss
|
1098
egs/librispeech/ASR/conformer_ctc3/train_cr_ctc.py
Executable file
1098
egs/librispeech/ASR/conformer_ctc3/train_cr_ctc.py
Executable file
File diff suppressed because it is too large
Load Diff
@ -1083,6 +1083,185 @@ def rescore_with_attention_decoder(
|
|||||||
return ans
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def rescore_with_attention_decoder_no_ngram_old(
|
||||||
|
lattice: k2.Fsa,
|
||||||
|
num_paths: int,
|
||||||
|
model: torch.nn.Module,
|
||||||
|
memory: torch.Tensor,
|
||||||
|
memory_key_padding_mask: Optional[torch.Tensor],
|
||||||
|
sos_id: int,
|
||||||
|
eos_id: int,
|
||||||
|
attention_scale: Optional[float] = None,
|
||||||
|
use_double_scores: bool = True,
|
||||||
|
) -> Dict[str, k2.Fsa]:
|
||||||
|
"""This function extracts `num_paths` paths from the given lattice and uses
|
||||||
|
an attention decoder to rescore them. The path with the highest score is
|
||||||
|
the decoding output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
lattice:
|
||||||
|
An FsaVec with axes [utt][state][arc].
|
||||||
|
num_paths:
|
||||||
|
Number of paths to extract from the given lattice for rescoring.
|
||||||
|
model:
|
||||||
|
A transformer model. See the class "Transformer" in
|
||||||
|
conformer_ctc/transformer.py for its interface.
|
||||||
|
memory:
|
||||||
|
The encoder memory of the given model. It is the output of
|
||||||
|
the last torch.nn.TransformerEncoder layer in the given model.
|
||||||
|
Its shape is `(T, N, C)`.
|
||||||
|
memory_key_padding_mask:
|
||||||
|
The padding mask for memory with shape `(N, T)`.
|
||||||
|
sos_id:
|
||||||
|
The token ID for SOS.
|
||||||
|
eos_id:
|
||||||
|
The token ID for EOS.
|
||||||
|
nbest_scale:
|
||||||
|
It's the scale applied to `lattice.scores`. A smaller value
|
||||||
|
leads to more unique paths at the risk of missing the correct path.
|
||||||
|
ngram_lm_scale:
|
||||||
|
Optional. It specifies the scale for n-gram LM scores.
|
||||||
|
attention_scale:
|
||||||
|
Optional. It specifies the scale for attention decoder scores.
|
||||||
|
Returns:
|
||||||
|
A dict of FsaVec, whose key contains a string
|
||||||
|
ngram_lm_scale_attention_scale and the value is the
|
||||||
|
best decoding path for each utterance in the lattice.
|
||||||
|
"""
|
||||||
|
# max_loop_count = 10
|
||||||
|
# loop_count = 0
|
||||||
|
# while loop_count <= max_loop_count:
|
||||||
|
# try:
|
||||||
|
# nbest = Nbest.from_lattice(
|
||||||
|
# lattice=lattice,
|
||||||
|
# num_paths=num_paths,
|
||||||
|
# use_double_scores=use_double_scores,
|
||||||
|
# nbest_scale=nbest_scale,
|
||||||
|
# )
|
||||||
|
# # nbest.fsa.scores are all 0s at this point
|
||||||
|
# nbest = nbest.intersect(lattice)
|
||||||
|
# break
|
||||||
|
# except RuntimeError as e:
|
||||||
|
# logging.info(f"Caught exception:\n{e}\n")
|
||||||
|
# logging.info(f"num_paths before decreasing: {num_paths}")
|
||||||
|
# num_paths = int(num_paths / 2)
|
||||||
|
# if loop_count >= max_loop_count or num_paths <= 0:
|
||||||
|
# logging.info("Return None as the resulting lattice is too large.")
|
||||||
|
# return None
|
||||||
|
# logging.info(
|
||||||
|
# "This OOM is not an error. You can ignore it. "
|
||||||
|
# "If your model does not converge well, or --max-duration "
|
||||||
|
# "is too large, or the input sound file is difficult to "
|
||||||
|
# "decode, you will meet this exception."
|
||||||
|
# )
|
||||||
|
# logging.info(f"num_paths after decreasing: {num_paths}")
|
||||||
|
# loop_count += 1
|
||||||
|
|
||||||
|
# # Now nbest.fsa has its scores set.
|
||||||
|
# # Also, nbest.fsa inherits the attributes from `lattice`.
|
||||||
|
# assert hasattr(nbest.fsa, "lm_scores")
|
||||||
|
|
||||||
|
# am_scores = nbest.compute_am_scores()
|
||||||
|
# ngram_lm_scores = nbest.compute_lm_scores()
|
||||||
|
|
||||||
|
# # The `tokens` attribute is set inside `compile_hlg.py`
|
||||||
|
# assert hasattr(nbest.fsa, "tokens")
|
||||||
|
# assert isinstance(nbest.fsa.tokens, torch.Tensor)
|
||||||
|
|
||||||
|
# path_to_utt_map = nbest.shape.row_ids(1).to(torch.long)
|
||||||
|
# # the shape of memory is (T, N, C), so we use axis=1 here
|
||||||
|
# expanded_memory = memory.index_select(1, path_to_utt_map)
|
||||||
|
|
||||||
|
# if memory_key_padding_mask is not None:
|
||||||
|
# # The shape of memory_key_padding_mask is (N, T), so we
|
||||||
|
# # use axis=0 here.
|
||||||
|
# expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||||
|
# 0, path_to_utt_map
|
||||||
|
# )
|
||||||
|
# else:
|
||||||
|
# expanded_memory_key_padding_mask = None
|
||||||
|
|
||||||
|
# # remove axis corresponding to states.
|
||||||
|
# tokens_shape = nbest.fsa.arcs.shape().remove_axis(1)
|
||||||
|
# tokens = k2.RaggedTensor(tokens_shape, nbest.fsa.tokens)
|
||||||
|
# tokens = tokens.remove_values_leq(0)
|
||||||
|
# token_ids = tokens.tolist()
|
||||||
|
|
||||||
|
# if len(token_ids) == 0:
|
||||||
|
# print("Warning: rescore_with_attention_decoder(): empty token-ids")
|
||||||
|
# return None
|
||||||
|
|
||||||
|
# path is a ragged tensor with dtype torch.int32.
|
||||||
|
# It has three axes [utt][path][arc_pos]
|
||||||
|
path = k2.random_paths(lattice, num_paths=num_paths, use_double_scores=True)
|
||||||
|
# Note that labels, aux_labels and scores contains 0s and -1s.
|
||||||
|
# The last entry in each sublist is -1.
|
||||||
|
# The axes are [path][token_id]
|
||||||
|
labels = k2.ragged.index(lattice.labels.contiguous(), path).remove_axis(0)
|
||||||
|
aux_labels = k2.ragged.index(lattice.aux_labels.contiguous(), path).remove_axis(0)
|
||||||
|
scores = k2.ragged.index(lattice.scores.contiguous(), path).remove_axis(0)
|
||||||
|
|
||||||
|
# Remove -1 from labels as we will use it to construct a linear FSA
|
||||||
|
labels = labels.remove_values_eq(-1)
|
||||||
|
fsa = k2.linear_fsa(labels)
|
||||||
|
fsa.aux_labels = aux_labels.values
|
||||||
|
|
||||||
|
# utt_to_path_shape has axes [utt][path]
|
||||||
|
utt_to_path_shape = path.shape.get_layer(0)
|
||||||
|
scores = k2.RaggedTensor(utt_to_path_shape, scores.sum())
|
||||||
|
|
||||||
|
path_to_utt_map = utt_to_path_shape.row_ids(1).to(torch.long)
|
||||||
|
# the shape of memory is (N, T, C), so we use axis=0 here
|
||||||
|
# expanded_encoder_out = encoder_out.index_select(0, path_to_utt_map)
|
||||||
|
# expanded_encoder_out_lens = encoder_out_lens.index_select(0, path_to_utt_map)
|
||||||
|
# # the shape of memory is (T, N, C), so we use axis=1 here
|
||||||
|
expanded_memory = memory.index_select(1, path_to_utt_map)
|
||||||
|
|
||||||
|
if memory_key_padding_mask is not None:
|
||||||
|
# The shape of memory_key_padding_mask is (N, T), so we
|
||||||
|
# use axis=0 here.
|
||||||
|
expanded_memory_key_padding_mask = memory_key_padding_mask.index_select(
|
||||||
|
0, path_to_utt_map
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
expanded_memory_key_padding_mask = None
|
||||||
|
|
||||||
|
token_ids = aux_labels.remove_values_leq(0).tolist()
|
||||||
|
|
||||||
|
nll = model.decoder_nll(
|
||||||
|
memory=expanded_memory,
|
||||||
|
memory_key_padding_mask=expanded_memory_key_padding_mask,
|
||||||
|
token_ids=token_ids,
|
||||||
|
sos_id=sos_id,
|
||||||
|
eos_id=eos_id,
|
||||||
|
)
|
||||||
|
assert nll.ndim == 2
|
||||||
|
assert nll.shape[0] == len(token_ids)
|
||||||
|
|
||||||
|
attention_scores = -nll.sum(dim=1)
|
||||||
|
|
||||||
|
if attention_scale is None:
|
||||||
|
attention_scale_list = [0.01, 0.05, 0.08]
|
||||||
|
attention_scale_list += [0.1, 0.3, 0.5, 0.6, 0.7, 0.9, 1.0]
|
||||||
|
attention_scale_list += [1.1, 1.2, 1.3, 1.5, 1.7, 1.9, 2.0]
|
||||||
|
attention_scale_list += [2.1, 2.2, 2.3, 2.5, 3.0, 4.0, 5.0]
|
||||||
|
else:
|
||||||
|
attention_scale_list = [attention_scale]
|
||||||
|
|
||||||
|
ans = dict()
|
||||||
|
|
||||||
|
for a_scale in attention_scale_list:
|
||||||
|
tot_scores = scores.values + a_scale * attention_scores
|
||||||
|
ragged_tot_scores = k2.RaggedTensor(utt_to_path_shape, tot_scores)
|
||||||
|
max_indexes = ragged_tot_scores.argmax()
|
||||||
|
best_path = k2.index_fsa(fsa, max_indexes)
|
||||||
|
|
||||||
|
key = f"attention_scale_{a_scale}"
|
||||||
|
ans[key] = best_path
|
||||||
|
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
def rescore_with_attention_decoder_with_ngram(
|
def rescore_with_attention_decoder_with_ngram(
|
||||||
lattice: k2.Fsa,
|
lattice: k2.Fsa,
|
||||||
num_paths: int,
|
num_paths: int,
|
||||||
|
@ -983,7 +983,8 @@ def write_error_stats_with_timestamps(
|
|||||||
hyp_count = corr + hyp_sub + ins
|
hyp_count = corr + hyp_sub + ins
|
||||||
|
|
||||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||||
return float(tot_err_rate), float(mean_delay), float(var_delay)
|
# return float(tot_err_rate), float(mean_delay), float(var_delay)
|
||||||
|
return float(tot_err_rate), mean_delay, var_delay
|
||||||
|
|
||||||
|
|
||||||
def write_surt_error_stats(
|
def write_surt_error_stats(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user