mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Support batch-wise forced-alignment (#970)
* support batch-wise forced-alignment based on beam search * add length_norm to HypothesisList.topk() * Use Hypothesis and HypothesisList instead
This commit is contained in:
parent
15d48e3a6a
commit
bcc5923ab9
@ -829,10 +829,21 @@ class HypothesisList(object):
|
||||
ans.add(hyp) # shallow copy
|
||||
return ans
|
||||
|
||||
def topk(self, k: int) -> "HypothesisList":
|
||||
"""Return the top-k hypothesis."""
|
||||
def topk(self, k: int, length_norm: bool = False) -> "HypothesisList":
|
||||
"""Return the top-k hypothesis.
|
||||
|
||||
Args:
|
||||
length_norm:
|
||||
If True, the `log_prob` of a hypothesis is normalized by the
|
||||
number of tokens in it.
|
||||
"""
|
||||
hyps = list(self._data.items())
|
||||
|
||||
if length_norm:
|
||||
hyps = sorted(
|
||||
hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True
|
||||
)[:k]
|
||||
else:
|
||||
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
|
||||
|
||||
ans = HypothesisList(dict(hyps))
|
||||
|
206
egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py
Normal file
206
egs/librispeech/ASR/pruned_transducer_stateless7/alignment.py
Normal file
@ -0,0 +1,206 @@
|
||||
# Copyright 2022-2023 Xiaomi Corp. (authors: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
|
||||
from beam_search import Hypothesis, HypothesisList, get_hyps_shape
|
||||
|
||||
# The force alignment problem can be formulated as finding
|
||||
# a path in a rectangular lattice, where the path starts
|
||||
# from the lower left corner and ends at the upper right
|
||||
# corner. The horizontal axis of the lattice is `t` (representing
|
||||
# acoustic frame indexes) and the vertical axis is `u` (representing
|
||||
# BPE tokens of the transcript).
|
||||
#
|
||||
# The notations `t` and `u` are from the paper
|
||||
# https://arxiv.org/pdf/1211.3711.pdf
|
||||
#
|
||||
# Beam search is used to find the path with the highest log probabilities.
|
||||
#
|
||||
# It assumes the maximum number of symbols that can be
|
||||
# emitted per frame is 1.
|
||||
|
||||
|
||||
def batch_force_alignment(
|
||||
model: torch.nn.Module,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_list: List[List[int]],
|
||||
beam_size: int = 4,
|
||||
) -> List[int]:
|
||||
"""Compute the force alignment of a batch of utterances given their transcripts
|
||||
in BPE tokens and the corresponding acoustic output from the encoder.
|
||||
|
||||
Caution:
|
||||
This function is modified from `modified_beam_search` in beam_search.py.
|
||||
We assume that the maximum number of sybmols per frame is 1.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||
encoder_out before padding.
|
||||
ys_list:
|
||||
A list of BPE token IDs list. We require that for each utterance i,
|
||||
len(ys_list[i]) <= encoder_out_lens[i].
|
||||
beam_size:
|
||||
Size of the beam used in beam search.
|
||||
|
||||
Returns:
|
||||
Return a list of frame indexes list for each utterance i,
|
||||
where len(ans[i]) == len(ys_list[i]).
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.ndim
|
||||
assert encoder_out.size(0) == len(ys_list), (encoder_out.size(0), len(ys_list))
|
||||
assert encoder_out.size(0) > 0, encoder_out.size(0)
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
device = next(model.parameters()).device
|
||||
|
||||
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||
input=encoder_out,
|
||||
lengths=encoder_out_lens.cpu(),
|
||||
batch_first=True,
|
||||
enforce_sorted=False,
|
||||
)
|
||||
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||
N = encoder_out.size(0)
|
||||
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||
assert N == batch_size_list[0], (N, batch_size_list)
|
||||
|
||||
sorted_indices = packed_encoder_out.sorted_indices.tolist()
|
||||
encoder_out_lens = encoder_out_lens.tolist()
|
||||
ys_lens = [len(ys) for ys in ys_list]
|
||||
sorted_encoder_out_lens = [encoder_out_lens[i] for i in sorted_indices]
|
||||
sorted_ys_lens = [ys_lens[i] for i in sorted_indices]
|
||||
sorted_ys_list = [ys_list[i] for i in sorted_indices]
|
||||
|
||||
B = [HypothesisList() for _ in range(N)]
|
||||
for i in range(N):
|
||||
B[i].add(
|
||||
Hypothesis(
|
||||
ys=[blank_id] * context_size,
|
||||
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||
timestamp=[],
|
||||
)
|
||||
)
|
||||
|
||||
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||
|
||||
offset = 0
|
||||
finalized_B = []
|
||||
for (t, batch_size) in enumerate(batch_size_list):
|
||||
start = offset
|
||||
end = offset + batch_size
|
||||
current_encoder_out = encoder_out.data[start:end]
|
||||
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||
offset = end
|
||||
|
||||
finalized_B = B[batch_size:] + finalized_B
|
||||
B = B[:batch_size]
|
||||
sorted_encoder_out_lens = sorted_encoder_out_lens[:batch_size]
|
||||
sorted_ys_lens = sorted_ys_lens[:batch_size]
|
||||
|
||||
hyps_shape = get_hyps_shape(B).to(device)
|
||||
|
||||
A = [list(b) for b in B]
|
||||
B = [HypothesisList() for _ in range(batch_size)]
|
||||
|
||||
ys_log_probs = torch.cat(
|
||||
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||
) # (num_hyps, 1)
|
||||
|
||||
decoder_input = torch.tensor(
|
||||
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||
device=device,
|
||||
dtype=torch.int64,
|
||||
) # (num_hyps, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||
|
||||
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||
# as index, so we use `to(torch.int64)` below.
|
||||
current_encoder_out = torch.index_select(
|
||||
current_encoder_out,
|
||||
dim=0,
|
||||
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out, decoder_out, project_input=False
|
||||
) # (num_hyps, 1, 1, vocab_size)
|
||||
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||
|
||||
log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||
log_probs.add_(ys_log_probs)
|
||||
|
||||
vocab_size = log_probs.size(-1)
|
||||
|
||||
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||
)
|
||||
ragged_log_probs = k2.RaggedTensor(
|
||||
shape=log_probs_shape, value=log_probs.reshape(-1)
|
||||
) # [batch][num_hyps*vocab_size]
|
||||
|
||||
for i in range(batch_size):
|
||||
for h, hyp in enumerate(A[i]):
|
||||
pos_u = len(hyp.timestamp)
|
||||
idx_offset = h * vocab_size
|
||||
if (sorted_encoder_out_lens[i] - 1 - t) >= (sorted_ys_lens[i] - pos_u):
|
||||
# emit blank token
|
||||
new_hyp = Hypothesis(
|
||||
log_prob=ragged_log_probs[i][idx_offset + blank_id],
|
||||
ys=hyp.ys[:],
|
||||
timestamp=hyp.timestamp[:],
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
if pos_u < sorted_ys_lens[i]:
|
||||
# emit non-blank token
|
||||
new_token = sorted_ys_list[i][pos_u]
|
||||
new_hyp = Hypothesis(
|
||||
log_prob=ragged_log_probs[i][idx_offset + new_token],
|
||||
ys=hyp.ys + [new_token],
|
||||
timestamp=hyp.timestamp + [t],
|
||||
)
|
||||
B[i].add(new_hyp)
|
||||
|
||||
if len(B[i]) > beam_size:
|
||||
B[i] = B[i].topk(beam_size, length_norm=True)
|
||||
|
||||
B = B + finalized_B
|
||||
sorted_hyps = [b.get_most_probable() for b in B]
|
||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||
hyps = [sorted_hyps[i] for i in unsorted_indices]
|
||||
ans = []
|
||||
for i, hyp in enumerate(hyps):
|
||||
assert hyp.ys[context_size:] == ys_list[i], (hyp.ys[context_size:], ys_list[i])
|
||||
ans.append(hyp.timestamp)
|
||||
|
||||
return ans
|
345
egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py
Executable file
345
egs/librispeech/ASR/pruned_transducer_stateless7/compute_ali.py
Executable file
@ -0,0 +1,345 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao,
|
||||
# Xiaoyu Yang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
The script gets forced-alignments based on the modified_beam_search decoding method.
|
||||
Both token-level alignments and word-level alignments are saved to the new cuts manifests.
|
||||
|
||||
It loads a checkpoint and uses it to get the forced-alignments.
|
||||
You can generate the checkpoint with the following command:
|
||||
|
||||
./pruned_transducer_stateless7/export.py \
|
||||
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
Usage of this script:
|
||||
|
||||
./pruned_transducer_stateless7/compute_ali.py \
|
||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--dataset test-clean \
|
||||
--max-duration 300 \
|
||||
--beam-size 4 \
|
||||
--cuts-out-dir data/fbank_ali_beam_search
|
||||
"""
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from alignment import batch_force_alignment
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from train import add_model_arguments, get_params, get_transducer_model
|
||||
|
||||
from icefall.utils import AttributeDict, convert_timestamp, parse_timestamp
|
||||
from lhotse import CutSet
|
||||
from lhotse.serialization import SequentialJsonlWriter
|
||||
from lhotse.supervision import AlignmentItem
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the checkpoint. "
|
||||
"The checkpoint is assumed to be saved by "
|
||||
"icefall.checkpoint.save_checkpoint().",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""The name of the dataset to compute alignments for.
|
||||
Possible values are:
|
||||
- test-clean
|
||||
- test-other
|
||||
- train-clean-100
|
||||
- train-clean-360
|
||||
- train-other-500
|
||||
- dev-clean
|
||||
- dev-other
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
help="""An integer indicating how many candidates we will keep for each
|
||||
frame. Used only when --decoding-method is beam_search or
|
||||
modified_beam_search.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; 2 means tri-gram",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cuts-out-dir",
|
||||
type=str,
|
||||
default="data/fbank_ali_beam_search",
|
||||
help="The dir to save the new cuts manifests with alignments",
|
||||
)
|
||||
|
||||
add_model_arguments(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def align_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
batch: dict,
|
||||
) -> Tuple[List[List[str]], List[List[str]], List[List[float]], List[List[float]]]:
|
||||
"""Get forced-alignments for one batch.
|
||||
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
|
||||
Returns:
|
||||
token_list:
|
||||
A list of token list.
|
||||
word_list:
|
||||
A list of word list.
|
||||
token_time_list:
|
||||
A list of timestamps list for tokens.
|
||||
word_time_list.
|
||||
A list of timestamps list for words.
|
||||
|
||||
where len(token_list) == len(word_list) == len(token_time_list) == len(word_time_list),
|
||||
len(token_list[i]) == len(token_time_list[i]),
|
||||
and len(word_list[i]) == len(word_time_list[i])
|
||||
|
||||
"""
|
||||
device = next(model.parameters()).device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is (N, T, C)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||
|
||||
texts = supervisions["text"]
|
||||
ys_list: List[List[int]] = sp.encode(texts, out_type=int)
|
||||
|
||||
frame_indexes = batch_force_alignment(
|
||||
model, encoder_out, encoder_out_lens, ys_list, params.beam_size
|
||||
)
|
||||
|
||||
token_list = []
|
||||
word_list = []
|
||||
token_time_list = []
|
||||
word_time_list = []
|
||||
for i in range(encoder_out.size(0)):
|
||||
tokens = sp.id_to_piece(ys_list[i])
|
||||
words = texts[i].split()
|
||||
token_time = convert_timestamp(
|
||||
frame_indexes[i], params.subsampling_factor, params.frame_shift_ms
|
||||
)
|
||||
word_time = parse_timestamp(tokens, token_time)
|
||||
assert len(word_time) == len(words), (len(word_time), len(words))
|
||||
|
||||
token_list.append(tokens)
|
||||
word_list.append(words)
|
||||
token_time_list.append(token_time)
|
||||
word_time_list.append(word_time)
|
||||
|
||||
return token_list, word_list, token_time_list, word_time_list
|
||||
|
||||
|
||||
def align_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
writer: SequentialJsonlWriter,
|
||||
) -> None:
|
||||
"""Get forced-alignments for the dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
sp:
|
||||
The BPE model.
|
||||
writer:
|
||||
Writer to save the cuts with alignments.
|
||||
"""
|
||||
log_interval = 20
|
||||
num_cuts = 0
|
||||
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
token_list, word_list, token_time_list, word_time_list = align_one_batch(
|
||||
params=params, model=model, sp=sp, batch=batch
|
||||
)
|
||||
|
||||
cut_list = batch["supervisions"]["cut"]
|
||||
for cut, token, word, token_time, word_time in zip(
|
||||
cut_list, token_list, word_list, token_time_list, word_time_list
|
||||
):
|
||||
assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}"
|
||||
token_ali = [
|
||||
AlignmentItem(
|
||||
symbol=token[i],
|
||||
start=round(token_time[i], ndigits=3),
|
||||
duration=None,
|
||||
)
|
||||
for i in range(len(token))
|
||||
]
|
||||
word_ali = [
|
||||
AlignmentItem(
|
||||
symbol=word[i], start=round(word_time[i], ndigits=3), duration=None
|
||||
)
|
||||
for i in range(len(word))
|
||||
]
|
||||
cut.supervisions[0].alignment = {"word": word_ali, "token": token_ali}
|
||||
writer.write(cut, flush=True)
|
||||
|
||||
num_cuts += len(cut_list)
|
||||
if batch_idx % log_interval == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> and <unk> are defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(params)
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
checkpoint = torch.load(args.checkpoint, map_location="cpu")
|
||||
model.load_state_dict(checkpoint["model"], strict=False)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
# we need cut ids to display recognition results.
|
||||
args.return_cuts = True
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
|
||||
if params.dataset == "test-clean":
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
elif params.dataset == "test-other":
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
elif params.dataset == "train-clean-100":
|
||||
train_clean_100_cuts = librispeech.train_clean_100_cuts()
|
||||
dl = librispeech.train_dataloaders(train_clean_100_cuts)
|
||||
elif params.dataset == "train-clean-360":
|
||||
train_clean_360_cuts = librispeech.train_clean_360_cuts()
|
||||
dl = librispeech.train_dataloaders(train_clean_360_cuts)
|
||||
elif params.dataset == "train-other-500":
|
||||
train_other_500_cuts = librispeech.train_other_500_cuts()
|
||||
dl = librispeech.train_dataloaders(train_other_500_cuts)
|
||||
elif params.dataset == "dev-clean":
|
||||
dev_clean_cuts = librispeech.dev_clean_cuts()
|
||||
dl = librispeech.valid_dataloaders(dev_clean_cuts)
|
||||
else:
|
||||
assert params.dataset == "dev-other", f"{params.dataset}"
|
||||
dev_other_cuts = librispeech.dev_other_cuts()
|
||||
dl = librispeech.valid_dataloaders(dev_other_cuts)
|
||||
|
||||
cuts_out_dir = Path(params.cuts_out_dir)
|
||||
cuts_out_dir.mkdir(parents=True, exist_ok=True)
|
||||
cuts_out_path = cuts_out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz"
|
||||
|
||||
with CutSet.open_writer(cuts_out_path) as writer:
|
||||
align_dataset(dl=dl, params=params, model=model, sp=sp, writer=writer)
|
||||
|
||||
logging.info(
|
||||
f"For dataset {params.dataset}, the cut manifest with framewise token alignments "
|
||||
f"and word alignments are saved to {cuts_out_path}"
|
||||
)
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
130
egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py
Executable file
130
egs/librispeech/ASR/pruned_transducer_stateless7/test_compute_ali.py
Executable file
@ -0,0 +1,130 @@
|
||||
#!/usr/bin/env python3
|
||||
#
|
||||
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
|
||||
# Zengwei Yao)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script compares the word-level alignments generated based on modified_beam_search decoding
|
||||
(in ./pruned_transducer_stateless7/compute_ali.py) to the reference alignments generated
|
||||
by torchaudio framework (in ./add_alignments.sh).
|
||||
|
||||
Usage:
|
||||
|
||||
./pruned_transducer_stateless7/compute_ali.py \
|
||||
--checkpoint ./pruned_transducer_stateless7/exp/pretrained.pt \
|
||||
--bpe-model data/lang_bpe_500/bpe.model \
|
||||
--dataset test-clean \
|
||||
--max-duration 300 \
|
||||
--beam-size 4 \
|
||||
--cuts-out-dir data/fbank_ali_beam_search
|
||||
|
||||
And the you can run:
|
||||
|
||||
./pruned_transducer_stateless7/test_compute_ali.py \
|
||||
--cuts-out-dir ./data/fbank_ali_test \
|
||||
--cuts-ref-dir ./data/fbank_ali_torch \
|
||||
--dataset train-clean-100
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from lhotse import load_manifest
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cuts-out-dir",
|
||||
type=Path,
|
||||
default="./data/fbank_ali",
|
||||
help="The dir that saves the generated cuts manifests with alignments",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cuts-ref-dir",
|
||||
type=Path,
|
||||
default="./data/fbank_ali_torch",
|
||||
help="The dir that saves the reference cuts manifests with alignments",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""The name of the dataset:
|
||||
Possible values are:
|
||||
- test-clean
|
||||
- test-other
|
||||
- train-clean-100
|
||||
- train-clean-360
|
||||
- train-other-500
|
||||
- dev-clean
|
||||
- dev-other
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
|
||||
cuts_out_jsonl = args.cuts_out_dir / f"librispeech_cuts_{args.dataset}.jsonl.gz"
|
||||
cuts_ref_jsonl = args.cuts_ref_dir / f"librispeech_cuts_{args.dataset}.jsonl.gz"
|
||||
|
||||
logging.info(f"Loading {cuts_out_jsonl} and {cuts_ref_jsonl}")
|
||||
cuts_out = load_manifest(cuts_out_jsonl)
|
||||
cuts_ref = load_manifest(cuts_ref_jsonl)
|
||||
cuts_ref = cuts_ref.sort_like(cuts_out)
|
||||
|
||||
all_time_diffs = []
|
||||
for cut_out, cut_ref in zip(cuts_out, cuts_ref):
|
||||
time_out = [
|
||||
ali.start
|
||||
for ali in cut_out.supervisions[0].alignment["word"]
|
||||
if ali.symbol != ""
|
||||
]
|
||||
time_ref = [
|
||||
ali.start
|
||||
for ali in cut_ref.supervisions[0].alignment["word"]
|
||||
if ali.symbol != ""
|
||||
]
|
||||
assert len(time_out) == len(time_ref), (len(time_out), len(time_ref))
|
||||
diff = [
|
||||
round(abs(out - ref), ndigits=3) for out, ref in zip(time_out, time_ref)
|
||||
]
|
||||
all_time_diffs += diff
|
||||
|
||||
all_time_diffs = torch.tensor(all_time_diffs)
|
||||
logging.info(
|
||||
f"For the word-level alignments abs difference on dataset {args.dataset}, "
|
||||
f"mean: {'%.2f' % all_time_diffs.mean()}s, std: {'%.2f' % all_time_diffs.std()}s"
|
||||
)
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
@ -1378,7 +1378,7 @@ def parse_timestamp(tokens: List[str], timestamp: List[float]) -> List[float]:
|
||||
List of timestamp of each word.
|
||||
"""
|
||||
start_token = b"\xe2\x96\x81".decode() # '_'
|
||||
assert len(tokens) == len(timestamp)
|
||||
assert len(tokens) == len(timestamp), (len(tokens), len(timestamp))
|
||||
ans = []
|
||||
for i in range(len(tokens)):
|
||||
flag = False
|
||||
|
Loading…
x
Reference in New Issue
Block a user