mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Support long audios recognition (#980)
* support long file transcription * rename recipe as long_file_recog * add docs * support multi-gpu decoding * style fix
This commit is contained in:
parent
f18b539fbc
commit
a7e142b7ff
94
egs/librispeech/ASR/long_file_recog.sh
Executable file
94
egs/librispeech/ASR/long_file_recog.sh
Executable file
@ -0,0 +1,94 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674
|
||||||
|
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
|
||||||
|
|
||||||
|
export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||||
|
|
||||||
|
set -eou pipefail
|
||||||
|
|
||||||
|
# This script is used to recogize long audios. The process is as follows:
|
||||||
|
# 1) Split long audios into chunks with overlaps.
|
||||||
|
# 2) Perform speech recognition on chunks, getting tokens and timestamps.
|
||||||
|
# 3) Merge the overlapped chunks into utterances acording to the timestamps.
|
||||||
|
|
||||||
|
# Each chunk (except the first and the last) is padded with extra left side and right side.
|
||||||
|
# The chunk length is: left_side + chunk_size + right_side.
|
||||||
|
chunk=30.0
|
||||||
|
extra=2.0
|
||||||
|
|
||||||
|
stage=1
|
||||||
|
stop_stage=4
|
||||||
|
|
||||||
|
# We assume that you have downloaded the LibriLight dataset
|
||||||
|
# with audio files in $corpus_dir and texts in $text_dir
|
||||||
|
corpus_dir=$PWD/download/libri-light
|
||||||
|
text_dir=$PWD/download/librilight_text
|
||||||
|
# Path to save the manifests
|
||||||
|
output_dir=$PWD/data/librilight
|
||||||
|
|
||||||
|
world_size=4
|
||||||
|
|
||||||
|
|
||||||
|
log() {
|
||||||
|
# This function is from espnet
|
||||||
|
local fname=${BASH_SOURCE[1]##*/}
|
||||||
|
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||||
|
# We will get librilight_recodings_{subset}.jsonl.gz and librilight_supervisions_{subset}.jsonl.gz
|
||||||
|
# saved in $output_dir/manifests
|
||||||
|
log "Stage 1: Prepare LibriLight manifest"
|
||||||
|
lhotse prepare librilight $corpus_dir $text_dir $output_dir/manifests -j 10
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||||
|
# Chunk manifests are saved to $output_dir/manifests_chunk/librilight_cuts_{subset}.jsonl.gz
|
||||||
|
log "Stage 2: Split long audio into chunks"
|
||||||
|
./long_file_recog/split_into_chunks.py \
|
||||||
|
--manifest-in-dir $output_dir/manifests \
|
||||||
|
--manifest-out-dir $output_dir/manifests_chunk \
|
||||||
|
--chunk $chunk \
|
||||||
|
--extra $extra # Extra duration (in seconds) at both sides
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||||
|
# Recognized tokens and timestamps are saved to $output_dir/manifests_chunk_recog/librilight_cuts_{subset}.jsonl.gz
|
||||||
|
|
||||||
|
# This script loads torchscript models, exported by `torch.jit.script()`,
|
||||||
|
# and uses it to decode waves.
|
||||||
|
# You can download the jit model from https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||||
|
|
||||||
|
log "Stage 3: Perform speech recognition on splitted chunks"
|
||||||
|
for subset in small median large; do
|
||||||
|
./long_file_recog/recognize.py \
|
||||||
|
--world-size $world_size \
|
||||||
|
--num-workers 8 \
|
||||||
|
--subset $subset \
|
||||||
|
--manifest-in-dir $output_dir/manifests_chunk \
|
||||||
|
--manifest-out-dir $output_dir/manifests_chunk_recog \
|
||||||
|
--nn-model-filename long_file_recog/exp/jit_model.pt \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--max-duration 2400 \
|
||||||
|
--decoding-method greedy_search
|
||||||
|
--master 12345
|
||||||
|
|
||||||
|
if [ $world_size -gt 1 ]; then
|
||||||
|
# Combine manifests from different jobs
|
||||||
|
lhotse combine $(find $output_dir/manifests_chunk_recog -name librilight_cuts_${subset}_job_*.jsonl.gz | tr "\n" " ") $output_dir/manifests_chunk_recog/librilight_cuts_${subset}.jsonl.gz
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||||
|
# Final results are saved in $output_dir/manifests/librilight_cuts_{subset}.jsonl.gz
|
||||||
|
log "Stage 4: Merge splitted chunks into utterances."
|
||||||
|
./long_file_recog/merge_chunks.py \
|
||||||
|
--manifest-in-dir $output_dir/manifests_chunk_recog \
|
||||||
|
--manifest-out-dir $output_dir/manifests \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--extra $extra
|
||||||
|
fi
|
||||||
|
|
||||||
|
|
189
egs/librispeech/ASR/long_file_recog/asr_datamodule.py
Normal file
189
egs/librispeech/ASR/long_file_recog/asr_datamodule.py
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
# Copyright 2021 Piotr Żelasko
|
||||||
|
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo)
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy
|
||||||
|
from lhotse.cut import Cut
|
||||||
|
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
|
||||||
|
CutConcatenate,
|
||||||
|
CutMix,
|
||||||
|
DynamicBucketingSampler,
|
||||||
|
K2SpeechRecognitionDataset,
|
||||||
|
PrecomputedFeatures,
|
||||||
|
SimpleCutSampler,
|
||||||
|
SpecAugment,
|
||||||
|
)
|
||||||
|
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples
|
||||||
|
AudioSamples,
|
||||||
|
BatchIO,
|
||||||
|
OnTheFlyFeatures,
|
||||||
|
)
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from icefall.utils import str2bool
|
||||||
|
|
||||||
|
|
||||||
|
class SpeechRecognitionDataset(K2SpeechRecognitionDataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
return_cuts: bool = False,
|
||||||
|
input_strategy: BatchIO = PrecomputedFeatures(),
|
||||||
|
):
|
||||||
|
super().__init__(return_cuts=return_cuts, input_strategy=input_strategy)
|
||||||
|
|
||||||
|
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[Cut]]]:
|
||||||
|
"""
|
||||||
|
Return a new batch, with the batch size automatically determined using the constraints
|
||||||
|
of max_frames and max_cuts.
|
||||||
|
"""
|
||||||
|
self.hdf5_fix.update()
|
||||||
|
|
||||||
|
# Note: don't sort cuts here
|
||||||
|
# Sort the cuts by duration so that the first one determines the batch time dimensions.
|
||||||
|
# cuts = cuts.sort_by_duration(ascending=False)
|
||||||
|
|
||||||
|
# Get a tensor with batched feature matrices, shape (B, T, F)
|
||||||
|
# Collation performs auto-padding, if necessary.
|
||||||
|
input_tpl = self.input_strategy(cuts)
|
||||||
|
if len(input_tpl) == 3:
|
||||||
|
# An input strategy with fault tolerant audio reading mode.
|
||||||
|
# "cuts" may be a subset of the original "cuts" variable,
|
||||||
|
# that only has cuts for which we succesfully read the audio.
|
||||||
|
inputs, _, cuts = input_tpl
|
||||||
|
else:
|
||||||
|
inputs, _ = input_tpl
|
||||||
|
|
||||||
|
# Get a dict of tensors that encode the positional information about supervisions
|
||||||
|
# in the batch of feature matrices. The tensors are named "sequence_idx",
|
||||||
|
# "start_frame/sample" and "num_frames/samples".
|
||||||
|
supervision_intervals = self.input_strategy.supervision_intervals(cuts)
|
||||||
|
|
||||||
|
batch = {"inputs": inputs, "supervisions": supervision_intervals}
|
||||||
|
if self.return_cuts:
|
||||||
|
batch["supervisions"]["cut"] = [cut for cut in cuts]
|
||||||
|
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
class AsrDataModule:
|
||||||
|
"""
|
||||||
|
DataModule for k2 ASR experiments.
|
||||||
|
It assumes there is always one train and valid dataloader,
|
||||||
|
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean
|
||||||
|
and test-other).
|
||||||
|
|
||||||
|
It contains all the common data pipeline modules used in ASR
|
||||||
|
experiments, e.g.:
|
||||||
|
- dynamic batch size,
|
||||||
|
- bucketing samplers,
|
||||||
|
- cut concatenation,
|
||||||
|
- augmentation,
|
||||||
|
- on-the-fly feature extraction
|
||||||
|
|
||||||
|
This class should be derived for specific corpora used in ASR tasks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, args: argparse.Namespace):
|
||||||
|
self.args = args
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_arguments(cls, parser: argparse.ArgumentParser):
|
||||||
|
group = parser.add_argument_group(
|
||||||
|
title="ASR data related options",
|
||||||
|
description="These options are used for the preparation of "
|
||||||
|
"PyTorch DataLoaders from Lhotse CutSet's -- they control the "
|
||||||
|
"effective batch sizes, sampling strategies, applied data "
|
||||||
|
"augmentations, etc.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--manifest-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/manifests_chunk"),
|
||||||
|
help="Path to directory with train/valid/test cuts.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--max-duration",
|
||||||
|
type=int,
|
||||||
|
default=600.0,
|
||||||
|
help="Maximum pooled recordings duration (seconds) in a "
|
||||||
|
"single batch. You can reduce it if it causes CUDA OOM.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--bucketing-sampler",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, the batches will come from buckets of "
|
||||||
|
"similar duration (saves padding frames).",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--return-cuts",
|
||||||
|
type=str2bool,
|
||||||
|
default=True,
|
||||||
|
help="When enabled, each batch will have the "
|
||||||
|
"field: batch['supervisions']['cut'] with the cuts that "
|
||||||
|
"were used to construct it.",
|
||||||
|
)
|
||||||
|
group.add_argument(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
default=8,
|
||||||
|
help="The number of training dataloader workers that "
|
||||||
|
"collect the batches.",
|
||||||
|
)
|
||||||
|
|
||||||
|
group.add_argument(
|
||||||
|
"--input-strategy",
|
||||||
|
type=str,
|
||||||
|
default="PrecomputedFeatures",
|
||||||
|
help="AudioSamples or PrecomputedFeatures",
|
||||||
|
)
|
||||||
|
|
||||||
|
def dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
|
logging.debug("About to create test dataset")
|
||||||
|
test = SpeechRecognitionDataset(
|
||||||
|
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))),
|
||||||
|
return_cuts=self.args.return_cuts,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampler = SimpleCutSampler(
|
||||||
|
cuts,
|
||||||
|
max_duration=self.args.max_duration,
|
||||||
|
shuffle=False,
|
||||||
|
drop_last=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging.debug("About to create test dataloader")
|
||||||
|
test_dl = DataLoader(
|
||||||
|
test,
|
||||||
|
batch_size=None,
|
||||||
|
sampler=sampler,
|
||||||
|
num_workers=self.args.num_workers,
|
||||||
|
persistent_workers=False,
|
||||||
|
)
|
||||||
|
return test_dl
|
||||||
|
|
||||||
|
@lru_cache()
|
||||||
|
def load_subset(self, cuts_filename: Path) -> CutSet:
|
||||||
|
return load_manifest_lazy(cuts_filename)
|
613
egs/librispeech/ASR/long_file_recog/beam_search.py
Normal file
613
egs/librispeech/ASR/long_file_recog/beam_search.py
Normal file
@ -0,0 +1,613 @@
|
|||||||
|
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
|
||||||
|
# Xiaoyu Yang)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from icefall.decode import one_best_decoding
|
||||||
|
from icefall.utils import DecodingResults, get_texts, get_texts_with_timestamp
|
||||||
|
|
||||||
|
|
||||||
|
def fast_beam_search(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
decoding_graph: k2.Fsa,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: float,
|
||||||
|
max_states: int,
|
||||||
|
max_contexts: int,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
) -> k2.Fsa:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
decoding_graph:
|
||||||
|
Decoding graph used for decoding, may be a TrivialGraph or a LG.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
|
encoder_out_lens:
|
||||||
|
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||||
|
before padding.
|
||||||
|
beam:
|
||||||
|
Beam value, similar to the beam used in Kaldi..
|
||||||
|
max_states:
|
||||||
|
Max states per stream per frame.
|
||||||
|
max_contexts:
|
||||||
|
Max contexts pre stream per frame.
|
||||||
|
temperature:
|
||||||
|
Softmax temperature.
|
||||||
|
Returns:
|
||||||
|
Return an FsaVec with axes [utt][state][arc] containing the decoded
|
||||||
|
lattice. Note: When the input graph is a TrivialGraph, the returned
|
||||||
|
lattice is actually an acceptor.
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
vocab_size = model.decoder.vocab_size
|
||||||
|
|
||||||
|
B, T, C = encoder_out.shape
|
||||||
|
|
||||||
|
config = k2.RnntDecodingConfig(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
decoder_history_len=context_size,
|
||||||
|
beam=beam,
|
||||||
|
max_contexts=max_contexts,
|
||||||
|
max_states=max_states,
|
||||||
|
)
|
||||||
|
individual_streams = []
|
||||||
|
for i in range(B):
|
||||||
|
individual_streams.append(k2.RnntDecodingStream(decoding_graph))
|
||||||
|
decoding_streams = k2.RnntDecodingStreams(individual_streams, config)
|
||||||
|
|
||||||
|
encoder_out = model.joiner.encoder_proj(encoder_out)
|
||||||
|
|
||||||
|
for t in range(T):
|
||||||
|
# shape is a RaggedShape of shape (B, context)
|
||||||
|
# contexts is a Tensor of shape (shape.NumElements(), context_size)
|
||||||
|
shape, contexts = decoding_streams.get_contexts()
|
||||||
|
# `nn.Embedding()` in torch below v1.7.1 supports only torch.int64
|
||||||
|
contexts = contexts.to(torch.int64)
|
||||||
|
# decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim)
|
||||||
|
decoder_out = model.decoder(contexts, need_pad=False)
|
||||||
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
# current_encoder_out is of shape
|
||||||
|
# (shape.NumElements(), 1, joiner_dim)
|
||||||
|
# fmt: off
|
||||||
|
current_encoder_out = torch.index_select(
|
||||||
|
encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64)
|
||||||
|
)
|
||||||
|
# fmt: on
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out.unsqueeze(2),
|
||||||
|
decoder_out.unsqueeze(1),
|
||||||
|
project_input=False,
|
||||||
|
)
|
||||||
|
logits = logits.squeeze(1).squeeze(1)
|
||||||
|
log_probs = (logits / temperature).log_softmax(dim=-1)
|
||||||
|
decoding_streams.advance(log_probs)
|
||||||
|
decoding_streams.terminate_and_flush_to_streams()
|
||||||
|
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
|
||||||
|
|
||||||
|
return lattice
|
||||||
|
|
||||||
|
|
||||||
|
def fast_beam_search_one_best(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
decoding_graph: k2.Fsa,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: float,
|
||||||
|
max_states: int,
|
||||||
|
max_contexts: int,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
|
"""It limits the maximum number of symbols per frame to 1.
|
||||||
|
|
||||||
|
A lattice is first obtained using fast beam search, and then
|
||||||
|
the shortest path within the lattice is used as the final output.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
An instance of `Transducer`.
|
||||||
|
decoding_graph:
|
||||||
|
Decoding graph used for decoding, may be a TrivialGraph or a LG.
|
||||||
|
encoder_out:
|
||||||
|
A tensor of shape (N, T, C) from the encoder.
|
||||||
|
encoder_out_lens:
|
||||||
|
A tensor of shape (N,) containing the number of frames in `encoder_out`
|
||||||
|
before padding.
|
||||||
|
beam:
|
||||||
|
Beam value, similar to the beam used in Kaldi..
|
||||||
|
max_states:
|
||||||
|
Max states per stream per frame.
|
||||||
|
max_contexts:
|
||||||
|
Max contexts pre stream per frame.
|
||||||
|
temperature:
|
||||||
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
|
Returns:
|
||||||
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
|
"""
|
||||||
|
lattice = fast_beam_search(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=beam,
|
||||||
|
max_states=max_states,
|
||||||
|
max_contexts=max_contexts,
|
||||||
|
temperature=temperature,
|
||||||
|
)
|
||||||
|
|
||||||
|
best_path = one_best_decoding(lattice)
|
||||||
|
|
||||||
|
if not return_timestamps:
|
||||||
|
return get_texts(best_path)
|
||||||
|
else:
|
||||||
|
return get_texts_with_timestamp(best_path)
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_search_batch(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
|
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
Output from the encoder. Its shape is (N, T, C), where N >= 1.
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||||
|
encoder_out before padding.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
|
Returns:
|
||||||
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
|
||||||
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)]
|
||||||
|
|
||||||
|
# timestamp[n][i] is the frame index after subsampling
|
||||||
|
# on which hyp[n][i] is decoded
|
||||||
|
timestamps = [[] for _ in range(N)]
|
||||||
|
# scores[n][i] is the logits on which hyp[n][i] is decoded
|
||||||
|
scores = [[] for _ in range(N)]
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
hyps,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (N, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
# decoder_out: (N, 1, decoder_out_dim)
|
||||||
|
|
||||||
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
for (t, batch_size) in enumerate(batch_size_list):
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
|
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
decoder_out = decoder_out[:batch_size]
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
|
||||||
|
)
|
||||||
|
# logits'shape (batch_size, 1, 1, vocab_size)
|
||||||
|
|
||||||
|
logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size)
|
||||||
|
log_probs = logits.log_softmax(dim=-1)
|
||||||
|
assert log_probs.ndim == 2, log_probs.shape
|
||||||
|
y = log_probs.argmax(dim=1).tolist()
|
||||||
|
emitted = False
|
||||||
|
for i, v in enumerate(y):
|
||||||
|
if v not in (blank_id, unk_id):
|
||||||
|
hyps[i].append(v)
|
||||||
|
timestamps[i].append(t)
|
||||||
|
scores[i].append(log_probs[i, v].item())
|
||||||
|
emitted = True
|
||||||
|
if emitted:
|
||||||
|
# update decoder output
|
||||||
|
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
decoder_input,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
)
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||||
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
|
||||||
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
|
ans = []
|
||||||
|
ans_timestamps = []
|
||||||
|
ans_scores = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
ans_timestamps.append(timestamps[unsorted_indices[i]])
|
||||||
|
ans_scores.append(scores[unsorted_indices[i]])
|
||||||
|
|
||||||
|
if not return_timestamps:
|
||||||
|
return ans
|
||||||
|
else:
|
||||||
|
return DecodingResults(
|
||||||
|
hyps=ans,
|
||||||
|
timestamps=ans_timestamps,
|
||||||
|
scores=ans_scores,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Hypothesis:
|
||||||
|
# The predicted tokens so far.
|
||||||
|
# Newly predicted tokens are appended to `ys`.
|
||||||
|
ys: List[int]
|
||||||
|
|
||||||
|
# The log prob of ys.
|
||||||
|
# It contains only one entry.
|
||||||
|
log_prob: torch.Tensor
|
||||||
|
|
||||||
|
# timestamp[i] is the frame index after subsampling
|
||||||
|
# on which ys[i] is decoded
|
||||||
|
timestamp: List[int] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def key(self) -> str:
|
||||||
|
"""Return a string representation of self.ys"""
|
||||||
|
return "_".join(map(str, self.ys))
|
||||||
|
|
||||||
|
|
||||||
|
class HypothesisList(object):
|
||||||
|
def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None:
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
data:
|
||||||
|
A dict of Hypotheses. Its key is its `value.key`.
|
||||||
|
"""
|
||||||
|
if data is None:
|
||||||
|
self._data = {}
|
||||||
|
else:
|
||||||
|
self._data = data
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self) -> Dict[str, Hypothesis]:
|
||||||
|
return self._data
|
||||||
|
|
||||||
|
def add(self, hyp: Hypothesis) -> None:
|
||||||
|
"""Add a Hypothesis to `self`.
|
||||||
|
|
||||||
|
If `hyp` already exists in `self`, its probability is updated using
|
||||||
|
`log-sum-exp` with the existed one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hyp:
|
||||||
|
The hypothesis to be added.
|
||||||
|
"""
|
||||||
|
key = hyp.key
|
||||||
|
if key in self:
|
||||||
|
old_hyp = self._data[key] # shallow copy
|
||||||
|
torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob)
|
||||||
|
else:
|
||||||
|
self._data[key] = hyp
|
||||||
|
|
||||||
|
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
|
||||||
|
"""Get the most probable hypothesis, i.e., the one with
|
||||||
|
the largest `log_prob`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
length_norm:
|
||||||
|
If True, the `log_prob` of a hypothesis is normalized by the
|
||||||
|
number of tokens in it.
|
||||||
|
Returns:
|
||||||
|
Return the hypothesis that has the largest `log_prob`.
|
||||||
|
"""
|
||||||
|
if length_norm:
|
||||||
|
return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys))
|
||||||
|
else:
|
||||||
|
return max(self._data.values(), key=lambda hyp: hyp.log_prob)
|
||||||
|
|
||||||
|
def remove(self, hyp: Hypothesis) -> None:
|
||||||
|
"""Remove a given hypothesis.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
`self` is modified **in-place**.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hyp:
|
||||||
|
The hypothesis to be removed from `self`.
|
||||||
|
Note: It must be contained in `self`. Otherwise,
|
||||||
|
an exception is raised.
|
||||||
|
"""
|
||||||
|
key = hyp.key
|
||||||
|
assert key in self, f"{key} does not exist"
|
||||||
|
del self._data[key]
|
||||||
|
|
||||||
|
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
|
||||||
|
"""Remove all Hypotheses whose log_prob is less than threshold.
|
||||||
|
|
||||||
|
Caution:
|
||||||
|
`self` is not modified. Instead, a new HypothesisList is returned.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a new HypothesisList containing all hypotheses from `self`
|
||||||
|
with `log_prob` being greater than the given `threshold`.
|
||||||
|
"""
|
||||||
|
ans = HypothesisList()
|
||||||
|
for _, hyp in self._data.items():
|
||||||
|
if hyp.log_prob > threshold:
|
||||||
|
ans.add(hyp) # shallow copy
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def topk(self, k: int) -> "HypothesisList":
|
||||||
|
"""Return the top-k hypothesis."""
|
||||||
|
hyps = list(self._data.items())
|
||||||
|
|
||||||
|
hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k]
|
||||||
|
|
||||||
|
ans = HypothesisList(dict(hyps))
|
||||||
|
return ans
|
||||||
|
|
||||||
|
def __contains__(self, key: str):
|
||||||
|
return key in self._data
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return iter(self._data.values())
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
def __str__(self) -> str:
|
||||||
|
s = []
|
||||||
|
for key in self:
|
||||||
|
s.append(key)
|
||||||
|
return ", ".join(s)
|
||||||
|
|
||||||
|
|
||||||
|
def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
|
||||||
|
"""Return a ragged shape with axes [utt][num_hyps].
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hyps:
|
||||||
|
len(hyps) == batch_size. It contains the current hypothesis for
|
||||||
|
each utterance in the batch.
|
||||||
|
Returns:
|
||||||
|
Return a ragged shape with 2 axes [utt][num_hyps]. Note that
|
||||||
|
the shape is on CPU.
|
||||||
|
"""
|
||||||
|
num_hyps = [len(h) for h in hyps]
|
||||||
|
|
||||||
|
# torch.cumsum() is inclusive sum, so we put a 0 at the beginning
|
||||||
|
# to get exclusive sum later.
|
||||||
|
num_hyps.insert(0, 0)
|
||||||
|
|
||||||
|
num_hyps = torch.tensor(num_hyps)
|
||||||
|
row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32)
|
||||||
|
ans = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=row_splits, cached_tot_size=row_splits[-1].item()
|
||||||
|
)
|
||||||
|
return ans
|
||||||
|
|
||||||
|
|
||||||
|
def modified_beam_search(
|
||||||
|
model: torch.nn.Module,
|
||||||
|
encoder_out: torch.Tensor,
|
||||||
|
encoder_out_lens: torch.Tensor,
|
||||||
|
beam: int = 4,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
return_timestamps: bool = False,
|
||||||
|
) -> Union[List[List[int]], DecodingResults]:
|
||||||
|
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model:
|
||||||
|
The transducer model.
|
||||||
|
encoder_out:
|
||||||
|
Output from the encoder. Its shape is (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,), containing number of valid frames in
|
||||||
|
encoder_out before padding.
|
||||||
|
beam:
|
||||||
|
Number of active paths during the beam search.
|
||||||
|
temperature:
|
||||||
|
Softmax temperature.
|
||||||
|
return_timestamps:
|
||||||
|
Whether to return timestamps.
|
||||||
|
Returns:
|
||||||
|
If return_timestamps is False, return the decoded result.
|
||||||
|
Else, return a DecodingResults object containing
|
||||||
|
decoded result and corresponding timestamps.
|
||||||
|
"""
|
||||||
|
assert encoder_out.ndim == 3, encoder_out.shape
|
||||||
|
assert encoder_out.size(0) >= 1, encoder_out.size(0)
|
||||||
|
|
||||||
|
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
|
||||||
|
input=encoder_out,
|
||||||
|
lengths=encoder_out_lens.cpu(),
|
||||||
|
batch_first=True,
|
||||||
|
enforce_sorted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
blank_id = model.decoder.blank_id
|
||||||
|
unk_id = getattr(model, "unk_id", blank_id)
|
||||||
|
context_size = model.decoder.context_size
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
|
||||||
|
batch_size_list = packed_encoder_out.batch_sizes.tolist()
|
||||||
|
N = encoder_out.size(0)
|
||||||
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens
|
||||||
|
assert N == batch_size_list[0], (N, batch_size_list)
|
||||||
|
|
||||||
|
B = [HypothesisList() for _ in range(N)]
|
||||||
|
for i in range(N):
|
||||||
|
B[i].add(
|
||||||
|
Hypothesis(
|
||||||
|
ys=[blank_id] * context_size,
|
||||||
|
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
|
||||||
|
timestamp=[],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
|
||||||
|
|
||||||
|
offset = 0
|
||||||
|
finalized_B = []
|
||||||
|
for (t, batch_size) in enumerate(batch_size_list):
|
||||||
|
start = offset
|
||||||
|
end = offset + batch_size
|
||||||
|
current_encoder_out = encoder_out.data[start:end]
|
||||||
|
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
|
||||||
|
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
|
||||||
|
offset = end
|
||||||
|
|
||||||
|
finalized_B = B[batch_size:] + finalized_B
|
||||||
|
B = B[:batch_size]
|
||||||
|
|
||||||
|
hyps_shape = get_hyps_shape(B).to(device)
|
||||||
|
|
||||||
|
A = [list(b) for b in B]
|
||||||
|
B = [HypothesisList() for _ in range(batch_size)]
|
||||||
|
|
||||||
|
ys_log_probs = torch.cat(
|
||||||
|
[hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps]
|
||||||
|
) # (num_hyps, 1)
|
||||||
|
|
||||||
|
decoder_input = torch.tensor(
|
||||||
|
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int64,
|
||||||
|
) # (num_hyps, context_size)
|
||||||
|
|
||||||
|
decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
|
||||||
|
decoder_out = model.joiner.decoder_proj(decoder_out)
|
||||||
|
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)
|
||||||
|
|
||||||
|
# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
|
||||||
|
# as index, so we use `to(torch.int64)` below.
|
||||||
|
current_encoder_out = torch.index_select(
|
||||||
|
current_encoder_out,
|
||||||
|
dim=0,
|
||||||
|
index=hyps_shape.row_ids(1).to(torch.int64),
|
||||||
|
) # (num_hyps, 1, 1, encoder_out_dim)
|
||||||
|
|
||||||
|
logits = model.joiner(
|
||||||
|
current_encoder_out,
|
||||||
|
decoder_out,
|
||||||
|
project_input=False,
|
||||||
|
) # (num_hyps, 1, 1, vocab_size)
|
||||||
|
|
||||||
|
logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size)
|
||||||
|
|
||||||
|
log_probs.add_(ys_log_probs)
|
||||||
|
|
||||||
|
vocab_size = log_probs.size(-1)
|
||||||
|
|
||||||
|
log_probs = log_probs.reshape(-1)
|
||||||
|
|
||||||
|
row_splits = hyps_shape.row_splits(1) * vocab_size
|
||||||
|
log_probs_shape = k2.ragged.create_ragged_shape2(
|
||||||
|
row_splits=row_splits, cached_tot_size=log_probs.numel()
|
||||||
|
)
|
||||||
|
ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs)
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
|
||||||
|
|
||||||
|
with warnings.catch_warnings():
|
||||||
|
warnings.simplefilter("ignore")
|
||||||
|
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
|
||||||
|
topk_token_indexes = (topk_indexes % vocab_size).tolist()
|
||||||
|
|
||||||
|
for k in range(len(topk_hyp_indexes)):
|
||||||
|
hyp_idx = topk_hyp_indexes[k]
|
||||||
|
hyp = A[i][hyp_idx]
|
||||||
|
|
||||||
|
new_ys = hyp.ys[:]
|
||||||
|
new_token = topk_token_indexes[k]
|
||||||
|
new_timestamp = hyp.timestamp[:]
|
||||||
|
if new_token not in (blank_id, unk_id):
|
||||||
|
new_ys.append(new_token)
|
||||||
|
new_timestamp.append(t)
|
||||||
|
|
||||||
|
new_log_prob = topk_log_probs[k]
|
||||||
|
new_hyp = Hypothesis(
|
||||||
|
ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp
|
||||||
|
)
|
||||||
|
B[i].add(new_hyp)
|
||||||
|
|
||||||
|
B = B + finalized_B
|
||||||
|
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
|
||||||
|
|
||||||
|
sorted_ans = [h.ys[context_size:] for h in best_hyps]
|
||||||
|
sorted_timestamps = [h.timestamp for h in best_hyps]
|
||||||
|
ans = []
|
||||||
|
ans_timestamps = []
|
||||||
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
|
for i in range(N):
|
||||||
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
|
ans_timestamps.append(sorted_timestamps[unsorted_indices[i]])
|
||||||
|
|
||||||
|
if not return_timestamps:
|
||||||
|
return ans
|
||||||
|
else:
|
||||||
|
return DecodingResults(
|
||||||
|
hyps=ans,
|
||||||
|
timestamps=ans_timestamps,
|
||||||
|
)
|
240
egs/librispeech/ASR/long_file_recog/merge_chunks.py
Executable file
240
egs/librispeech/ASR/long_file_recog/merge_chunks.py
Executable file
@ -0,0 +1,240 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This file merge overlapped chunks into utterances accroding to recording ids.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import sentencepiece as spm
|
||||||
|
from lhotse import (
|
||||||
|
CutSet,
|
||||||
|
MonoCut,
|
||||||
|
SupervisionSegment,
|
||||||
|
SupervisionSet,
|
||||||
|
load_manifest,
|
||||||
|
load_manifest_lazy,
|
||||||
|
)
|
||||||
|
from lhotse.cut import Cut
|
||||||
|
from lhotse.serialization import SequentialJsonlWriter
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/bpe.model",
|
||||||
|
help="Path to the BPE model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest-in-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/librilight/manifests_chunk_recog"),
|
||||||
|
help="Path to directory of chunk cuts with recognition results.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest-out-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/manifests"),
|
||||||
|
help="Path to directory to save full utterance by merging overlapped chunks.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--extra",
|
||||||
|
type=float,
|
||||||
|
default=2.0,
|
||||||
|
help="""Extra duration (in seconds) at both sides.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def merge_chunks(
|
||||||
|
cuts_chunk: CutSet,
|
||||||
|
supervisions: SupervisionSet,
|
||||||
|
cuts_writer: SequentialJsonlWriter,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
extra: float,
|
||||||
|
) -> int:
|
||||||
|
"""Merge chunk-wise cuts accroding to recording ids.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cuts_chunk:
|
||||||
|
The chunk-wise cuts opened in a lazy mode.
|
||||||
|
supervisions:
|
||||||
|
The supervision manifest containing text file path, opened in a lazy mode.
|
||||||
|
cuts_writer:
|
||||||
|
Writer to save the cuts with recognition results.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
extra:
|
||||||
|
Extra duration (in seconds) to drop at both sides of each chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Background worker to add alignemnt and save cuts to disk.
|
||||||
|
def _save_worker(utt_cut: Cut, flush=False):
|
||||||
|
cuts_writer.write(utt_cut, flush=flush)
|
||||||
|
|
||||||
|
def _merge(cut_list: List[Cut], rec_id: str, utt_idx: int):
|
||||||
|
"""Merge chunks with same recording_id."""
|
||||||
|
for cut in cut_list:
|
||||||
|
assert cut.recording.id == rec_id, (cut.recording.id, rec_id)
|
||||||
|
|
||||||
|
# For each group with a same recording, sort it accroding to the start time
|
||||||
|
# In fact, we don't need to do this since the cuts have been sorted
|
||||||
|
# according to the start time
|
||||||
|
cut_list = sorted(cut_list, key=(lambda cut: cut.start))
|
||||||
|
|
||||||
|
rec = cut_list[0].recording
|
||||||
|
alignments = []
|
||||||
|
cur_end = 0
|
||||||
|
for cut in cut_list:
|
||||||
|
# Get left and right borders
|
||||||
|
left = cut.start + extra if cut.start > 0 else 0
|
||||||
|
chunk_end = cut.start + cut.duration
|
||||||
|
right = chunk_end - extra if chunk_end < rec.duration else rec.duration
|
||||||
|
|
||||||
|
# Assert the chunks are continuous
|
||||||
|
assert left == cur_end, (left, cur_end)
|
||||||
|
cur_end = right
|
||||||
|
|
||||||
|
assert len(cut.supervisions) == 1, len(cut.supervisions)
|
||||||
|
for ali in cut.supervisions[0].alignment["symbol"]:
|
||||||
|
t = ali.start + cut.start
|
||||||
|
if left <= t < right:
|
||||||
|
alignments.append(ali.with_offset(cut.start))
|
||||||
|
|
||||||
|
old_sup = supervisions[rec_id]
|
||||||
|
# Assuming the supervisions are sorted with the same recoding order as in cuts_chunk
|
||||||
|
# old_sup = supervisions[utt_idx]
|
||||||
|
assert old_sup.recording_id == rec_id, (old_sup.recording_id, rec_id)
|
||||||
|
|
||||||
|
new_sup = SupervisionSegment(
|
||||||
|
id=rec_id,
|
||||||
|
recording_id=rec_id,
|
||||||
|
start=0,
|
||||||
|
duration=rec.duration,
|
||||||
|
alignment={"symbol": alignments},
|
||||||
|
language=old_sup.language,
|
||||||
|
speaker=old_sup.speaker,
|
||||||
|
)
|
||||||
|
|
||||||
|
utt_cut = MonoCut(
|
||||||
|
id=rec_id,
|
||||||
|
start=0,
|
||||||
|
duration=rec.duration,
|
||||||
|
channel=0,
|
||||||
|
recording=rec,
|
||||||
|
supervisions=[new_sup],
|
||||||
|
)
|
||||||
|
# Set a custom attribute to the cut
|
||||||
|
utt_cut.text_path = old_sup.book
|
||||||
|
|
||||||
|
return utt_cut
|
||||||
|
|
||||||
|
last_rec_id = None
|
||||||
|
cut_list = []
|
||||||
|
utt_idx = 0
|
||||||
|
|
||||||
|
futures = []
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
|
||||||
|
for cut in cuts_chunk:
|
||||||
|
cur_rec_id = cut.recording.id
|
||||||
|
if len(cut_list) == 0:
|
||||||
|
# Case of the first cut
|
||||||
|
last_rec_id = cur_rec_id
|
||||||
|
cut_list.append(cut)
|
||||||
|
elif cur_rec_id == last_rec_id:
|
||||||
|
cut_list.append(cut)
|
||||||
|
else:
|
||||||
|
# Case of a cut belonging to a new recording
|
||||||
|
utt_cut = _merge(cut_list, last_rec_id, utt_idx)
|
||||||
|
utt_idx += 1
|
||||||
|
|
||||||
|
futures.append(executor.submit(_save_worker, utt_cut))
|
||||||
|
|
||||||
|
last_rec_id = cur_rec_id
|
||||||
|
cut_list = [cut]
|
||||||
|
|
||||||
|
if utt_idx % 5000 == 0:
|
||||||
|
logging.info(f"Procesed {utt_idx} utterances.")
|
||||||
|
|
||||||
|
# For the cuts belonging to the last recording
|
||||||
|
if len(cut_list) != 0:
|
||||||
|
utt_cut = _merge(cut_list, last_rec_id, utt_idx)
|
||||||
|
utt_idx += 1
|
||||||
|
|
||||||
|
futures.append(executor.submit(_save_worker, utt_cut))
|
||||||
|
logging.info("Finished")
|
||||||
|
|
||||||
|
for f in futures:
|
||||||
|
f.result()
|
||||||
|
|
||||||
|
return utt_idx
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_parser()
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(args.bpe_model)
|
||||||
|
|
||||||
|
# It contains "librilight_recordings_*.jsonl.gz" and "librilight_supervisions_small.jsonl.gz"
|
||||||
|
manifest_out_dir = args.manifest_out_dir
|
||||||
|
|
||||||
|
subsets = ["small", "median", "large"]
|
||||||
|
|
||||||
|
for subset in subsets:
|
||||||
|
logging.info(f"Processing {subset} subset")
|
||||||
|
|
||||||
|
manifest_out = manifest_out_dir / f"librilight_cuts_{subset}.jsonl.gz"
|
||||||
|
if manifest_out.is_file():
|
||||||
|
logging.info(f"{manifest_out} already exists - skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
supervisions = load_manifest(
|
||||||
|
manifest_out_dir / f"librilight_supervisions_{subset}.jsonl.gz"
|
||||||
|
) # We will use the text path from supervisions
|
||||||
|
|
||||||
|
cuts_chunk = load_manifest_lazy(
|
||||||
|
args.manifest_in_dir / f"librilight_cuts_{subset}.jsonl.gz"
|
||||||
|
)
|
||||||
|
|
||||||
|
cuts_writer = CutSet.open_writer(manifest_out, overwrite=True)
|
||||||
|
num_utt = merge_chunks(
|
||||||
|
cuts_chunk, supervisions, cuts_writer=cuts_writer, sp=sp, extra=args.extra
|
||||||
|
)
|
||||||
|
cuts_writer.close()
|
||||||
|
logging.info(f"{num_utt} cuts saved to {manifest_out}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
435
egs/librispeech/ASR/long_file_recog/recognize.py
Executable file
435
egs/librispeech/ASR/long_file_recog/recognize.py
Executable file
@ -0,0 +1,435 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
This script loads torchscript models, exported by `torch.jit.script()`,
|
||||||
|
and uses them to decode waves.
|
||||||
|
You can use the following command to get the exported models:
|
||||||
|
|
||||||
|
./pruned_transducer_stateless7/export.py \
|
||||||
|
--exp-dir ./pruned_transducer_stateless7/exp \
|
||||||
|
--bpe-model data/lang_bpe_500/bpe.model \
|
||||||
|
--epoch 20 \
|
||||||
|
--avg 10 \
|
||||||
|
--jit 1
|
||||||
|
|
||||||
|
You can also download the jit model from
|
||||||
|
https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import logging
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import k2
|
||||||
|
import sentencepiece as spm
|
||||||
|
from asr_datamodule import AsrDataModule
|
||||||
|
from beam_search import (
|
||||||
|
fast_beam_search_one_best,
|
||||||
|
greedy_search_batch,
|
||||||
|
modified_beam_search,
|
||||||
|
)
|
||||||
|
from icefall.utils import AttributeDict, convert_timestamp, setup_logger
|
||||||
|
from lhotse import CutSet, load_manifest_lazy
|
||||||
|
from lhotse.cut import Cut
|
||||||
|
from lhotse.supervision import AlignmentItem
|
||||||
|
from lhotse.serialization import SequentialJsonlWriter
|
||||||
|
|
||||||
|
|
||||||
|
def get_parser():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--world-size",
|
||||||
|
type=int,
|
||||||
|
default=1,
|
||||||
|
help="Number of GPUs for DDP training.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--master-port",
|
||||||
|
type=int,
|
||||||
|
default=12354,
|
||||||
|
help="Master port to use for DDP training.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--subset",
|
||||||
|
type=str,
|
||||||
|
default="small",
|
||||||
|
help="Subset to process. Possible values are 'small', 'medium', 'large'",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest-in-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/librilight/manifests_chunk"),
|
||||||
|
help="Path to directory with chunks cuts.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest-out-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/librilight/manifests_chunk_recog"),
|
||||||
|
help="Path to directory to save the chunk cuts with recognition results.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--log-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("long_file_recog/log"),
|
||||||
|
help="Path to directory to save logs.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--nn-model-filename",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path to the torchscript model cpu_jit.pt",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--bpe-model",
|
||||||
|
type=str,
|
||||||
|
default="data/lang_bpe_500/bpe.model",
|
||||||
|
help="Path to the BPE model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--decoding-method",
|
||||||
|
type=str,
|
||||||
|
default="greedy_search",
|
||||||
|
help="""Possible values are:
|
||||||
|
- greedy_search
|
||||||
|
- modified_beam_search
|
||||||
|
- fast_beam_search
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
def get_params() -> AttributeDict:
|
||||||
|
"""Return a dict containing decoding parameters."""
|
||||||
|
params = AttributeDict(
|
||||||
|
{
|
||||||
|
"subsampling_factor": 4,
|
||||||
|
"frame_shift_ms": 10,
|
||||||
|
# Used only when --method is beam_search or modified_beam_search.
|
||||||
|
"beam_size": 4,
|
||||||
|
# Used only when --method is beam_search or fast_beam_search.
|
||||||
|
# A floating point value to calculate the cutoff score during beam
|
||||||
|
# search (i.e., `cutoff = max-score - beam`), which is the same as the
|
||||||
|
# `beam` in Kaldi.
|
||||||
|
"beam": 4,
|
||||||
|
"max_contexts": 4,
|
||||||
|
"max_states": 8,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
def decode_one_batch(
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
batch: dict,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
) -> Tuple[List[List[str]], List[List[float]], List[List[float]]]:
|
||||||
|
"""Decode one batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params:
|
||||||
|
It's the return value of :func:`get_params`.
|
||||||
|
paramsmodel:
|
||||||
|
The neural model.
|
||||||
|
batch:
|
||||||
|
It is the return value from iterating
|
||||||
|
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||||
|
for the format of the `batch`.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return the decoding result, timestamps, and scores.
|
||||||
|
"""
|
||||||
|
device = next(model.parameters()).device
|
||||||
|
feature = batch["inputs"]
|
||||||
|
assert feature.ndim == 3
|
||||||
|
|
||||||
|
feature = feature.to(device)
|
||||||
|
# at entry, feature is (N, T, C)
|
||||||
|
|
||||||
|
supervisions = batch["supervisions"]
|
||||||
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
|
encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens)
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
res = fast_beam_search_one_best(
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam,
|
||||||
|
max_contexts=params.max_contexts,
|
||||||
|
max_states=params.max_states,
|
||||||
|
return_timestamps=True,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "greedy_search":
|
||||||
|
res = greedy_search_batch(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
return_timestamps=True,
|
||||||
|
)
|
||||||
|
elif params.decoding_method == "modified_beam_search":
|
||||||
|
res = modified_beam_search(
|
||||||
|
model=model,
|
||||||
|
encoder_out=encoder_out,
|
||||||
|
encoder_out_lens=encoder_out_lens,
|
||||||
|
beam=params.beam_size,
|
||||||
|
return_timestamps=True,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported decoding method: {params.decoding_method}")
|
||||||
|
|
||||||
|
hyps = []
|
||||||
|
timestamps = []
|
||||||
|
scores = []
|
||||||
|
for i in range(feature.shape[0]):
|
||||||
|
hyps.append(res.hyps[i])
|
||||||
|
timestamps.append(
|
||||||
|
convert_timestamp(
|
||||||
|
res.timestamps[i], params.subsampling_factor, params.frame_shift_ms
|
||||||
|
)
|
||||||
|
)
|
||||||
|
scores.append(res.scores[i])
|
||||||
|
|
||||||
|
return hyps, timestamps, scores
|
||||||
|
|
||||||
|
|
||||||
|
def decode_dataset(
|
||||||
|
dl: torch.utils.data.DataLoader,
|
||||||
|
params: AttributeDict,
|
||||||
|
model: nn.Module,
|
||||||
|
sp: spm.SentencePieceProcessor,
|
||||||
|
cuts_writer: SequentialJsonlWriter,
|
||||||
|
decoding_graph: Optional[k2.Fsa] = None,
|
||||||
|
) -> None:
|
||||||
|
"""Decode dataset and store the recognition results to manifest.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dl:
|
||||||
|
PyTorch's dataloader containing the dataset to decode.
|
||||||
|
params:
|
||||||
|
It is returned by :func:`get_params`.
|
||||||
|
model:
|
||||||
|
The neural model.
|
||||||
|
sp:
|
||||||
|
The BPE model.
|
||||||
|
cuts_writer:
|
||||||
|
Writer to save the cuts with recognition results.
|
||||||
|
decoding_graph:
|
||||||
|
The decoding graph. Can be either a `k2.trivial_graph` or LG, Used
|
||||||
|
only when --decoding_method is fast_beam_search.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Return a dict, whose key may be "greedy_search" if greedy search
|
||||||
|
is used, or it may be "beam_7" if beam size of 7 is used.
|
||||||
|
Its value is a list of tuples. Each tuple contains five elements:
|
||||||
|
- cut_id
|
||||||
|
- reference transcript
|
||||||
|
- predicted result
|
||||||
|
- timestamps of reference transcript
|
||||||
|
- timestamps of predicted result
|
||||||
|
"""
|
||||||
|
# Background worker to add alignemnt and save cuts to disk.
|
||||||
|
def _save_worker(
|
||||||
|
cuts: List[Cut],
|
||||||
|
hyps: List[List[str]],
|
||||||
|
timestamps: List[List[float]],
|
||||||
|
scores: List[List[float]],
|
||||||
|
):
|
||||||
|
for cut, symbol_list, time_list, score_list in zip(
|
||||||
|
cuts, hyps, timestamps, scores
|
||||||
|
):
|
||||||
|
symbol_list = sp.id_to_piece(symbol_list)
|
||||||
|
ali = [
|
||||||
|
AlignmentItem(symbol=symbol, start=start, duration=None, score=score)
|
||||||
|
for symbol, start, score in zip(symbol_list, time_list, score_list)
|
||||||
|
]
|
||||||
|
assert len(cut.supervisions) == 1, len(cut.supervisions)
|
||||||
|
cut.supervisions[0].alignment = {"symbol": ali}
|
||||||
|
cuts_writer.write(cut, flush=True)
|
||||||
|
|
||||||
|
num_cuts = 0
|
||||||
|
log_interval = 10
|
||||||
|
futures = []
|
||||||
|
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||||
|
# We only want one background worker so that serialization is deterministic.
|
||||||
|
|
||||||
|
for batch_idx, batch in enumerate(dl):
|
||||||
|
cuts = batch["supervisions"]["cut"]
|
||||||
|
|
||||||
|
hyps, timestamps, scores = decode_one_batch(
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
batch=batch,
|
||||||
|
)
|
||||||
|
|
||||||
|
futures.append(
|
||||||
|
executor.submit(_save_worker, cuts, hyps, timestamps, scores)
|
||||||
|
)
|
||||||
|
|
||||||
|
num_cuts += len(cuts)
|
||||||
|
if batch_idx % log_interval == 0:
|
||||||
|
logging.info(f"cuts processed until now is {num_cuts}")
|
||||||
|
|
||||||
|
for f in futures:
|
||||||
|
f.result()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def run(rank, world_size, args, in_cuts):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
rank:
|
||||||
|
It is a value between 0 and `world_size-1`.
|
||||||
|
world_size:
|
||||||
|
Number of GPUs for DDP training.
|
||||||
|
args:
|
||||||
|
The return value of get_parser().parse_args()
|
||||||
|
"""
|
||||||
|
params = get_params()
|
||||||
|
params.update(vars(args))
|
||||||
|
|
||||||
|
setup_logger(f"{params.log_dir}/log-decode")
|
||||||
|
logging.info("Decoding started")
|
||||||
|
|
||||||
|
assert params.decoding_method in (
|
||||||
|
"greedy_search",
|
||||||
|
"fast_beam_search",
|
||||||
|
"modified_beam_search",
|
||||||
|
), params.decoding_method
|
||||||
|
|
||||||
|
sp = spm.SentencePieceProcessor()
|
||||||
|
sp.load(params.bpe_model)
|
||||||
|
|
||||||
|
# <blk> is defined in local/train_bpe_model.py
|
||||||
|
params.blank_id = sp.piece_to_id("<blk>")
|
||||||
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
|
logging.info(f"{params}")
|
||||||
|
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
device = torch.device("cuda", rank)
|
||||||
|
logging.info(f"device: {device}")
|
||||||
|
|
||||||
|
logging.info("Loading jit model")
|
||||||
|
model = torch.jit.load(params.nn_model_filename)
|
||||||
|
model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if params.decoding_method == "fast_beam_search":
|
||||||
|
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
|
||||||
|
else:
|
||||||
|
decoding_graph = None
|
||||||
|
|
||||||
|
# we will store new cuts with recognition results.
|
||||||
|
args.return_cuts = True
|
||||||
|
asr_data_module = AsrDataModule(args)
|
||||||
|
|
||||||
|
if world_size > 1:
|
||||||
|
in_cuts = in_cuts[rank]
|
||||||
|
out_cuts_filename = params.manifest_out_dir / (
|
||||||
|
f"{params.cuts_filename}_job_{rank}" + params.suffix
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
out_cuts_filename = params.manifest_out_dir / (
|
||||||
|
f"{params.cuts_filename}" + params.suffix
|
||||||
|
)
|
||||||
|
|
||||||
|
dl = asr_data_module.dataloaders(in_cuts)
|
||||||
|
|
||||||
|
cuts_writer = CutSet.open_writer(out_cuts_filename, overwrite=True)
|
||||||
|
decode_dataset(
|
||||||
|
dl=dl,
|
||||||
|
params=params,
|
||||||
|
model=model,
|
||||||
|
sp=sp,
|
||||||
|
decoding_graph=decoding_graph,
|
||||||
|
cuts_writer=cuts_writer,
|
||||||
|
)
|
||||||
|
cuts_writer.close()
|
||||||
|
logging.info(f"Cuts saved to {out_cuts_filename}")
|
||||||
|
|
||||||
|
logging.info("Done!")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = get_parser()
|
||||||
|
AsrDataModule.add_arguments(parser)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
subset = args.subset
|
||||||
|
assert subset in ["small", "medium", "large"], subset
|
||||||
|
|
||||||
|
manifest_out_dir = args.manifest_out_dir
|
||||||
|
manifest_out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
args.suffix = ".jsonl.gz"
|
||||||
|
args.cuts_filename = f"librilight_cuts_{args.subset}"
|
||||||
|
|
||||||
|
out_cuts_filename = manifest_out_dir / (args.cuts_filename + args.suffix)
|
||||||
|
if out_cuts_filename.is_file():
|
||||||
|
logging.info(f"{out_cuts_filename} already exists - skipping.")
|
||||||
|
return
|
||||||
|
|
||||||
|
in_cuts_filename = args.manifest_in_dir / (args.cuts_filename + args.suffix)
|
||||||
|
in_cuts = load_manifest_lazy(in_cuts_filename)
|
||||||
|
|
||||||
|
world_size = args.world_size
|
||||||
|
assert world_size >= 1
|
||||||
|
if world_size > 1:
|
||||||
|
chunk_size = (len(in_cuts) + (world_size - 1)) // world_size
|
||||||
|
# Each manifest is saved at: ``{output_dir}/{prefix}.{split_idx}.jsonl.gz``
|
||||||
|
splits = in_cuts.split_lazy(
|
||||||
|
output_dir=args.manifest_in_dir / "split",
|
||||||
|
chunk_size=chunk_size,
|
||||||
|
prefix=args.cuts_filename,
|
||||||
|
)
|
||||||
|
assert len(splits) == world_size, (len(splits), world_size)
|
||||||
|
mp.spawn(run, args=(world_size, args, splits), nprocs=world_size, join=True)
|
||||||
|
else:
|
||||||
|
run(rank=0, world_size=world_size, args=args, in_cuts=in_cuts)
|
||||||
|
|
||||||
|
|
||||||
|
torch.set_num_threads(1)
|
||||||
|
torch.set_num_interop_threads(1)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
100
egs/librispeech/ASR/long_file_recog/split_into_chunks.py
Executable file
100
egs/librispeech/ASR/long_file_recog/split_into_chunks.py
Executable file
@ -0,0 +1,100 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao)
|
||||||
|
#
|
||||||
|
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
"""
|
||||||
|
This script splits long utterances into chunks with overlaps.
|
||||||
|
Each chunk (except the first and the last) is padded with extra left side and right side.
|
||||||
|
The chunk length is: left_side + chunk_size + right_side.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from lhotse import CutSet, load_manifest
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest-in-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/librilight/manifests"),
|
||||||
|
help="Path to directory of full utterances.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--manifest-out-dir",
|
||||||
|
type=Path,
|
||||||
|
default=Path("data/librilight/manifests_chunk"),
|
||||||
|
help="Path to directory to save splitted chunks.",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--chunk",
|
||||||
|
type=float,
|
||||||
|
default=300.0,
|
||||||
|
help="""Duration (in seconds) of each chunk.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--extra",
|
||||||
|
type=float,
|
||||||
|
default=2.0,
|
||||||
|
help="""Extra duration (in seconds) at both sides.""",
|
||||||
|
)
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
logging.info(vars(args))
|
||||||
|
|
||||||
|
manifest_out_dir = args.manifest_out_dir
|
||||||
|
manifest_out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
subsets = ["small", "medium", "large"]
|
||||||
|
|
||||||
|
for subset in subsets:
|
||||||
|
logging.info(f"Processing {subset} subset")
|
||||||
|
|
||||||
|
manifest_out = manifest_out_dir / f"librilight_cuts_{subset}.jsonl.gz"
|
||||||
|
if manifest_out.is_file():
|
||||||
|
logging.info(f"{manifest_out} already exists - skipping.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
manifest_in = args.manifest_in_dir / f"librilight_recordings_{subset}.jsonl.gz"
|
||||||
|
recordings = load_manifest(manifest_in)
|
||||||
|
|
||||||
|
cuts = CutSet.from_manifests(recordings=recordings)
|
||||||
|
cuts = cuts.cut_into_windows(
|
||||||
|
duration=args.chunk, hop=args.chunk - args.extra * 2
|
||||||
|
)
|
||||||
|
cuts = cuts.fill_supervisions()
|
||||||
|
|
||||||
|
cuts.to_file(manifest_out)
|
||||||
|
logging.info(f"Cuts saved to {manifest_out}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||||
|
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||||
|
|
||||||
|
main()
|
@ -670,6 +670,8 @@ def greedy_search_batch(
|
|||||||
# timestamp[n][i] is the frame index after subsampling
|
# timestamp[n][i] is the frame index after subsampling
|
||||||
# on which hyp[n][i] is decoded
|
# on which hyp[n][i] is decoded
|
||||||
timestamps = [[] for _ in range(N)]
|
timestamps = [[] for _ in range(N)]
|
||||||
|
# scores[n][i] is the logits on which hyp[n][i] is decoded
|
||||||
|
scores = [[] for _ in range(N)]
|
||||||
|
|
||||||
decoder_input = torch.tensor(
|
decoder_input = torch.tensor(
|
||||||
hyps,
|
hyps,
|
||||||
@ -707,6 +709,7 @@ def greedy_search_batch(
|
|||||||
if v not in (blank_id, unk_id):
|
if v not in (blank_id, unk_id):
|
||||||
hyps[i].append(v)
|
hyps[i].append(v)
|
||||||
timestamps[i].append(t)
|
timestamps[i].append(t)
|
||||||
|
scores[i].append(logits[i, v].item())
|
||||||
emitted = True
|
emitted = True
|
||||||
if emitted:
|
if emitted:
|
||||||
# update decoder output
|
# update decoder output
|
||||||
@ -722,10 +725,12 @@ def greedy_search_batch(
|
|||||||
sorted_ans = [h[context_size:] for h in hyps]
|
sorted_ans = [h[context_size:] for h in hyps]
|
||||||
ans = []
|
ans = []
|
||||||
ans_timestamps = []
|
ans_timestamps = []
|
||||||
|
ans_scores = []
|
||||||
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
ans.append(sorted_ans[unsorted_indices[i]])
|
ans.append(sorted_ans[unsorted_indices[i]])
|
||||||
ans_timestamps.append(timestamps[unsorted_indices[i]])
|
ans_timestamps.append(timestamps[unsorted_indices[i]])
|
||||||
|
ans_scores.append(scores[unsorted_indices[i]])
|
||||||
|
|
||||||
if not return_timestamps:
|
if not return_timestamps:
|
||||||
return ans
|
return ans
|
||||||
@ -733,6 +738,7 @@ def greedy_search_batch(
|
|||||||
return DecodingResults(
|
return DecodingResults(
|
||||||
hyps=ans,
|
hyps=ans,
|
||||||
timestamps=ans_timestamps,
|
timestamps=ans_timestamps,
|
||||||
|
scores=ans_scores,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -272,6 +272,9 @@ class DecodingResults:
|
|||||||
# for the i-th utterance with fast_beam_search_nbest_LG.
|
# for the i-th utterance with fast_beam_search_nbest_LG.
|
||||||
hyps: Union[List[List[int]], k2.RaggedTensor]
|
hyps: Union[List[List[int]], k2.RaggedTensor]
|
||||||
|
|
||||||
|
# scores[i][k] contains the log-prob of tokens[i][k]
|
||||||
|
scores: Optional[List[List[float]]] = None
|
||||||
|
|
||||||
|
|
||||||
def get_texts_with_timestamp(
|
def get_texts_with_timestamp(
|
||||||
best_paths: k2.Fsa, return_ragged: bool = False
|
best_paths: k2.Fsa, return_ragged: bool = False
|
||||||
@ -1442,7 +1445,7 @@ def convert_timestamp(
|
|||||||
frame_shift = frame_shift_ms / 1000.0
|
frame_shift = frame_shift_ms / 1000.0
|
||||||
time = []
|
time = []
|
||||||
for f in frames:
|
for f in frames:
|
||||||
time.append(f * subsampling_factor * frame_shift)
|
time.append(round(f * subsampling_factor * frame_shift, ndigits=3))
|
||||||
|
|
||||||
return time
|
return time
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user