do some changes

This commit is contained in:
luomingshuang 2022-05-16 20:53:19 +08:00
parent d184599ff8
commit cd3bf0a1fd
10 changed files with 739 additions and 176 deletions

View File

@ -0,0 +1,39 @@
Note: This recipe is trained with the codes from this PR https://github.com/k2-fsa/icefall/pull/355
And the SpecAugment codes from this PR https://github.com/lhotse-speech/lhotse/pull/604.
# Pre-trained Transducer-Stateless2 models for the Aidatatang_200zh dataset with icefall.
The model was trained on full [Aidatatang_200zh](https://www.openslr.org/62) with the scripts in [icefall](https://github.com/k2-fsa/icefall) based on the latest version k2.
## Training procedure
The main repositories are list below, we will update the training and decoding scripts with the update of version.
k2: https://github.com/k2-fsa/k2
icefall: https://github.com/k2-fsa/icefall
lhotse: https://github.com/lhotse-speech/lhotse
* Install k2 and lhotse, k2 installation guide refers to https://k2.readthedocs.io/en/latest/installation/index.html, lhotse refers to https://lhotse.readthedocs.io/en/latest/getting-started.html#installation. I think the latest version would be ok. And please also install the requirements listed in icefall.
* Clone icefall(https://github.com/k2-fsa/icefall) and check to the commit showed above.
```
git clone https://github.com/k2-fsa/icefall
cd icefall
```
* Preparing data.
```
cd egs/aidatatang_200zh/ASR
bash ./prepare.sh
```
* Training
```
export CUDA_VISIBLE_DEVICES="0,1"
./pruned_transducer_stateless2/train.py \
--world-size 2 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 250
```
## Evaluation results
The decoding results (WER%) on Aidatatang_200zh(dev and test) are listed below, we got this result by averaging models from epoch 11 to 29.
The WERs are
| | dev | test | comment |
|------------------------------------|------------|------------|------------------------------------------|
| greedy search | 5.53 | 6.59 | --epoch 29, --avg 19, --max-duration 100 |
| modified beam search (beam size 4) | 5.28 | 6.32 | --epoch 29, --avg 19, --max-duration 100 |
| fast beam search (set as default) | 5.29 | 6.33 | --epoch 29, --avg 19, --max-duration 1500|

View File

@ -0,0 +1,72 @@
## Results
### Aidatatang_200zh Char training results (Pruned Transducer Stateless2)
#### 2022-05-16
Using the codes from this PR https://github.com/k2-fsa/icefall/pull/355.
The WERs are
| | dev | test | comment |
|------------------------------------|------------|------------|------------------------------------------|
| greedy search | 5.53 | 6.59 | --epoch 29, --avg 19, --max-duration 100 |
| modified beam search (beam size 4) | 5.28 | 6.32 | --epoch 29, --avg 19, --max-duration 100 |
| fast beam search (set as default) | 5.29 | 6.33 | --epoch 29, --avg 19, --max-duration 1500|
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0, 1"
./pruned_transducer_stateless2/train.py \
--world-size 2 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 250 \
--save-every-n 1000
```
The tensorboard training log can be found at
https://tensorboard.dev/experiment/VpA8b7SZQ7CEjZs9WZ5HNA/#scalars
The decoding command is:
```
epoch=29
avg=19
## greedy search
./pruned_transducer_stateless2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless2/exp \
--lang-dir ./data/lang_char \
--max-duration 100
## modified beam search
./pruned_transducer_stateless2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless2/exp \
--lang-dir ./data/lang_char \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
## fast beam search
./pruned_transducer_stateless2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir ./data/lang_char \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
```
A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_aidatatang-200zh_pruned_transducer_stateless2>

View File

@ -1,72 +0,0 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from lhotse import CutSet
from lhotse.recipes.utils import read_manifests_if_cached
def preprocess_aidatatang_200zh():
src_dir = Path("data/manifests/aidatatang_200zh")
output_dir = Path("data/fbank/aidatatang_200zh")
output_dir.mkdir(exist_ok=True, parents=True)
dataset_parts = (
"train",
"test",
"dev",
)
logging.info("Loading manifest")
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
)
assert len(manifests) > 0
for partition, m in manifests.items():
logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
if raw_cuts_path.is_file():
logging.info(f"{partition} already exists - skipping")
continue
for sup in m["supervisions"]:
sup.custom = {"origin": "aidatatang_200zh"}
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
logging.info(f"Saving to {raw_cuts_path}")
cut_set.to_file(raw_cuts_path)
def main():
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
preprocess_aidatatang_200zh()
if __name__ == "__main__":
main()

View File

@ -186,8 +186,6 @@ def main():
for z in a:
a_flat.append("".join(z))
# a_chars = [z.replace(" ", args.space) for z in a_flat]
a_chars = [z for z in a_flat]
a_chars = "".join(a_flat)
print(a_chars)
line = f.readline()

View File

@ -413,6 +413,6 @@ class Aidatatang_200zhAsrDataModule:
return load_manifest(self.args.manifest_dir / "cuts_dev.json.gz")
@lru_cache()
def test_net_cuts(self) -> List[CutSet]:
def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts")
return load_manifest(self.args.manifest_dir / "cuts_test.json.gz")

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional
@ -21,11 +22,11 @@ import k2
import torch
from model import Transducer
from icefall.decode import one_best_decoding
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts
def fast_beam_search(
def fast_beam_search_one_best(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
@ -35,7 +36,8 @@ def fast_beam_search(
max_contexts: int,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
Args:
model:
An instance of `Transducer`.
@ -55,6 +57,143 @@ def fast_beam_search(
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest_oracle(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
ref_texts: List[List[int]],
use_double_scores: bool = True,
nbest_scale: float = 0.5,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
we select `num_paths` linear paths from the lattice. The path
that has the minimum edit distance with the given reference transcript
is used as the output.
This is the best result we can achieve for any nbest based rescoring
methods.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
ref_texts:
A list-of-list of integers containing the reference transcripts.
If the decoding_graph is a trivial_graph, the integer ID is the
BPE token ID.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
hyps = nbest.build_levenshtein_graphs()
refs = k2.levenshtein_graph(ref_texts, device=hyps.device)
levenshtein_alignment = k2.levenshtein_alignment(
refs=refs,
hyps=hyps,
hyp_to_ref_map=nbest.shape.row_ids(1),
sorted_match_ref=True,
)
tot_scores = levenshtein_alignment.get_tot_scores(
use_double_scores=False, log_semiring=False
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
Returns:
Return an FsaVec with axes [utt][state][arc] containing the decoded
lattice. Note: When the input graph is a TrivialGraph, the returned
lattice is actually an acceptor.
"""
assert encoder_out.ndim == 3
context_size = model.decoder.context_size
@ -103,9 +242,7 @@ def fast_beam_search(
decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist())
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
return lattice
def greedy_search(
@ -130,8 +267,9 @@ def greedy_search(
blank_id = model.decoder.blank_id
context_size = model.decoder.context_size
unk_id = getattr(model, "unk_id", blank_id)
device = model.device
device = next(model.parameters()).device
decoder_input = torch.tensor(
[blank_id] * context_size, device=device, dtype=torch.int64
@ -170,7 +308,7 @@ def greedy_search(
# logits is (1, 1, 1, vocab_size)
y = logits.argmax().item()
if y != blank_id:
if y not in (blank_id, unk_id):
hyp.append(y)
decoder_input = torch.tensor(
[hyp[-context_size:]], device=device
@ -190,7 +328,9 @@ def greedy_search(
def greedy_search_batch(
model: Transducer, encoder_out: torch.Tensor
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args:
@ -198,6 +338,9 @@ def greedy_search_batch(
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1.
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
Returns:
Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0).
@ -205,30 +348,49 @@ def greedy_search_batch(
assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
device = next(model.parameters()).device
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)]
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor(
hyps,
device=device,
dtype=torch.int64,
) # (batch_size, context_size)
) # (N, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
encoder_out = model.joiner.encoder_proj(encoder_out)
# decoder_out: (N, 1, decoder_out_dim)
# decoder_out: (batch_size, 1, decoder_out_dim)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False
)
@ -239,12 +401,12 @@ def greedy_search_batch(
y = logits.argmax(dim=1).tolist()
emitted = False
for i, v in enumerate(y):
if v != blank_id:
if v not in (blank_id, unk_id):
hyps[i].append(v)
emitted = True
if emitted:
# update decoder output
decoder_input = [h[-context_size:] for h in hyps]
decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor(
decoder_input,
device=device,
@ -253,7 +415,12 @@ def greedy_search_batch(
decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out)
ans = [h[context_size:] for h in hyps]
sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@ -291,10 +458,8 @@ class HypothesisList(object):
def add(self, hyp: Hypothesis) -> None:
"""Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using
`log-sum-exp` with the existed one.
Args:
hyp:
The hypothesis to be added.
@ -311,7 +476,6 @@ class HypothesisList(object):
def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
"""Get the most probable hypothesis, i.e., the one with
the largest `log_prob`.
Args:
length_norm:
If True, the `log_prob` of a hypothesis is normalized by the
@ -328,10 +492,8 @@ class HypothesisList(object):
def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis.
Caution:
`self` is modified **in-place**.
Args:
hyp:
The hypothesis to be removed from `self`.
@ -344,10 +506,8 @@ class HypothesisList(object):
def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold.
Caution:
`self` is not modified. Instead, a new HypothesisList is returned.
Returns:
Return a new HypothesisList containing all hypotheses from `self`
with `log_prob` being greater than the given `threshold`.
@ -385,7 +545,6 @@ class HypothesisList(object):
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps].
Args:
hyps:
len(hyps) == batch_size. It contains the current hypothesis for
@ -411,15 +570,18 @@ def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
def modified_beam_search(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam:
Number of active paths during the beam search.
Returns:
@ -427,15 +589,26 @@ def modified_beam_search(
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
batch_size = encoder_out.size(0)
T = encoder_out.size(1)
assert encoder_out.size(0) >= 1, encoder_out.size(0)
packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size):
device = next(model.parameters()).device
batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
@ -443,11 +616,20 @@ def modified_beam_search(
)
)
encoder_out = model.joiner.encoder_proj(encoder_out)
encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa
offset = 0
finalized_B = []
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = _get_hyps_shape(B).to(device)
@ -503,6 +685,8 @@ def modified_beam_search(
for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
@ -512,15 +696,21 @@ def modified_beam_search(
new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token != blank_id:
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp)
best_hyps = [b.get_most_probable(length_norm=True) for b in B]
ans = [h.ys[context_size:] for h in best_hyps]
B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=False) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans
@ -531,12 +721,9 @@ def _deprecated_modified_beam_search(
beam: int = 4,
) -> List[int]:
"""It limits the maximum number of symbols per frame to 1.
It decodes only one utterance at a time. We keep it only for reference.
The function :func:`modified_beam_search` should be preferred as it
supports batch decoding.
Args:
model:
An instance of `Transducer`.
@ -553,9 +740,10 @@ def _deprecated_modified_beam_search(
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
device = next(model.parameters()).device
T = encoder_out.size(1)
@ -614,6 +802,8 @@ def _deprecated_modified_beam_search(
topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
@ -621,7 +811,7 @@ def _deprecated_modified_beam_search(
hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:]
new_token = topk_token_indexes[i]
if new_token != blank_id:
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
@ -640,9 +830,7 @@ def beam_search(
) -> List[int]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
espnet/nets/beam_search_transducer.py#L247 is used as a reference.
Args:
model:
An instance of `Transducer`.
@ -658,9 +846,10 @@ def beam_search(
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = model.device
device = next(model.parameters()).device
decoder_input = torch.tensor(
[blank_id] * context_size,
@ -743,7 +932,7 @@ def beam_search(
# Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1)
for i, v in zip(indices.tolist(), values.tolist()):
if i == blank_id:
if i in (blank_id, unk_id):
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v

View File

@ -59,10 +59,10 @@ from typing import Dict, List, Optional, Tuple
import k2
import torch
import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule
from asr_datamodule import Aidatatang_200zhAsrDataModule
from beam_search import (
beam_search,
fast_beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
@ -255,7 +255,7 @@ def decode_one_batch(
hyps = []
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search(
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
@ -273,6 +273,7 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
@ -280,6 +281,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for i in range(encoder_out.size(0)):
@ -359,12 +361,12 @@ def decode_dataset(
if params.decoding_method == "greedy_search":
log_interval = 100
else:
log_interval = 2
log_interval = 50
results = defaultdict(list)
for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"]
texts = [list(str(text)) for text in texts]
texts = [list(str(text).replace(" ", "")) for text in texts]
hyps_dict = decode_one_batch(
params=params,
@ -440,7 +442,7 @@ def save_results(
@torch.no_grad()
def main():
parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser)
Aidatatang_200zhAsrDataModule.add_arguments(parser)
args = parser.parse_args()
args.exp_dir = Path(args.exp_dir)
@ -506,6 +508,13 @@ def main():
model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device))
average = average_checkpoints(filenames, device=device)
checkpoint = {"model": average}
torch.save(
checkpoint,
"pruned_transducer_stateless2/pretrained_average_11_to_29.pt",
)
model.to(device)
model.eval()
model.device = device
@ -526,33 +535,26 @@ def main():
from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset
wenetspeech = WenetSpeechAsrDataModule(args)
aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
dev = "dev"
test_net = "test_net"
test_meet = "test_meet"
test = "test"
if not os.path.exists(f"{dev}/shared-0.tar"):
dev_cuts = wenetspeech.valid_cuts()
os.makedirs(dev)
dev_cuts = aidatatang_200zh.valid_cuts()
export_to_webdataset(
dev_cuts,
output_path=f"{dev}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test_net}/shared-0.tar"):
test_net_cuts = wenetspeech.test_net_cuts()
if not os.path.exists(f"{test}/shared-0.tar"):
os.makedirs(test)
test_cuts = aidatatang_200zh.test_cuts()
export_to_webdataset(
test_net_cuts,
output_path=f"{test_net}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test_meet}/shared-0.tar"):
test_meeting_cuts = wenetspeech.test_meeting_cuts()
export_to_webdataset(
test_meeting_cuts,
output_path=f"{test_meet}/shared-%d.tar",
test_cuts,
output_path=f"{test}/shared-%d.tar",
shard_size=300,
)
@ -567,34 +569,22 @@ def main():
shuffle_shards=True,
)
test_net_shards = [
test_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar")))
for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
]
cuts_test_net_webdataset = CutSet.from_webdataset(
test_net_shards,
cuts_test_webdataset = CutSet.from_webdataset(
test_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
test_meet_shards = [
str(path)
for path in sorted(glob.glob(os.path.join(test_meet, "shared-*.tar")))
]
cuts_test_meet_webdataset = CutSet.from_webdataset(
test_meet_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
dev_dl = aidatatang_200zh.valid_dataloaders(cuts_dev_webdataset)
test_dl = aidatatang_200zh.test_dataloaders(cuts_test_webdataset)
dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset)
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset)
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meet_webdataset)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
test_sets = ["dev", "test"]
test_dl = [dev_dl, test_dl]
for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset(

View File

@ -21,8 +21,8 @@ Usage:
./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--epoch 20 \
--avg 10
--epoch 29 \
--avg 19
It will generate a file exp_dir/pretrained.pt
@ -32,7 +32,7 @@ you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/wenetspeech/ASR
cd /path/to/egs/aidatatang_200zh/ASR
./pruned_transducer_stateless2/decode.py \
--exp-dir ./pruned_transducer_stateless2/exp \
--epoch 9999 \

View File

@ -0,0 +1,347 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# 2022 Xiaomi Crop. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--lang-dir ./data/lang_char \
--method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--lang-dir ./data/lang_char \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--lang-dir ./data/lang_char \
--method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless2/exp/pretrained.pt is generated by
./pruned_transducer_stateless2/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--lang-dir",
type=str,
help="""Path to lang.
""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search ",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
with torch.no_grad():
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
hyps = []
msg = f"Using {params.decoding_method}"
logging.info(msg)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[idx] for idx in hyp])
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -103,7 +103,7 @@ def get_parser():
parser.add_argument(
"--master-port",
type=int,
default=12354,
default=12359,
help="Master port to use for DDP training.",
)