Support long audios recognition (#980)

* support long file transcription

* rename recipe as long_file_recog

* add docs

* support multi-gpu decoding

* style fix
This commit is contained in:
Zengwei Yao 2023-05-19 20:27:55 +08:00 committed by GitHub
parent f18b539fbc
commit a7e142b7ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 1681 additions and 1 deletions

View File

@ -0,0 +1,94 @@
#!/usr/bin/env bash
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
export CUDA_VISIBLE_DEVICES="0,1,2,3"
set -eou pipefail
# This script is used to recogize long audios. The process is as follows:
# 1) Split long audios into chunks with overlaps.
# 2) Perform speech recognition on chunks, getting tokens and timestamps.
# 3) Merge the overlapped chunks into utterances acording to the timestamps.
# Each chunk (except the first and the last) is padded with extra left side and right side.
# The chunk length is: left_side + chunk_size + right_side.
chunk=30.0
extra=2.0
stage=1
stop_stage=4
# We assume that you have downloaded the LibriLight dataset
# with audio files in $corpus_dir and texts in $text_dir
corpus_dir=$PWD/download/libri-light
text_dir=$PWD/download/librilight_text
# Path to save the manifests
output_dir=$PWD/data/librilight
world_size=4
log() {
# This function is from espnet
local fname=${BASH_SOURCE[1]##*/}
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
}
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
# We will get librilight_recodings_{subset}.jsonl.gz and librilight_supervisions_{subset}.jsonl.gz
# saved in $output_dir/manifests
log "Stage 1: Prepare LibriLight manifest"
lhotse prepare librilight $corpus_dir $text_dir $output_dir/manifests -j 10
fi
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
# Chunk manifests are saved to $output_dir/manifests_chunk/librilight_cuts_{subset}.jsonl.gz
log "Stage 2: Split long audio into chunks"
./long_file_recog/split_into_chunks.py \
--manifest-in-dir $output_dir/manifests \
--manifest-out-dir $output_dir/manifests_chunk \
--chunk $chunk \
--extra $extra # Extra duration (in seconds) at both sides
fi
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
# Recognized tokens and timestamps are saved to $output_dir/manifests_chunk_recog/librilight_cuts_{subset}.jsonl.gz
# This script loads torchscript models, exported by `torch.jit.script()`,
# and uses it to decode waves.
# You can download the jit model from https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
log "Stage 3: Perform speech recognition on splitted chunks"
for subset in small median large; do
./long_file_recog/recognize.py \
--world-size $world_size \
--num-workers 8 \
--subset $subset \
--manifest-in-dir $output_dir/manifests_chunk \
--manifest-out-dir $output_dir/manifests_chunk_recog \
--nn-model-filename long_file_recog/exp/jit_model.pt \
--bpe-model data/lang_bpe_500/bpe.model \
--max-duration 2400 \
--decoding-method greedy_search
--master 12345
if [ $world_size -gt 1 ]; then
# Combine manifests from different jobs
lhotse combine $(find $output_dir/manifests_chunk_recog -name librilight_cuts_${subset}_job_*.jsonl.gz | tr "\n" " ") $output_dir/manifests_chunk_recog/librilight_cuts_${subset}.jsonl.gz
fi
done
fi
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
# Final results are saved in $output_dir/manifests/librilight_cuts_{subset}.jsonl.gz
log "Stage 4: Merge splitted chunks into utterances."
./long_file_recog/merge_chunks.py \
--manifest-in-dir $output_dir/manifests_chunk_recog \
--manifest-out-dir $output_dir/manifests \
--bpe-model data/lang_bpe_500/bpe.model \
--extra $extra
fi

View File

@ -0,0 +1,189 @@
# Copyright 2021 Piotr Żelasko
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Union
import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
from lhotse.cut import Cut
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
CutConcatenate,
CutMix,
DynamicBucketingSampler,
K2SpeechRecognitionDataset,
PrecomputedFeatures,
SimpleCutSampler,
SpecAugment,
)
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
AudioSamples,
BatchIO,
OnTheFlyFeatures,
)
from torch.utils.data import DataLoader
from icefall.utils import str2bool
class SpeechRecognitionDataset(K2SpeechRecognitionDataset):
def __init__(
self,
return_cuts: bool = False,
input_strategy: BatchIO = PrecomputedFeatures(),
):
super().__init__(return_cuts=return_cuts, input_strategy=input_strategy)
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[Cut]]]:
"""
Return a new batch, with the batch size automatically determined using the constraints
of max_frames and max_cuts.
"""
self.hdf5_fix.update()
# Note: don't sort cuts here
# Sort the cuts by duration so that the first one determines the batch time dimensions.
# cuts = cuts.sort_by_duration(ascending=False)
# Get a tensor with batched feature matrices, shape (B, T, F)
# Collation performs auto-padding, if necessary.
input_tpl = self.input_strategy(cuts)
if len(input_tpl) == 3:
# An input strategy with fault tolerant audio reading mode.
# "cuts" may be a subset of the original "cuts" variable,
# that only has cuts for which we succesfully read the audio.
inputs, _, cuts = input_tpl
else:
inputs, _ = input_tpl
# Get a dict of tensors that encode the positional information about supervisions
# in the batch of feature matrices. The tensors are named "sequence_idx",
# "start_frame/sample" and "num_frames/samples".
supervision_intervals = self.input_strategy.supervision_intervals(cuts)
batch = {"inputs": inputs, "supervisions": supervision_intervals}
if self.return_cuts:
batch["supervisions"]["cut"] = [cut for cut in cuts]
return batch
class AsrDataModule:
"""
DataModule for k2 ASR experiments.
It assumes there is always one train and valid dataloader,
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
and test-other).
It contains all the common data pipeline modules used in ASR
experiments, e.g.:
- dynamic batch size,
- bucketing samplers,
- cut concatenation,
- augmentation,
- on-the-fly feature extraction
This class should be derived for specific corpora used in ASR tasks.
"""
def __init__(self, args: argparse.Namespace):
self.args = args
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(
title="ASR data related options",
description="These options are used for the preparation of "
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
"effective batch sizes, sampling strategies, applied data "
"augmentations, etc.",
)
group.add_argument(
"--manifest-dir",
type=Path,
default=Path("data/manifests_chunk"),
help="Path to directory with train/valid/test cuts.",
)
group.add_argument(
"--max-duration",
type=int,
default=600.0,
help="Maximum pooled recordings duration (seconds) in a "
"single batch. You can reduce it if it causes CUDA OOM.",
)
group.add_argument(
"--bucketing-sampler",
type=str2bool,
default=True,
help="When enabled, the batches will come from buckets of "
"similar duration (saves padding frames).",
)
group.add_argument(
"--return-cuts",
type=str2bool,
default=True,
help="When enabled, each batch will have the "
"field: batch['supervisions']['cut'] with the cuts that "
"were used to construct it.",
)
group.add_argument(
"--num-workers",
type=int,
default=8,
help="The number of training dataloader workers that "
"collect the batches.",
)
group.add_argument(
"--input-strategy",
type=str,
default="PrecomputedFeatures",
help="AudioSamples or PrecomputedFeatures",
)
def dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = SpeechRecognitionDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
return_cuts=self.args.return_cuts,
)
sampler = SimpleCutSampler(
cuts,
max_duration=self.args.max_duration,
shuffle=False,
drop_last=False,
)
logging.debug("About to create test dataloader")
test_dl = DataLoader(
test,
batch_size=None,
sampler=sampler,
num_workers=self.args.num_workers,
persistent_workers=False,
)
return test_dl
@lru_cache()
def load_subset(self, cuts_filename: Path) -> CutSet:
return load_manifest_lazy(cuts_filename)

