mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add force alignment for stateless transducer. (#239)
* Add force alignment for stateless transducer. * Add more documentation. * Compute word starting time from framewise token alignment. * Update README to include force alignment information. * Fix typos. * Fix more typos. * Fixes after review.
This commit is contained in:
parent
1603744469
commit
2f4e71f433
@ -60,8 +60,11 @@ log "dl_dir: $dl_dir"
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "Stage -1: Download LM"
|
||||
[ ! -e $dl_dir/lm ] && mkdir -p $dl_dir/lm
|
||||
./local/download_lm.py --out-dir=$dl_dir/lm
|
||||
mkdir -p $dl_dir/lm
|
||||
if [ ! -e $dl_dir/lm/.done ]; then
|
||||
./local/download_lm.py --out-dir=$dl_dir/lm
|
||||
touch $dl_dir/lm/.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
@ -91,7 +94,10 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
# We assume that you have downloaded the LibriSpeech corpus
|
||||
# to $dl_dir/LibriSpeech
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests
|
||||
if [ ! -e data/manifests/.librispeech.done ]; then
|
||||
lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests
|
||||
touch data/manifests/.librispeech.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
@ -99,19 +105,28 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
# We assume that you have downloaded the musan corpus
|
||||
# to data/musan
|
||||
mkdir -p data/manifests
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
if [ ! -e data/manifests/.musan.done ]; then
|
||||
lhotse prepare musan $dl_dir/musan data/manifests
|
||||
touch data/manifests/.musan.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "Stage 3: Compute fbank for librispeech"
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_librispeech.py
|
||||
if [ ! -e data/fbank/.librispeech.done ]; then
|
||||
./local/compute_fbank_librispeech.py
|
||||
touch data/fbank/.librispeech.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "Stage 4: Compute fbank for musan"
|
||||
mkdir -p data/fbank
|
||||
./local/compute_fbank_musan.py
|
||||
if [ ! -e data/fbank/.musan.done ]; then
|
||||
./local/compute_fbank_musan.py
|
||||
touch data/fbank/.musan.done
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
|
@ -180,14 +180,14 @@ class LibriSpeechAsrDataModule:
|
||||
)
|
||||
|
||||
def train_dataloaders(self, cuts_train: CutSet) -> DataLoader:
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "cuts_musan.json.gz"
|
||||
)
|
||||
|
||||
transforms = []
|
||||
if self.args.enable_musan:
|
||||
logging.info("Enable MUSAN")
|
||||
logging.info("About to get Musan cuts")
|
||||
cuts_musan = load_manifest(
|
||||
self.args.manifest_dir / "cuts_musan.json.gz"
|
||||
)
|
||||
transforms.append(
|
||||
CutMix(
|
||||
cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True
|
||||
|
@ -20,3 +20,120 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--max-duration 250 \
|
||||
--lr-factor 2.5
|
||||
```
|
||||
|
||||
## How to get framewise token alignment
|
||||
|
||||
Assume that you already have a trained model. If not, you can either
|
||||
train one by yourself or download a pre-trained model from hugging face:
|
||||
<https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01>
|
||||
|
||||
**Caution**: If you are going to use your own trained model, remember
|
||||
to set `--modified-transducer-prob` to a nonzero value since the
|
||||
force alignment code assumes that `--max-sym-per-frame` is 1.
|
||||
|
||||
|
||||
The following shows how to get framewise token alignment using the above
|
||||
pre-trained model.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/k2-fsa/icefall
|
||||
cd icefall/egs/librispeech/ASR
|
||||
mkdir tmp
|
||||
sudo apt-get install git-lfs
|
||||
git lfs install
|
||||
git clone https://huggingface.co/csukuangfj/icefall-asr-librispeech-transducer-stateless-multi-datasets-bpe-500-2022-03-01 ./tmp/
|
||||
|
||||
ln -s $PWD/tmp/exp/pretrained.pt $PWD/tmp/epoch-999.pt
|
||||
|
||||
./transducer_stateless/compute_ali.py \
|
||||
--exp-dir ./tmp/exp \
|
||||
--bpe-model ./tmp/data/lang_bpe_500/bpe.model \
|
||||
--epoch 999 \
|
||||
--avg 1 \
|
||||
--max-duration 100 \
|
||||
--dataset dev-clean \
|
||||
--out-dir data/ali
|
||||
```
|
||||
|
||||
After running the above commands, you will find the following two files
|
||||
in the folder `./data/ali`:
|
||||
|
||||
```
|
||||
-rw-r--r-- 1 xxx xxx 412K Mar 7 15:45 cuts_dev-clean.json.gz
|
||||
-rw-r--r-- 1 xxx xxx 2.9M Mar 7 15:45 token_ali_dev-clean.h5
|
||||
```
|
||||
|
||||
You can find usage examples in `./test_compute_ali.py` about
|
||||
extracting framewise token alignment information from the above
|
||||
two files.
|
||||
|
||||
## How to get word starting time from framewise token alignment
|
||||
|
||||
Assume you have run the above commands to get framewise token alignment
|
||||
using a pre-trained model from `tmp/exp/epoch-999.pt`. You can use the following
|
||||
commands to obtain word starting time.
|
||||
|
||||
```bash
|
||||
./transducer_stateless/test_compute_ali.py \
|
||||
--bpe-model ./tmp/data/lang_bpe_500/bpe.model \
|
||||
--ali-dir data/ali \
|
||||
--dataset dev-clean
|
||||
```
|
||||
|
||||
**Caution**: Since the frame shift is 10ms and the subsampling factor
|
||||
of the model is 4, the time resolution is 0.04 second.
|
||||
|
||||
**Note**: The script `test_compute_ali.py` is for illustration only
|
||||
and it processes only one batch and then exits.
|
||||
|
||||
You will get the following output:
|
||||
|
||||
```
|
||||
5694-64029-0022-1998-0
|
||||
[('THE', '0.20'), ('LEADEN', '0.36'), ('HAIL', '0.72'), ('STORM', '1.00'), ('SWEPT', '1.48'), ('THEM', '1.88'), ('OFF', '2.00'), ('THE', '2.24'), ('FIELD', '2.36'), ('THEY', '3.20'), ('FELL', '3.36'), ('BACK', '3.64'), ('AND', '3.92'), ('RE', '4.04'), ('FORMED', '4.20')]
|
||||
|
||||
3081-166546-0040-308-0
|
||||
[('IN', '0.32'), ('OLDEN', '0.60'), ('DAYS', '1.00'), ('THEY', '1.40'), ('WOULD', '1.56'), ('HAVE', '1.76'), ('SAID', '1.92'), ('STRUCK', '2.60'), ('BY', '3.16'), ('A', '3.36'), ('BOLT', '3.44'), ('FROM', '3.84'), ('HEAVEN', '4.04')]
|
||||
|
||||
2035-147960-0016-1283-0
|
||||
[('A', '0.44'), ('SNAKE', '0.52'), ('OF', '0.84'), ('HIS', '0.96'), ('SIZE', '1.12'), ('IN', '1.60'), ('FIGHTING', '1.72'), ('TRIM', '2.12'), ('WOULD', '2.56'), ('BE', '2.76'), ('MORE', '2.88'), ('THAN', '3.08'), ('ANY', '3.28'), ('BOY', '3.56'), ('COULD', '3.88'), ('HANDLE', '4.04')]
|
||||
|
||||
2428-83699-0020-1734-0
|
||||
[('WHEN', '0.28'), ('THE', '0.48'), ('TRAP', '0.60'), ('DID', '0.88'), ('APPEAR', '1.08'), ('IT', '1.80'), ('LOOKED', '1.96'), ('TO',
|
||||
'2.24'), ('ME', '2.36'), ('UNCOMMONLY', '2.52'), ('LIKE', '3.16'), ('AN', '3.40'), ('OPEN', '3.56'), ('SPRING', '3.92'), ('CART', '4.28')]
|
||||
|
||||
8297-275154-0026-2108-0
|
||||
[('LET', '0.44'), ('ME', '0.72'), ('REST', '0.92'), ('A', '1.32'), ('LITTLE', '1.40'), ('HE', '1.80'), ('PLEADED', '2.00'), ('IF', '3.04'), ("I'M", '3.28'), ('NOT', '3.52'), ('IN', '3.76'), ('THE', '3.88'), ('WAY', '4.00')]
|
||||
|
||||
652-129742-0007-1002-0
|
||||
[('SURROUND', '0.28'), ('WITH', '0.80'), ('A', '0.92'), ('GARNISH', '1.00'), ('OF', '1.44'), ('COOKED', '1.56'), ('AND', '1.88'), ('DICED', '4.16'), ('CARROTS', '4.28'), ('TURNIPS', '4.44'), ('GREEN', '4.60'), ('PEAS', '4.72')]
|
||||
```
|
||||
|
||||
|
||||
For the row:
|
||||
```
|
||||
5694-64029-0022-1998-0
|
||||
[('THE', '0.20'), ('LEADEN', '0.36'), ('HAIL', '0.72'), ('STORM', '1.00'), ('SWEPT', '1.48'),
|
||||
('THEM', '1.88'), ('OFF', '2.00'), ('THE', '2.24'), ('FIELD', '2.36'), ('THEY', '3.20'), ('FELL', '3.36'),
|
||||
('BACK', '3.64'), ('AND', '3.92'), ('RE', '4.04'), ('FORMED', '4.20')]
|
||||
```
|
||||
|
||||
- `5694-64029-0022-1998-0` is the cut ID.
|
||||
- `('THE', '0.20')` means the word `THE` starts at 0.20 second.
|
||||
- `('LEADEN', '0.36')` means the word `LEADEN` starts at 0.36 second.
|
||||
|
||||
|
||||
You can compare the above word starting time with the one
|
||||
from <https://github.com/CorentinJ/librispeech-alignments>
|
||||
|
||||
```
|
||||
5694-64029-0022 ",THE,LEADEN,HAIL,STORM,SWEPT,THEM,OFF,THE,FIELD,,THEY,FELL,BACK,AND,RE,FORMED," "0.230,0.360,0.670,1.010,1.440,1.860,1.990,2.230,2.350,2.870,3.230,3.390,3.660,3.960,4.060,4.160,4.850,4.9"
|
||||
```
|
||||
|
||||
We reformat it below for readability:
|
||||
|
||||
```
|
||||
5694-64029-0022 ",THE,LEADEN,HAIL,STORM,SWEPT,THEM,OFF,THE,FIELD,,THEY,FELL,BACK,AND,RE,FORMED,"
|
||||
"0.230,0.360,0.670,1.010,1.440,1.860,1.990,2.230,2.350,2.870,3.230,3.390,3.660,3.960,4.060,4.160,4.850,4.9"
|
||||
the leaden hail storm swept them off the field sil they fell back and re formed sil
|
||||
```
|
||||
|
268
egs/librispeech/ASR/transducer_stateless/alignment.py
Normal file
268
egs/librispeech/ASR/transducer_stateless/alignment.py
Normal file
@ -0,0 +1,268 @@
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterator, List, Optional
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from model import Transducer
|
||||
|
||||
# The force alignment problem can be formulated as finding
|
||||
# a path in a rectangular lattice, where the path starts
|
||||
# from the lower left corner and ends at the upper right
|
||||
# corner. The horizontal axis of the lattice is `t` (representing
|
||||
# acoustic frame indexes) and the vertical axis is `u` (representing
|
||||
# BPE tokens of the transcript).
|
||||
#
|
||||
# The notations `t` and `u` are from the paper
|
||||
# https://arxiv.org/pdf/1211.3711.pdf
|
||||
#
|
||||
# Beam search is used to find the path with the
|
||||
# highest log probabilities.
|
||||
#
|
||||
# It assumes the maximum number of symbols that can be
|
||||
# emitted per frame is 1. You can use `--modified-transducer-prob`
|
||||
# from `./train.py` to train a model that satisfies this assumption.
|
||||
|
||||
|
||||
# AlignItem is the ending node of a path originated from the starting node.
|
||||
# len(ys) equals to `t` and pos_u is the u coordinate
|
||||
# in the lattice.
|
||||
@dataclass
|
||||
class AlignItem:
|
||||
# total log prob of the path that ends at this item.
|
||||
# The path is originated from the starting node.
|
||||
log_prob: float
|
||||
|
||||
# It contains framewise token alignment
|
||||
ys: List[int]
|
||||
|
||||
# It equals to the number of non-zero entries in ys
|
||||
pos_u: int
|
||||
|
||||
|
||||
class AlignItemList:
|
||||
def __init__(self, items: Optional[List[AlignItem]] = None):
|
||||
"""
|
||||
Args:
|
||||
items:
|
||||
A list of AlignItem
|
||||
"""
|
||||
if items is None:
|
||||
items = []
|
||||
self.data = items
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
return iter(self.data)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of AlignItem in this object."""
|
||||
return len(self.data)
|
||||
|
||||
def __getitem__(self, i: int) -> AlignItem:
|
||||
"""Return the i-th item in this object."""
|
||||
return self.data[i]
|
||||
|
||||
def append(self, item: AlignItem) -> None:
|
||||
"""Append an item to the end of this object."""
|
||||
self.data.append(item)
|
||||
|
||||
def get_decoder_input(
|
||||
self,
|
||||
ys: List[int],
|
||||
context_size: int,
|
||||
blank_id: int,
|
||||
) -> List[List[int]]:
|
||||
"""Get input for the decoder for each item in this object.
|
||||
|
||||
Args:
|
||||
ys:
|
||||
The transcript of the utterance in BPE tokens.
|
||||
context_size:
|
||||
Context size of the NN decoder model.
|
||||
blank_id:
|
||||
The ID of the blank symbol.
|
||||
Returns:
|
||||
Return a list-of-list int. `ans[i]` contains the decoder
|
||||
input for the i-th item in this object and its lengths
|
||||
is `context_size`.
|
||||
"""
|
||||
ans: List[List[int]] = []
|
||||
buf = [blank_id] * context_size + ys
|
||||
for item in self:
|
||||
# fmt: off
|
||||
ans.append(buf[item.pos_u:(item.pos_u + context_size)])
|
||||
# fmt: on
|
||||
return ans
|
||||
|
||||
def topk(self, k: int) -> "AlignItemList":
|
||||
"""Return the top-k items.
|
||||
|
||||
Items are ordered by their log probs in descending order
|
||||
and the top-k items are returned.
|
||||
|
||||
Args:
|
||||
k:
|
||||
Size of top-k.
|
||||
Returns:
|
||||
Return a new AlignItemList that contains the top-k items
|
||||
in this object. Caution: It uses shallow copy.
|
||||
"""
|
||||
items = list(self)
|
||||
items = sorted(items, key=lambda i: i.log_prob, reverse=True)
|
||||
return AlignItemList(items[:k])
|
||||
|
||||
|
||||
def force_alignment(
|
||||
model: Transducer,
|
||||
encoder_out: torch.Tensor,
|
||||
ys: List[int],
|
||||
beam_size: int = 4,
|
||||
) -> List[int]:
|
||||
"""Compute the force alignment of an utterance given its transcript
|
||||
in BPE tokens and the corresponding acoustic output from the encoder.
|
||||
|
||||
Caution:
|
||||
We assume that the maximum number of sybmols per frame is 1.
|
||||
That is, the model should be trained using a nonzero value
|
||||
for the option `--modified-transducer-prob` in train.py.
|
||||
|
||||
Args:
|
||||
model:
|
||||
The transducer model.
|
||||
encoder_out:
|
||||
A tensor of shape (N, T, C). Support only for N==1 at present.
|
||||
ys:
|
||||
A list of BPE token IDs. We require that len(ys) <= T.
|
||||
beam_size:
|
||||
Size of the beam used in beam search.
|
||||
Returns:
|
||||
Return a list of int such that
|
||||
- len(ans) == T
|
||||
- After removing blanks from ans, we have ans == ys.
|
||||
"""
|
||||
assert encoder_out.ndim == 3, encoder_out.ndim
|
||||
assert encoder_out.size(0) == 1, encoder_out.size(0)
|
||||
assert 0 < len(ys) <= encoder_out.size(1), (len(ys), encoder_out.size(1))
|
||||
|
||||
blank_id = model.decoder.blank_id
|
||||
context_size = model.decoder.context_size
|
||||
|
||||
device = model.device
|
||||
|
||||
T = encoder_out.size(1)
|
||||
U = len(ys)
|
||||
assert 0 < U <= T
|
||||
|
||||
encoder_out_len = torch.tensor([1])
|
||||
decoder_out_len = encoder_out_len
|
||||
|
||||
start = AlignItem(log_prob=0.0, ys=[], pos_u=0)
|
||||
B = AlignItemList([start])
|
||||
|
||||
for t in range(T):
|
||||
# fmt: off
|
||||
current_encoder_out = encoder_out[:, t:t+1, :]
|
||||
# current_encoder_out is of shape (1, 1, encoder_out_dim)
|
||||
# fmt: on
|
||||
|
||||
A = B # shallow copy
|
||||
B = AlignItemList()
|
||||
|
||||
decoder_input = A.get_decoder_input(
|
||||
ys=ys, context_size=context_size, blank_id=blank_id
|
||||
)
|
||||
decoder_input = torch.tensor(decoder_input, device=device)
|
||||
# decoder_input is of shape (num_active_items, context_size)
|
||||
|
||||
decoder_out = model.decoder(decoder_input, need_pad=False)
|
||||
# decoder_output is of shape (num_active_items, 1, decoder_output_dim)
|
||||
|
||||
current_encoder_out = current_encoder_out.expand(
|
||||
decoder_out.size(0), 1, -1
|
||||
)
|
||||
|
||||
logits = model.joiner(
|
||||
current_encoder_out,
|
||||
decoder_out,
|
||||
encoder_out_len.expand(decoder_out.size(0)),
|
||||
decoder_out_len.expand(decoder_out.size(0)),
|
||||
)
|
||||
|
||||
# logits is of shape (num_active_items, vocab_size)
|
||||
log_probs = logits.log_softmax(dim=-1).tolist()
|
||||
|
||||
for i, item in enumerate(A):
|
||||
if (T - 1 - t) >= (U - item.pos_u):
|
||||
# horizontal transition (left -> right)
|
||||
new_item = AlignItem(
|
||||
log_prob=item.log_prob + log_probs[i][blank_id],
|
||||
ys=item.ys + [blank_id],
|
||||
pos_u=item.pos_u,
|
||||
)
|
||||
B.append(new_item)
|
||||
|
||||
if item.pos_u < U:
|
||||
# diagonal transition (lower left -> upper right)
|
||||
u = ys[item.pos_u]
|
||||
new_item = AlignItem(
|
||||
log_prob=item.log_prob + log_probs[i][u],
|
||||
ys=item.ys + [u],
|
||||
pos_u=item.pos_u + 1,
|
||||
)
|
||||
B.append(new_item)
|
||||
|
||||
if len(B) > beam_size:
|
||||
B = B.topk(beam_size)
|
||||
|
||||
ans = B.topk(1)[0].ys
|
||||
|
||||
assert len(ans) == T
|
||||
assert list(filter(lambda i: i != blank_id, ans)) == ys
|
||||
|
||||
return ans
|
||||
|
||||
|
||||
def get_word_starting_frames(
|
||||
ali: List[int], sp: spm.SentencePieceProcessor
|
||||
) -> List[int]:
|
||||
"""Get the starting frame of each word from the given token alignments.
|
||||
|
||||
When a word is encoded into BPE tokens, the first token starts
|
||||
with underscore "_", which can be used to identify the starting frame
|
||||
of a word.
|
||||
|
||||
Args:
|
||||
ali:
|
||||
Framewise token alignment. It can be the return value of
|
||||
:func:`force_alignment`.
|
||||
sp:
|
||||
The sentencepiece model.
|
||||
Returns:
|
||||
Return a list of int representing the starting frame of each word
|
||||
in the alignment.
|
||||
Caution:
|
||||
You have to take into account the model subsampling factor when
|
||||
converting the starting frame into time.
|
||||
"""
|
||||
underscore = b"\xe2\x96\x81".decode() # '_'
|
||||
ans = []
|
||||
for i in range(len(ali)):
|
||||
if sp.id_to_piece(ali[i]).startswith(underscore):
|
||||
ans.append(i)
|
||||
return ans
|
326
egs/librispeech/ASR/transducer_stateless/compute_ali.py
Executable file
326
egs/librispeech/ASR/transducer_stateless/compute_ali.py
Executable file
@ -0,0 +1,326 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Usage:
|
||||
./transducer_stateless/compute_ali.py \
|
||||
--exp-dir ./transducer_stateless/exp \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--max-duration 300 \
|
||||
--dataset train-clean-100 \
|
||||
--out-dir data/ali
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from alignment import force_alignment
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
from lhotse import CutSet
|
||||
from lhotse.features.io import FeaturesWriter, NumpyHdf5Writer
|
||||
from train import get_params, get_transducer_model
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.utils import AttributeDict, setup_logger
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--epoch",
|
||||
type=int,
|
||||
default=34,
|
||||
help="It specifies the checkpoint to use for decoding."
|
||||
"Note: Epoch counts from 0.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch'. ",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="transducer_stateless/exp",
|
||||
help="The experiment dir",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--out-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""Output directory.
|
||||
It contains 2 generated files:
|
||||
|
||||
- token_ali_xxx.h5
|
||||
- cuts_xxx.json.gz
|
||||
|
||||
where xxx is the value of `--dataset`. For instance, if
|
||||
`--dataset` is `train-clean-100`, it will contain 2 files:
|
||||
|
||||
- `token_ali_train-clean-100.h5`
|
||||
- `cuts_train-clean-100.json.gz`
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""The name of the dataset to compute alignments for.
|
||||
Possible values are:
|
||||
- test-clean.
|
||||
- test-other
|
||||
- train-clean-100
|
||||
- train-clean-360
|
||||
- train-other-500
|
||||
- dev-clean
|
||||
- dev-other
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--beam-size",
|
||||
type=int,
|
||||
default=4,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--context-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="The context size in the decoder. 1 means bigram; "
|
||||
"2 means tri-gram",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def compute_alignments(
|
||||
model: torch.nn.Module,
|
||||
dl: torch.utils.data,
|
||||
ali_writer: FeaturesWriter,
|
||||
params: AttributeDict,
|
||||
sp: spm.SentencePieceProcessor,
|
||||
):
|
||||
try:
|
||||
num_batches = len(dl)
|
||||
except TypeError:
|
||||
num_batches = "?"
|
||||
num_cuts = 0
|
||||
|
||||
device = model.device
|
||||
cuts = []
|
||||
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
feature = batch["inputs"]
|
||||
|
||||
# at entry, feature is [N, T, C]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
cut_list = supervisions["cut"]
|
||||
for cut in cut_list:
|
||||
assert len(cut.supervisions) == 1, f"{len(cut.supervisions)}"
|
||||
|
||||
feature_lens = supervisions["num_frames"].to(device)
|
||||
|
||||
encoder_out, encoder_out_lens = model.encoder(
|
||||
x=feature, x_lens=feature_lens
|
||||
)
|
||||
|
||||
batch_size = encoder_out.size(0)
|
||||
|
||||
texts = supervisions["text"]
|
||||
|
||||
ys_list: List[List[int]] = sp.encode(texts, out_type=int)
|
||||
|
||||
ali_list = []
|
||||
for i in range(batch_size):
|
||||
# fmt: off
|
||||
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
|
||||
# fmt: on
|
||||
|
||||
ali = force_alignment(
|
||||
model=model,
|
||||
encoder_out=encoder_out_i,
|
||||
ys=ys_list[i],
|
||||
beam_size=params.beam_size,
|
||||
)
|
||||
ali_list.append(ali)
|
||||
assert len(ali_list) == len(cut_list)
|
||||
|
||||
for cut, ali in zip(cut_list, ali_list):
|
||||
cut.token_alignment = ali_writer.store_array(
|
||||
key=cut.id,
|
||||
value=np.asarray(ali, dtype=np.int32),
|
||||
# frame shift is 0.01s, subsampling_factor is 4
|
||||
frame_shift=0.04,
|
||||
temporal_dim=0,
|
||||
start=0,
|
||||
)
|
||||
|
||||
cuts += cut_list
|
||||
|
||||
num_cuts += len(cut_list)
|
||||
|
||||
if batch_idx % 2 == 0:
|
||||
batch_str = f"{batch_idx}/{num_batches}"
|
||||
|
||||
logging.info(
|
||||
f"batch {batch_str}, cuts processed until now is {num_cuts}"
|
||||
)
|
||||
|
||||
return CutSet.from_cuts(cuts)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
LibriSpeechAsrDataModule.add_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
args.enable_spec_aug = False
|
||||
args.enable_musan = False
|
||||
args.return_cuts = True
|
||||
args.concatenate_cuts = False
|
||||
|
||||
params = get_params()
|
||||
params.update(vars(args))
|
||||
|
||||
setup_logger(f"{params.exp_dir}/log-ali")
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(params.bpe_model)
|
||||
|
||||
# <blk> is defined in local/train_bpe_model.py
|
||||
params.blank_id = sp.piece_to_id("<blk>")
|
||||
params.vocab_size = sp.get_piece_size()
|
||||
|
||||
logging.info(f"Computing alignments for {params.dataset} - started")
|
||||
logging.info(params)
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
logging.info(f"Device: {device}")
|
||||
|
||||
out_dir = Path(params.out_dir)
|
||||
out_dir.mkdir(exist_ok=True)
|
||||
|
||||
out_ali_filename = out_dir / f"token_ali_{params.dataset}.h5"
|
||||
out_manifest_filename = out_dir / f"cuts_{params.dataset}.json.gz"
|
||||
|
||||
done_file = out_dir / f".{params.dataset}.done"
|
||||
if done_file.is_file():
|
||||
logging.info(f"{done_file} exists - skipping")
|
||||
exit()
|
||||
|
||||
logging.info("About to create model")
|
||||
model = get_transducer_model(params)
|
||||
|
||||
if params.avg == 1:
|
||||
load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model)
|
||||
else:
|
||||
start = params.epoch - params.avg + 1
|
||||
filenames = []
|
||||
for i in range(start, params.epoch + 1):
|
||||
if start >= 0:
|
||||
filenames.append(f"{params.exp_dir}/epoch-{i}.pt")
|
||||
logging.info(f"averaging {filenames}")
|
||||
model.to(device)
|
||||
model.load_state_dict(
|
||||
average_checkpoints(filenames, device=device), strict=False
|
||||
)
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model.device = device
|
||||
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
if params.dataset == "test-clean":
|
||||
test_clean_cuts = librispeech.test_clean_cuts()
|
||||
dl = librispeech.test_dataloaders(test_clean_cuts)
|
||||
elif params.dataset == "test-other":
|
||||
test_other_cuts = librispeech.test_other_cuts()
|
||||
dl = librispeech.test_dataloaders(test_other_cuts)
|
||||
elif params.dataset == "train-clean-100":
|
||||
train_clean_100_cuts = librispeech.train_clean_100_cuts()
|
||||
dl = librispeech.train_dataloaders(train_clean_100_cuts)
|
||||
elif params.dataset == "train-clean-360":
|
||||
train_clean_360_cuts = librispeech.train_clean_360_cuts()
|
||||
dl = librispeech.train_dataloaders(train_clean_360_cuts)
|
||||
elif params.dataset == "train-other-500":
|
||||
train_other_500_cuts = librispeech.train_other_500_cuts()
|
||||
dl = librispeech.train_dataloaders(train_other_500_cuts)
|
||||
elif params.dataset == "dev-clean":
|
||||
dev_clean_cuts = librispeech.dev_clean_cuts()
|
||||
dl = librispeech.valid_dataloaders(dev_clean_cuts)
|
||||
else:
|
||||
assert params.dataset == "dev-other", f"{params.dataset}"
|
||||
dev_other_cuts = librispeech.dev_other_cuts()
|
||||
dl = librispeech.valid_dataloaders(dev_other_cuts)
|
||||
|
||||
logging.info(f"Processing {params.dataset}")
|
||||
|
||||
with NumpyHdf5Writer(out_ali_filename) as ali_writer:
|
||||
cut_set = compute_alignments(
|
||||
model=model,
|
||||
dl=dl,
|
||||
ali_writer=ali_writer,
|
||||
params=params,
|
||||
sp=sp,
|
||||
)
|
||||
|
||||
cut_set.to_file(out_manifest_filename)
|
||||
|
||||
logging.info(
|
||||
f"For dataset {params.dataset}, its framewise token alignments are "
|
||||
f"saved to {out_ali_filename} and the cut manifest "
|
||||
f"file is {out_manifest_filename}. Number of cuts: {len(cut_set)}"
|
||||
)
|
||||
done_file.touch()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
167
egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
Executable file
167
egs/librispeech/ASR/transducer_stateless/test_compute_ali.py
Executable file
@ -0,0 +1,167 @@
|
||||
#!/usr/bin/env python3
|
||||
# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang)
|
||||
#
|
||||
# See ../../../../LICENSE for clarification regarding multiple authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
"""
|
||||
This script shows how to get word starting time
|
||||
from framewise token alignment.
|
||||
|
||||
Usage:
|
||||
./transducer_stateless/compute_ali.py \
|
||||
--exp-dir ./transducer_stateless/exp \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--epoch 20 \
|
||||
--avg 10 \
|
||||
--max-duration 300 \
|
||||
--dataset train-clean-100 \
|
||||
--out-dir data/ali
|
||||
|
||||
And the you can run:
|
||||
|
||||
./transducer_stateless/test_compute_ali.py \
|
||||
--bpe-model ./data/lang_bpe_500/bpe.model \
|
||||
--ali-dir data/ali \
|
||||
--dataset train-clean-100
|
||||
"""
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
from alignment import get_word_starting_frames
|
||||
from lhotse import CutSet, load_manifest
|
||||
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
|
||||
from lhotse.dataset.collation import collate_custom_field
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-model",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/bpe.model",
|
||||
help="Path to the BPE model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ali-dir",
|
||||
type=Path,
|
||||
default="./data/ali",
|
||||
help="It specifies the directory where alignments can be found.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
required=True,
|
||||
help="""The name of the dataset:
|
||||
Possible values are:
|
||||
- test-clean.
|
||||
- test-other
|
||||
- train-clean-100
|
||||
- train-clean-360
|
||||
- train-other-500
|
||||
- dev-clean
|
||||
- dev-other
|
||||
""",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
|
||||
sp = spm.SentencePieceProcessor()
|
||||
sp.load(args.bpe_model)
|
||||
|
||||
cuts_json = args.ali_dir / f"cuts_{args.dataset}.json.gz"
|
||||
|
||||
logging.info(f"Loading {cuts_json}")
|
||||
cuts = load_manifest(cuts_json)
|
||||
|
||||
sampler = SingleCutSampler(
|
||||
cuts,
|
||||
max_duration=30,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
dataset = K2SpeechRecognitionDataset(return_cuts=True)
|
||||
|
||||
dl = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
sampler=sampler,
|
||||
batch_size=None,
|
||||
num_workers=1,
|
||||
persistent_workers=False,
|
||||
)
|
||||
|
||||
frame_shift = 10 # ms
|
||||
subsampling_factor = 4
|
||||
|
||||
frame_shift_in_second = frame_shift * subsampling_factor / 1000.0
|
||||
|
||||
# key: cut.id
|
||||
# value: a list of pairs (word, time_in_second)
|
||||
word_starting_time_dict = {}
|
||||
for batch in dl:
|
||||
supervisions = batch["supervisions"]
|
||||
cuts = supervisions["cut"]
|
||||
|
||||
token_alignment, token_alignment_length = collate_custom_field(
|
||||
CutSet.from_cuts(cuts), "token_alignment"
|
||||
)
|
||||
|
||||
for i in range(len(cuts)):
|
||||
assert (
|
||||
(cuts[i].features.num_frames - 1) // 2 - 1
|
||||
) // 2 == token_alignment_length[i]
|
||||
|
||||
word_starting_frames = get_word_starting_frames(
|
||||
token_alignment[i, : token_alignment_length[i]].tolist(), sp=sp
|
||||
)
|
||||
word_starting_time = [
|
||||
"{:.2f}".format(i * frame_shift_in_second)
|
||||
for i in word_starting_frames
|
||||
]
|
||||
|
||||
words = supervisions["text"][i].split()
|
||||
|
||||
assert len(word_starting_frames) == len(words)
|
||||
word_starting_time_dict[cuts[i].id] = list(
|
||||
zip(words, word_starting_time)
|
||||
)
|
||||
|
||||
# This is a demo script and we exit here after processing
|
||||
# one batch.
|
||||
# You can find word starting time in the dict "word_starting_time_dict"
|
||||
for cut_id, word_time in word_starting_time_dict.items():
|
||||
print(f"{cut_id}\n{word_time}\n")
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
formatter = (
|
||||
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
|
||||
)
|
||||
|
||||
logging.basicConfig(format=formatter, level=logging.INFO)
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user