View File

@ -0,0 +1,613 @@
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Xiaoyu Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Union
import k2
import torch
from icefall.decode import one_best_decoding
from icefall.utils import DecodingResults, get_texts, get_texts_with_timestamp
def fast_beam_search(
model: torch.nn.Module,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
temperature: float = 1.0,
) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
temperature:
Softmax temperature.
Returns:
Return an FsaVec with axes [utt][state][arc] containing the decoded
lattice. Note: When the input graph is a TrivialGraph, the returned
lattice is actually an acceptor.
"""
assert encoder_out.ndim == 3
context_size = model.decoder.context_size
vocab_size = model.decoder.vocab_size
B, T, C = encoder_out.shape
config = k2.RnntDecodingConfig(
vocab_size=vocab_size,
decoder_history_len=context_size,
beam=beam,
max_contexts=max_contexts,
max_states=max_states,
)
individual_streams = []
for i in range(B):
individual_streams.append(k2.RnntDecodingStream(decoding_graph))
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
encoder_out = model.joiner.encoder_proj(encoder_out)
for t in range(T):
# shape is a RaggedShape of shape (B, context)
# contexts is a Tensor of shape (shape.NumElements(), context_size)
shape, contexts = decoding_streams.get_contexts()
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
contexts = contexts.to(torch.int64)
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
decoder_out = model.decoder(contexts, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# current_encoder_out is of shape
# (shape.NumElements(), 1, joiner_dim)
# fmt: off
current_encoder_out = torch.index_select(
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
)
# fmt: on
logits = model.joiner(
current_encoder_out.unsqueeze(2),
decoder_out.unsqueeze(1),
project_input=False,
)
logits = logits.squeeze(1).squeeze(1)
log_probs = (logits / temperature).log_softmax(dim=-1)
decoding_streams.advance(log_probs)
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
return lattice
def fast_beam_search_one_best(
model: torch.nn.Module,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
temperature: float = 1.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using fast beam search, and then
the shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a LG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
temperature:
Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns:
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
temperature=temperature,
)
best_path = one_best_decoding(lattice)
if not return_timestamps:
return get_texts(best_path)
else:
return get_texts_with_timestamp(best_path)
def greedy_search_batch(
model: torch.nn.Module,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
return_timestamps:
Whether to return timestamps.
Returns:
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
device = next(model.parameters()).device
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)]
# timestamp[n][i] is the frame index after subsampling
# on which hyp[n][i] is decoded
timestamps = [[] for _ in range(N)]
# scores[n][i] is the logits on which hyp[n][i] is decoded
scores = [[] for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (N, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out: (N, 1, decoder_out_dim)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
)
# logits'shape (batch_size, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
log_probs = logits.log_softmax(dim=-1)
assert log_probs.ndim == 2, log_probs.shape
y = log_probs.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v not in (blank_id, unk_id):
hyps[i].append(v)
timestamps[i].append(t)
scores[i].append(log_probs[i, v].item())
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
dtype=torch.int64,
)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
sorted_ans = [h[context_size:] for h in hyps]
ans = []
ans_timestamps = []
ans_scores = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(timestamps[unsorted_indices[i]])
ans_scores.append(scores[unsorted_indices[i]])
if not return_timestamps:
return ans
else:
return DecodingResults(
hyps=ans,
timestamps=ans_timestamps,
scores=ans_scores,
)
@dataclass
class Hypothesis:
# The predicted tokens so far.
# Newly predicted tokens are appended to `ys`.
ys: List[int]
# The log prob of ys.
# It contains only one entry.
log_prob: torch.Tensor
# timestamp[i] is the frame index after subsampling
# on which ys[i] is decoded
timestamp: List[int] = field(default_factory=list)
@property
def key(self) -> str:
"""Return a string representation of self.ys"""
return "_".join(map(str, self.ys))
class HypothesisList(object):
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
"""
Args:
data:
A dict of Hypotheses. Its key is its `value.key`.
"""
if data is None:
self._data = {}
else:
self._data = data
@property
def data(self) -> Dict[str, Hypothesis]:
return self._data
def add(self, hyp: Hypothesis) -> None:
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
`log-sum-exp` with the existed one.
Args:
hyp:
The hypothesis to be added.
"""
key = hyp.key
if key in self:
old_hyp = self._data[key] # shallow copy
torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
else:
self._data[key] = hyp
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
"""Get the most probable hypothesis, i.e., the one with
the largest `log_prob`.
Args:
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
number of tokens in it.
Returns:
Return the hypothesis that has the largest `log_prob`.
"""
if length_norm:
return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
else:
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
Caution:
`self` is modified **in-place**.
Args:
hyp:
The hypothesis to be removed from `self`.
Note: It must be contained in `self`. Otherwise,
an exception is raised.
"""
key = hyp.key
assert key in self, f"{key} does not exist"
del self._data[key]
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
Caution:
`self` is not modified. Instead, a new HypothesisList is returned.
Returns:
Return a new HypothesisList containing all hypotheses from `self`
with `log_prob` being greater than the given `threshold`.
"""
ans = HypothesisList()
for _, hyp in self._data.items():
if hyp.log_prob > threshold:
ans.add(hyp) # shallow copy
return ans
def topk(self, k: int) -> "HypothesisList":
"""Return the top-k hypothesis."""
hyps = list(self._data.items())
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
ans = HypothesisList(dict(hyps))
return ans
def __contains__(self, key: str):
return key in self._data
def __iter__(self):
return iter(self._data.values())
def __len__(self) -> int:
return len(self._data)
def __str__(self) -> str:
s = []
for key in self:
s.append(key)
return ", ".join(s)
def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps].
Args:
hyps:
len(hyps) == batch_size. It contains the current hypothesis for
each utterance in the batch.
Returns:
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
the shape is on CPU.
"""
num_hyps = [len(h) for h in hyps]
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
# to get exclusive sum later.
num_hyps.insert(0, 0)
num_hyps = torch.tensor(num_hyps)
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
ans = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
)
return ans
def modified_beam_search(
model: torch.nn.Module,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4,
temperature: float = 1.0,
return_timestamps: bool = False,
) -> Union[List[List[int]], DecodingResults]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam:
Number of active paths during the beam search.
temperature:
Softmax temperature.
return_timestamps:
Whether to return timestamps.
Returns:
If return_timestamps is False, return the decoded result.
Else, return a DecodingResults object containing
decoded result and corresponding timestamps.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
timestamp=[],
)
)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
finalized_B = []
for (t, batch_size) in enumerate(batch_size_list):
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = get_hyps_shape(B).to(device)
A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]
ys_log_probs = torch.cat(
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
) # (num_hyps, 1)
decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)
logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)
row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
new_timestamp = hyp.timestamp[:]
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_timestamp.append(t)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
)
B[i].add(new_hyp)
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
sorted_timestamps = [h.timestamp for h in best_hyps]
ans = []
ans_timestamps = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
if not return_timestamps:
return ans
else:
return DecodingResults(
hyps=ans,
timestamps=ans_timestamps,
)

View File

@ -0,0 +1,240 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file merge overlapped chunks into utterances accroding to recording ids.
"""
import argparse
import logging
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from typing import List
import sentencepiece as spm
from lhotse import (
CutSet,
MonoCut,
SupervisionSegment,
SupervisionSet,
load_manifest,
load_manifest_lazy,
)
from lhotse.cut import Cut
from lhotse.serialization import SequentialJsonlWriter
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--manifest-in-dir",
type=Path,
default=Path("data/librilight/manifests_chunk_recog"),
help="Path to directory of chunk cuts with recognition results.",
)
parser.add_argument(
"--manifest-out-dir",
type=Path,
default=Path("data/manifests"),
help="Path to directory to save full utterance by merging overlapped chunks.",
)
parser.add_argument(
"--extra",
type=float,
default=2.0,
help="""Extra duration (in seconds) at both sides.""",
)
return parser.parse_args()
def merge_chunks(
cuts_chunk: CutSet,
supervisions: SupervisionSet,
cuts_writer: SequentialJsonlWriter,
sp: spm.SentencePieceProcessor,
extra: float,
) -> int:
"""Merge chunk-wise cuts accroding to recording ids.
Args:
cuts_chunk:
The chunk-wise cuts opened in a lazy mode.
supervisions:
The supervision manifest containing text file path, opened in a lazy mode.
cuts_writer:
Writer to save the cuts with recognition results.
sp:
The BPE model.
extra:
Extra duration (in seconds) to drop at both sides of each chunk.
"""
# Background worker to add alignemnt and save cuts to disk.
def _save_worker(utt_cut: Cut, flush=False):
cuts_writer.write(utt_cut, flush=flush)
def _merge(cut_list: List[Cut], rec_id: str, utt_idx: int):
"""Merge chunks with same recording_id."""
for cut in cut_list:
assert cut.recording.id == rec_id, (cut.recording.id, rec_id)
# For each group with a same recording, sort it accroding to the start time
# In fact, we don't need to do this since the cuts have been sorted
# according to the start time
cut_list = sorted(cut_list, key=(lambda cut: cut.start))
rec = cut_list[0].recording
alignments = []
cur_end = 0
for cut in cut_list:
# Get left and right borders
left = cut.start + extra if cut.start > 0 else 0
chunk_end = cut.start + cut.duration
right = chunk_end - extra if chunk_end < rec.duration else rec.duration
# Assert the chunks are continuous
assert left == cur_end, (left, cur_end)
cur_end = right
assert len(cut.supervisions) == 1, len(cut.supervisions)
for ali in cut.supervisions[0].alignment["symbol"]:
t = ali.start + cut.start
if left <= t < right:
alignments.append(ali.with_offset(cut.start))
old_sup = supervisions[rec_id]
# Assuming the supervisions are sorted with the same recoding order as in cuts_chunk
# old_sup = supervisions[utt_idx]
assert old_sup.recording_id == rec_id, (old_sup.recording_id, rec_id)
new_sup = SupervisionSegment(
id=rec_id,
recording_id=rec_id,
start=0,
duration=rec.duration,
alignment={"symbol": alignments},
language=old_sup.language,
speaker=old_sup.speaker,
)
utt_cut = MonoCut(
id=rec_id,
start=0,
duration=rec.duration,
channel=0,
recording=rec,
supervisions=[new_sup],
)
# Set a custom attribute to the cut
utt_cut.text_path = old_sup.book
return utt_cut
last_rec_id = None
cut_list = []
utt_idx = 0
futures = []
with ThreadPoolExecutor(max_workers=1) as executor:
for cut in cuts_chunk:
cur_rec_id = cut.recording.id
if len(cut_list) == 0:
# Case of the first cut
last_rec_id = cur_rec_id
cut_list.append(cut)
elif cur_rec_id == last_rec_id:
cut_list.append(cut)
else:
# Case of a cut belonging to a new recording
utt_cut = _merge(cut_list, last_rec_id, utt_idx)
utt_idx += 1
futures.append(executor.submit(_save_worker, utt_cut))
last_rec_id = cur_rec_id
cut_list = [cut]
if utt_idx % 5000 == 0:
logging.info(f"Procesed {utt_idx} utterances.")
# For the cuts belonging to the last recording
if len(cut_list) != 0:
utt_cut = _merge(cut_list, last_rec_id, utt_idx)
utt_idx += 1
futures.append(executor.submit(_save_worker, utt_cut))
logging.info("Finished")
for f in futures:
f.result()
return utt_idx
def main():
args = get_parser()
sp = spm.SentencePieceProcessor()
sp.load(args.bpe_model)
# It contains "librilight_recordings_*.jsonl.gz" and "librilight_supervisions_small.jsonl.gz"
manifest_out_dir = args.manifest_out_dir
subsets = ["small", "median", "large"]
for subset in subsets:
logging.info(f"Processing {subset} subset")
manifest_out = manifest_out_dir / f"librilight_cuts_{subset}.jsonl.gz"
if manifest_out.is_file():
logging.info(f"{manifest_out} already exists - skipping.")
continue
supervisions = load_manifest(
manifest_out_dir / f"librilight_supervisions_{subset}.jsonl.gz"
) # We will use the text path from supervisions
cuts_chunk = load_manifest_lazy(
args.manifest_in_dir / f"librilight_cuts_{subset}.jsonl.gz"
)
cuts_writer = CutSet.open_writer(manifest_out, overwrite=True)
num_utt = merge_chunks(
cuts_chunk, supervisions, cuts_writer=cuts_writer, sp=sp, extra=args.extra
)
cuts_writer.close()
logging.info(f"{num_utt} cuts saved to {manifest_out}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -0,0 +1,435 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script loads torchscript models, exported by `torch.jit.script()`,
and uses them to decode waves.
You can use the following command to get the exported models:
./pruned_transducer_stateless7/export.py \
--exp-dir ./pruned_transducer_stateless7/exp \
--bpe-model data/lang_bpe_500/bpe.model \
--epoch 20 \
--avg 10 \
--jit 1
You can also download the jit model from
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
"""
import argparse
import torch.multiprocessing as mp
import torch
import torch.nn as nn
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import List, Optional, Tuple
from pathlib import Path
import k2
import sentencepiece as spm
from asr_datamodule import AsrDataModule
from beam_search import (
fast_beam_search_one_best,
greedy_search_batch,
modified_beam_search,
)
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
from lhotse import CutSet, load_manifest_lazy
from lhotse.cut import Cut
from lhotse.supervision import AlignmentItem
from lhotse.serialization import SequentialJsonlWriter
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--world-size",
type=int,
default=1,
help="Number of GPUs for DDP training.",
)
parser.add_argument(
"--master-port",
type=int,
default=12354,
help="Master port to use for DDP training.",
)
parser.add_argument(
"--subset",
type=str,
default="small",
help="Subset to process. Possible values are 'small', 'medium', 'large'",
)
parser.add_argument(
"--manifest-in-dir",
type=Path,
default=Path("data/librilight/manifests_chunk"),
help="Path to directory with chunks cuts.",
)
parser.add_argument(
"--manifest-out-dir",
type=Path,
default=Path("data/librilight/manifests_chunk_recog"),
help="Path to directory to save the chunk cuts with recognition results.",
)
parser.add_argument(
"--log-dir",
type=Path,
default=Path("long_file_recog/log"),
help="Path to directory to save logs.",
)
parser.add_argument(
"--nn-model-filename",
type=str,
required=True,
help="Path to the torchscript model cpu_jit.pt",
)
parser.add_argument(
"--bpe-model",
type=str,
default="data/lang_bpe_500/bpe.model",
help="Path to the BPE model",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
""",
)
return parser
def get_params() -> AttributeDict:
"""Return a dict containing decoding parameters."""
params = AttributeDict(
{
"subsampling_factor": 4,
"frame_shift_ms": 10,
# Used only when --method is beam_search or modified_beam_search.
"beam_size": 4,
# Used only when --method is beam_search or fast_beam_search.
# A floating point value to calculate the cutoff score during beam
# search (i.e., `cutoff = max-score - beam`), which is the same as the
# `beam` in Kaldi.
"beam": 4,
"max_contexts": 4,
"max_states": 8,
}
)
return params
def decode_one_batch(
params: AttributeDict,
model: nn.Module,
batch: dict,
decoding_graph: Optional[k2.Fsa] = None,
) -> Tuple[List[List[str]], List[List[float]], List[List[float]]]:
"""Decode one batch.
Args:
params:
It's the return value of :func:`get_params`.
paramsmodel:
The neural model.
batch:
It is the return value from iterating
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
for the format of the `batch`.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return the decoding result, timestamps, and scores.
"""
device = next(model.parameters()).device
feature = batch["inputs"]
assert feature.ndim == 3
feature = feature.to(device)
# at entry, feature is (N, T, C)
supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device)
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
if params.decoding_method == "fast_beam_search":
res = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
return_timestamps=True,
)
elif params.decoding_method == "greedy_search":
res = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
return_timestamps=True,
)
elif params.decoding_method == "modified_beam_search":
res = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
return_timestamps=True,
)
else:
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
hyps = []
timestamps = []
scores = []
for i in range(feature.shape[0]):
hyps.append(res.hyps[i])
timestamps.append(
convert_timestamp(
res.timestamps[i], params.subsampling_factor, params.frame_shift_ms
)
)
scores.append(res.scores[i])
return hyps, timestamps, scores
def decode_dataset(
dl: torch.utils.data.DataLoader,
params: AttributeDict,
model: nn.Module,
sp: spm.SentencePieceProcessor,
cuts_writer: SequentialJsonlWriter,
decoding_graph: Optional[k2.Fsa] = None,
) -> None:
"""Decode dataset and store the recognition results to manifest.
Args:
dl:
PyTorch's dataloader containing the dataset to decode.
params:
It is returned by :func:`get_params`.
model:
The neural model.
sp:
The BPE model.
cuts_writer:
Writer to save the cuts with recognition results.
decoding_graph:
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
only when --decoding_method is fast_beam_search.
Returns:
Return a dict, whose key may be "greedy_search" if greedy search
is used, or it may be "beam_7" if beam size of 7 is used.
Its value is a list of tuples. Each tuple contains five elements:
- cut_id
- reference transcript
- predicted result
- timestamps of reference transcript
- timestamps of predicted result
"""
# Background worker to add alignemnt and save cuts to disk.
def _save_worker(
cuts: List[Cut],
hyps: List[List[str]],
timestamps: List[List[float]],
scores: List[List[float]],
):
for cut, symbol_list, time_list, score_list in zip(
cuts, hyps, timestamps, scores
):
symbol_list = sp.id_to_piece(symbol_list)
ali = [
AlignmentItem(symbol=symbol, start=start, duration=None, score=score)
for symbol, start, score in zip(symbol_list, time_list, score_list)
]
assert len(cut.supervisions) == 1, len(cut.supervisions)
cut.supervisions[0].alignment = {"symbol": ali}
cuts_writer.write(cut, flush=True)
num_cuts = 0
log_interval = 10
futures = []
with ThreadPoolExecutor(max_workers=1) as executor:
# We only want one background worker so that serialization is deterministic.
for batch_idx, batch in enumerate(dl):
cuts = batch["supervisions"]["cut"]
hyps, timestamps, scores = decode_one_batch(
params=params,
model=model,
decoding_graph=decoding_graph,
batch=batch,
)
futures.append(
executor.submit(_save_worker, cuts, hyps, timestamps, scores)
)
num_cuts += len(cuts)
if batch_idx % log_interval == 0:
logging.info(f"cuts processed until now is {num_cuts}")
for f in futures:
f.result()
@torch.no_grad()
def run(rank, world_size, args, in_cuts):
"""
Args:
rank:
It is a value between 0 and `world_size-1`.
world_size:
Number of GPUs for DDP training.
args:
The return value of get_parser().parse_args()
"""
params = get_params()
params.update(vars(args))
setup_logger(f"{params.log_dir}/log-decode")
logging.info("Decoding started")
assert params.decoding_method in (
"greedy_search",
"fast_beam_search",
"modified_beam_search",
), params.decoding_method
sp = spm.SentencePieceProcessor()
sp.load(params.bpe_model)
# <blk> is defined in local/train_bpe_model.py
params.blank_id = sp.piece_to_id("<blk>")
params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size()
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", rank)
logging.info(f"device: {device}")
logging.info("Loading jit model")
model = torch.jit.load(params.nn_model_filename)
model.to(device)
model.eval()
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
# we will store new cuts with recognition results.
args.return_cuts = True
asr_data_module = AsrDataModule(args)
if world_size > 1:
in_cuts = in_cuts[rank]
out_cuts_filename = params.manifest_out_dir / (
f"{params.cuts_filename}_job_{rank}" + params.suffix
)
else:
out_cuts_filename = params.manifest_out_dir / (
f"{params.cuts_filename}" + params.suffix
)
dl = asr_data_module.dataloaders(in_cuts)
cuts_writer = CutSet.open_writer(out_cuts_filename, overwrite=True)
decode_dataset(
dl=dl,
params=params,
model=model,
sp=sp,
decoding_graph=decoding_graph,
cuts_writer=cuts_writer,
)
cuts_writer.close()
logging.info(f"Cuts saved to {out_cuts_filename}")
logging.info("Done!")
def main():
parser = get_parser()
AsrDataModule.add_arguments(parser)
args = parser.parse_args()
subset = args.subset
assert subset in ["small", "medium", "large"], subset
manifest_out_dir = args.manifest_out_dir
manifest_out_dir.mkdir(parents=True, exist_ok=True)
args.suffix = ".jsonl.gz"
args.cuts_filename = f"librilight_cuts_{args.subset}"
out_cuts_filename = manifest_out_dir / (args.cuts_filename + args.suffix)
if out_cuts_filename.is_file():
logging.info(f"{out_cuts_filename} already exists - skipping.")
return
in_cuts_filename = args.manifest_in_dir / (args.cuts_filename + args.suffix)
in_cuts = load_manifest_lazy(in_cuts_filename)
world_size = args.world_size
assert world_size >= 1
if world_size > 1:
chunk_size = (len(in_cuts) + (world_size - 1)) // world_size
# Each manifest is saved at: ``{output_dir}/{prefix}.{split_idx}.jsonl.gz``
splits = in_cuts.split_lazy(
output_dir=args.manifest_in_dir / "split",
chunk_size=chunk_size,
prefix=args.cuts_filename,
)
assert len(splits) == world_size, (len(splits), world_size)
mp.spawn(run, args=(world_size, args, splits), nprocs=world_size, join=True)
else:
run(rank=0, world_size=world_size, args=args, in_cuts=in_cuts)
torch.set_num_threads(1)
torch.set_num_interop_threads(1)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,100 @@
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script splits long utterances into chunks with overlaps.
Each chunk (except the first and the last) is padded with extra left side and right side.
The chunk length is: left_side + chunk_size + right_side.
"""
import argparse
import logging
from pathlib import Path
from lhotse import CutSet, load_manifest
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--manifest-in-dir",
type=Path,
default=Path("data/librilight/manifests"),
help="Path to directory of full utterances.",
)
parser.add_argument(
"--manifest-out-dir",
type=Path,
default=Path("data/librilight/manifests_chunk"),
help="Path to directory to save splitted chunks.",
)
parser.add_argument(
"--chunk",
type=float,
default=300.0,
help="""Duration (in seconds) of each chunk.""",
)
parser.add_argument(
"--extra",
type=float,
default=2.0,
help="""Extra duration (in seconds) at both sides.""",
)
return parser.parse_args()
def main():
args = get_args()
logging.info(vars(args))
manifest_out_dir = args.manifest_out_dir
manifest_out_dir.mkdir(parents=True, exist_ok=True)
subsets = ["small", "medium", "large"]
for subset in subsets:
logging.info(f"Processing {subset} subset")
manifest_out = manifest_out_dir / f"librilight_cuts_{subset}.jsonl.gz"
if manifest_out.is_file():
logging.info(f"{manifest_out} already exists - skipping.")
continue
manifest_in = args.manifest_in_dir / f"librilight_recordings_{subset}.jsonl.gz"
recordings = load_manifest(manifest_in)
cuts = CutSet.from_manifests(recordings=recordings)
cuts = cuts.cut_into_windows(
duration=args.chunk, hop=args.chunk - args.extra * 2
)
cuts = cuts.fill_supervisions()
cuts.to_file(manifest_out)
logging.info(f"Cuts saved to {manifest_out}")
if __name__ == "__main__":
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -670,6 +670,8 @@ def greedy_search_batch(
# timestamp[n][i] is the frame index after subsampling
# on which hyp[n][i] is decoded
timestamps = [[] for _ in range(N)]
# scores[n][i] is the logits on which hyp[n][i] is decoded
scores = [[] for _ in range(N)]
decoder_input = torch.tensor(
hyps,
@ -707,6 +709,7 @@ def greedy_search_batch(
if v not in (blank_id, unk_id):
hyps[i].append(v)
timestamps[i].append(t)
scores[i].append(logits[i, v].item())
emitted = True
if emitted:
# update decoder output
@ -722,10 +725,12 @@ def greedy_search_batch(
sorted_ans = [h[context_size:] for h in hyps]
ans = []
ans_timestamps = []
ans_scores = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
ans_timestamps.append(timestamps[unsorted_indices[i]])
ans_scores.append(scores[unsorted_indices[i]])
if not return_timestamps:
return ans
@ -733,6 +738,7 @@ def greedy_search_batch(
return DecodingResults(
hyps=ans,
timestamps=ans_timestamps,
scores=ans_scores,
)

View File

@ -272,6 +272,9 @@ class DecodingResults:
# for the i-th utterance with fast_beam_search_nbest_LG.
hyps: Union[List[List[int]], k2.RaggedTensor]
# scores[i][k] contains the log-prob of tokens[i][k]
scores: Optional[List[List[float]]] = None
def get_texts_with_timestamp(
best_paths: k2.Fsa, return_ragged: bool = False
@ -1442,7 +1445,7 @@ def convert_timestamp(
frame_shift = frame_shift_ms / 1000.0
time = []
for f in frames:
time.append(f * subsampling_factor * frame_shift)
time.append(round(f * subsampling_factor * frame_shift, ndigits=3))
return time