do some changes

This commit is contained in:
luomingshuang 2022-05-16 20:53:19 +08:00
parent d184599ff8
commit cd3bf0a1fd
10 changed files with 739 additions and 176 deletions

View File

@ -0,0 +1,39 @@
Note: This recipe is trained with the codes from this PR https://github.com/k2-fsa/icefall/pull/355
And the SpecAugment codes from this PR https://github.com/lhotse-speech/lhotse/pull/604.
# Pre-trained Transducer-Stateless2 models for the Aidatatang_200zh dataset with icefall.
The model was trained on full [Aidatatang_200zh](https://www.openslr.org/62) with the scripts in [icefall](https://github.com/k2-fsa/icefall) based on the latest version k2.
## Training procedure
The main repositories are list below, we will update the training and decoding scripts with the update of version.
k2: https://github.com/k2-fsa/k2
icefall: https://github.com/k2-fsa/icefall
lhotse: https://github.com/lhotse-speech/lhotse
* Install k2 and lhotse, k2 installation guide refers to https://k2.readthedocs.io/en/latest/installation/index.html, lhotse refers to https://lhotse.readthedocs.io/en/latest/getting-started.html#installation. I think the latest version would be ok. And please also install the requirements listed in icefall.
* Clone icefall(https://github.com/k2-fsa/icefall) and check to the commit showed above.
```
git clone https://github.com/k2-fsa/icefall
cd icefall
```
* Preparing data.
```
cd egs/aidatatang_200zh/ASR
bash ./prepare.sh
```
* Training
```
export CUDA_VISIBLE_DEVICES="0,1"
./pruned_transducer_stateless2/train.py \
--world-size 2 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 250
```
## Evaluation results
The decoding results (WER%) on Aidatatang_200zh(dev and test) are listed below, we got this result by averaging models from epoch 11 to 29.
The WERs are
| | dev | test | comment |
|------------------------------------|------------|------------|------------------------------------------|
| greedy search | 5.53 | 6.59 | --epoch 29, --avg 19, --max-duration 100 |
| modified beam search (beam size 4) | 5.28 | 6.32 | --epoch 29, --avg 19, --max-duration 100 |
| fast beam search (set as default) | 5.29 | 6.33 | --epoch 29, --avg 19, --max-duration 1500|

View File

@ -0,0 +1,72 @@
## Results
### Aidatatang_200zh Char training results (Pruned Transducer Stateless2)
#### 2022-05-16
Using the codes from this PR https://github.com/k2-fsa/icefall/pull/355.
The WERs are
| | dev | test | comment |
|------------------------------------|------------|------------|------------------------------------------|
| greedy search | 5.53 | 6.59 | --epoch 29, --avg 19, --max-duration 100 |
| modified beam search (beam size 4) | 5.28 | 6.32 | --epoch 29, --avg 19, --max-duration 100 |
| fast beam search (set as default) | 5.29 | 6.33 | --epoch 29, --avg 19, --max-duration 1500|
The training command for reproducing is given below:
```
export CUDA_VISIBLE_DEVICES="0, 1"
./pruned_transducer_stateless2/train.py \
--world-size 2 \
--num-epochs 30 \
--start-epoch 0 \
--exp-dir pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \
--max-duration 250 \
--save-every-n 1000
```
The tensorboard training log can be found at
https://tensorboard.dev/experiment/VpA8b7SZQ7CEjZs9WZ5HNA/#scalars
The decoding command is:
```
epoch=29
avg=19
## greedy search
./pruned_transducer_stateless2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless2/exp \
--lang-dir ./data/lang_char \
--max-duration 100
## modified beam search
./pruned_transducer_stateless2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir pruned_transducer_stateless2/exp \
--lang-dir ./data/lang_char \
--max-duration 100 \
--decoding-method modified_beam_search \
--beam-size 4
## fast beam search
./pruned_transducer_stateless2/decode.py \
--epoch $epoch \
--avg $avg \
--exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir ./data/lang_char \
--max-duration 1500 \
--decoding-method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8
```
A pre-trained model and decoding logs can be found at <https://huggingface.co/luomingshuang/icefall_asr_aidatatang-200zh_pruned_transducer_stateless2>

View File

@ -1,72 +0,0 @@
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (Fangjun Kuang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from pathlib import Path
from lhotse import CutSet
from lhotse.recipes.utils import read_manifests_if_cached
def preprocess_aidatatang_200zh():
src_dir = Path("data/manifests/aidatatang_200zh")
output_dir = Path("data/fbank/aidatatang_200zh")
output_dir.mkdir(exist_ok=True, parents=True)
dataset_parts = (
"train",
"test",
"dev",
)
logging.info("Loading manifest")
manifests = read_manifests_if_cached(
dataset_parts=dataset_parts,
output_dir=src_dir,
)
assert len(manifests) > 0
for partition, m in manifests.items():
logging.info(f"Processing {partition}")
raw_cuts_path = output_dir / f"cuts_{partition}_raw.jsonl.gz"
if raw_cuts_path.is_file():
logging.info(f"{partition} already exists - skipping")
continue
for sup in m["supervisions"]:
sup.custom = {"origin": "aidatatang_200zh"}
cut_set = CutSet.from_manifests(
recordings=m["recordings"],
supervisions=m["supervisions"],
)
logging.info(f"Saving to {raw_cuts_path}")
cut_set.to_file(raw_cuts_path)
def main():
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
preprocess_aidatatang_200zh()
if __name__ == "__main__":
main()

View File

@ -186,8 +186,6 @@ def main():
for z in a: for z in a:
a_flat.append("".join(z)) a_flat.append("".join(z))
# a_chars = [z.replace(" ", args.space) for z in a_flat]
a_chars = [z for z in a_flat]
a_chars = "".join(a_flat) a_chars = "".join(a_flat)
print(a_chars) print(a_chars)
line = f.readline() line = f.readline()

View File

@ -413,6 +413,6 @@ class Aidatatang_200zhAsrDataModule:
return load_manifest(self.args.manifest_dir / "cuts_dev.json.gz") return load_manifest(self.args.manifest_dir / "cuts_dev.json.gz")
@lru_cache() @lru_cache()
def test_net_cuts(self) -> List[CutSet]: def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts") logging.info("About to get test cuts")
return load_manifest(self.args.manifest_dir / "cuts_test.json.gz") return load_manifest(self.args.manifest_dir / "cuts_test.json.gz")

View File

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional from typing import Dict, List, Optional
@ -21,11 +22,11 @@ import k2
import torch import torch
from model import Transducer from model import Transducer
from icefall.decode import one_best_decoding from icefall.decode import Nbest, one_best_decoding
from icefall.utils import get_texts from icefall.utils import get_texts
def fast_beam_search( def fast_beam_search_one_best(
model: Transducer, model: Transducer,
decoding_graph: k2.Fsa, decoding_graph: k2.Fsa,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
@ -35,7 +36,8 @@ def fast_beam_search(
max_contexts: int, max_contexts: int,
) -> List[List[int]]: ) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
the shortest path within the lattice is used as the final output.
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.
@ -55,6 +57,143 @@ def fast_beam_search(
Returns: Returns:
Return the decoded result. Return the decoded result.
""" """
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
best_path = one_best_decoding(lattice)
hyps = get_texts(best_path)
return hyps
def fast_beam_search_nbest_oracle(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
num_paths: int,
ref_texts: List[List[int]],
use_double_scores: bool = True,
nbest_scale: float = 0.5,
) -> List[List[int]]:
"""It limits the maximum number of symbols per frame to 1.
A lattice is first obtained using modified beam search, and then
we select `num_paths` linear paths from the lattice. The path
that has the minimum edit distance with the given reference transcript
is used as the output.
This is the best result we can achieve for any nbest based rescoring
methods.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
num_paths:
Number of paths to extract from the decoded lattice.
ref_texts:
A list-of-list of integers containing the reference transcripts.
If the decoding_graph is a trivial_graph, the integer ID is the
BPE token ID.
use_double_scores:
True to use double precision for computation. False to use
single precision.
nbest_scale:
It's the scale applied to the lattice.scores. A smaller value
yields more unique paths.
Returns:
Return the decoded result.
"""
lattice = fast_beam_search(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=beam,
max_states=max_states,
max_contexts=max_contexts,
)
nbest = Nbest.from_lattice(
lattice=lattice,
num_paths=num_paths,
use_double_scores=use_double_scores,
nbest_scale=nbest_scale,
)
hyps = nbest.build_levenshtein_graphs()
refs = k2.levenshtein_graph(ref_texts, device=hyps.device)
levenshtein_alignment = k2.levenshtein_alignment(
refs=refs,
hyps=hyps,
hyp_to_ref_map=nbest.shape.row_ids(1),
sorted_match_ref=True,
)
tot_scores = levenshtein_alignment.get_tot_scores(
use_double_scores=False, log_semiring=False
)
ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores)
max_indexes = ragged_tot_scores.argmax()
best_path = k2.index_fsa(nbest.fsa, max_indexes)
hyps = get_texts(best_path)
return hyps
def fast_beam_search(
model: Transducer,
decoding_graph: k2.Fsa,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: float,
max_states: int,
max_contexts: int,
) -> k2.Fsa:
"""It limits the maximum number of symbols per frame to 1.
Args:
model:
An instance of `Transducer`.
decoding_graph:
Decoding graph used for decoding, may be a TrivialGraph or a HLG.
encoder_out:
A tensor of shape (N, T, C) from the encoder.
encoder_out_lens:
A tensor of shape (N,) containing the number of frames in `encoder_out`
before padding.
beam:
Beam value, similar to the beam used in Kaldi..
max_states:
Max states per stream per frame.
max_contexts:
Max contexts pre stream per frame.
Returns:
Return an FsaVec with axes [utt][state][arc] containing the decoded
lattice. Note: When the input graph is a TrivialGraph, the returned
lattice is actually an acceptor.
"""
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
context_size = model.decoder.context_size context_size = model.decoder.context_size
@ -103,9 +242,7 @@ def fast_beam_search(
decoding_streams.terminate_and_flush_to_streams() decoding_streams.terminate_and_flush_to_streams()
lattice = decoding_streams.format_output(encoder_out_lens.tolist()) lattice = decoding_streams.format_output(encoder_out_lens.tolist())
best_path = one_best_decoding(lattice) return lattice
hyps = get_texts(best_path)
return hyps
def greedy_search( def greedy_search(
@ -130,8 +267,9 @@ def greedy_search(
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
context_size = model.decoder.context_size context_size = model.decoder.context_size
unk_id = getattr(model, "unk_id", blank_id)
device = model.device device = next(model.parameters()).device
decoder_input = torch.tensor( decoder_input = torch.tensor(
[blank_id] * context_size, device=device, dtype=torch.int64 [blank_id] * context_size, device=device, dtype=torch.int64
@ -170,7 +308,7 @@ def greedy_search(
# logits is (1, 1, 1, vocab_size) # logits is (1, 1, 1, vocab_size)
y = logits.argmax().item() y = logits.argmax().item()
if y != blank_id: if y not in (blank_id, unk_id):
hyp.append(y) hyp.append(y)
decoder_input = torch.tensor( decoder_input = torch.tensor(
[hyp[-context_size:]], device=device [hyp[-context_size:]], device=device
@ -190,7 +328,9 @@ def greedy_search(
def greedy_search_batch( def greedy_search_batch(
model: Transducer, encoder_out: torch.Tensor model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
) -> List[List[int]]: ) -> List[List[int]]:
"""Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1.
Args: Args:
@ -198,6 +338,9 @@ def greedy_search_batch(
The transducer model. The transducer model.
encoder_out: encoder_out:
Output from the encoder. Its shape is (N, T, C), where N >= 1. Output from the encoder. Its shape is (N, T, C), where N >= 1.
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
Returns: Returns:
Return a list-of-list of token IDs containing the decoded results. Return a list-of-list of token IDs containing the decoded results.
len(ans) equals to encoder_out.size(0). len(ans) equals to encoder_out.size(0).
@ -205,30 +348,49 @@ def greedy_search_batch(
assert encoder_out.ndim == 3 assert encoder_out.ndim == 3
assert encoder_out.size(0) >= 1, encoder_out.size(0) assert encoder_out.size(0) >= 1, encoder_out.size(0)
device = model.device packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
batch_size = encoder_out.size(0) device = next(model.parameters()).device
T = encoder_out.size(1)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
hyps = [[blank_id] * context_size for _ in range(batch_size)] batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
hyps = [[blank_id] * context_size for _ in range(N)]
decoder_input = torch.tensor( decoder_input = torch.tensor(
hyps, hyps,
device=device, device=device,
dtype=torch.int64, dtype=torch.int64,
) # (batch_size, context_size) ) # (N, context_size)
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out) decoder_out = model.joiner.decoder_proj(decoder_out)
encoder_out = model.joiner.encoder_proj(encoder_out) # decoder_out: (N, 1, decoder_out_dim)
# decoder_out: (batch_size, 1, decoder_out_dim) encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
for t in range(T):
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa offset = 0
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim)
offset = end
decoder_out = decoder_out[:batch_size]
logits = model.joiner( logits = model.joiner(
current_encoder_out, decoder_out.unsqueeze(1), project_input=False current_encoder_out, decoder_out.unsqueeze(1), project_input=False
) )
@ -239,12 +401,12 @@ def greedy_search_batch(
y = logits.argmax(dim=1).tolist() y = logits.argmax(dim=1).tolist()
emitted = False emitted = False
for i, v in enumerate(y): for i, v in enumerate(y):
if v != blank_id: if v not in (blank_id, unk_id):
hyps[i].append(v) hyps[i].append(v)
emitted = True emitted = True
if emitted: if emitted:
# update decoder output # update decoder output
decoder_input = [h[-context_size:] for h in hyps] decoder_input = [h[-context_size:] for h in hyps[:batch_size]]
decoder_input = torch.tensor( decoder_input = torch.tensor(
decoder_input, decoder_input,
device=device, device=device,
@ -253,7 +415,12 @@ def greedy_search_batch(
decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.decoder(decoder_input, need_pad=False)
decoder_out = model.joiner.decoder_proj(decoder_out) decoder_out = model.joiner.decoder_proj(decoder_out)
ans = [h[context_size:] for h in hyps] sorted_ans = [h[context_size:] for h in hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans return ans
@ -291,10 +458,8 @@ class HypothesisList(object):
def add(self, hyp: Hypothesis) -> None: def add(self, hyp: Hypothesis) -> None:
"""Add a Hypothesis to `self`. """Add a Hypothesis to `self`.
If `hyp` already exists in `self`, its probability is updated using If `hyp` already exists in `self`, its probability is updated using
`log-sum-exp` with the existed one. `log-sum-exp` with the existed one.
Args: Args:
hyp: hyp:
The hypothesis to be added. The hypothesis to be added.
@ -311,7 +476,6 @@ class HypothesisList(object):
def get_most_probable(self, length_norm: bool = False) -> Hypothesis: def get_most_probable(self, length_norm: bool = False) -> Hypothesis:
"""Get the most probable hypothesis, i.e., the one with """Get the most probable hypothesis, i.e., the one with
the largest `log_prob`. the largest `log_prob`.
Args: Args:
length_norm: length_norm:
If True, the `log_prob` of a hypothesis is normalized by the If True, the `log_prob` of a hypothesis is normalized by the
@ -328,10 +492,8 @@ class HypothesisList(object):
def remove(self, hyp: Hypothesis) -> None: def remove(self, hyp: Hypothesis) -> None:
"""Remove a given hypothesis. """Remove a given hypothesis.
Caution: Caution:
`self` is modified **in-place**. `self` is modified **in-place**.
Args: Args:
hyp: hyp:
The hypothesis to be removed from `self`. The hypothesis to be removed from `self`.
@ -344,10 +506,8 @@ class HypothesisList(object):
def filter(self, threshold: torch.Tensor) -> "HypothesisList": def filter(self, threshold: torch.Tensor) -> "HypothesisList":
"""Remove all Hypotheses whose log_prob is less than threshold. """Remove all Hypotheses whose log_prob is less than threshold.
Caution: Caution:
`self` is not modified. Instead, a new HypothesisList is returned. `self` is not modified. Instead, a new HypothesisList is returned.
Returns: Returns:
Return a new HypothesisList containing all hypotheses from `self` Return a new HypothesisList containing all hypotheses from `self`
with `log_prob` being greater than the given `threshold`. with `log_prob` being greater than the given `threshold`.
@ -385,7 +545,6 @@ class HypothesisList(object):
def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
"""Return a ragged shape with axes [utt][num_hyps]. """Return a ragged shape with axes [utt][num_hyps].
Args: Args:
hyps: hyps:
len(hyps) == batch_size. It contains the current hypothesis for len(hyps) == batch_size. It contains the current hypothesis for
@ -411,15 +570,18 @@ def _get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape:
def modified_beam_search( def modified_beam_search(
model: Transducer, model: Transducer,
encoder_out: torch.Tensor, encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
beam: int = 4, beam: int = 4,
) -> List[List[int]]: ) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args: Args:
model: model:
The transducer model. The transducer model.
encoder_out: encoder_out:
Output from the encoder. Its shape is (N, T, C). Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam: beam:
Number of active paths during the beam search. Number of active paths during the beam search.
Returns: Returns:
@ -427,15 +589,26 @@ def modified_beam_search(
for the i-th utterance. for the i-th utterance.
""" """
assert encoder_out.ndim == 3, encoder_out.shape assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)
batch_size = encoder_out.size(0) packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
T = encoder_out.size(1) input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = next(model.parameters()).device
B = [HypothesisList() for _ in range(batch_size)]
for i in range(batch_size): batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)
B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add( B[i].add(
Hypothesis( Hypothesis(
ys=[blank_id] * context_size, ys=[blank_id] * context_size,
@ -443,11 +616,20 @@ def modified_beam_search(
) )
) )
encoder_out = model.joiner.encoder_proj(encoder_out) encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)
for t in range(T): offset = 0
current_encoder_out = encoder_out[:, t : t + 1, :].unsqueeze(2) # noqa finalized_B = []
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end
finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]
hyps_shape = _get_hyps_shape(B).to(device) hyps_shape = _get_hyps_shape(B).to(device)
@ -503,8 +685,10 @@ def modified_beam_search(
for i in range(batch_size): for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)
topk_hyp_indexes = (topk_indexes // vocab_size).tolist() with warnings.catch_warnings():
topk_token_indexes = (topk_indexes % vocab_size).tolist() warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()
for k in range(len(topk_hyp_indexes)): for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k] hyp_idx = topk_hyp_indexes[k]
@ -512,15 +696,21 @@ def modified_beam_search(
new_ys = hyp.ys[:] new_ys = hyp.ys[:]
new_token = topk_token_indexes[k] new_token = topk_token_indexes[k]
if new_token != blank_id: if new_token not in (blank_id, unk_id):
new_ys.append(new_token) new_ys.append(new_token)
new_log_prob = topk_log_probs[k] new_log_prob = topk_log_probs[k]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
B[i].add(new_hyp) B[i].add(new_hyp)
best_hyps = [b.get_most_probable(length_norm=True) for b in B] B = B + finalized_B
ans = [h.ys[context_size:] for h in best_hyps] best_hyps = [b.get_most_probable(length_norm=False) for b in B]
sorted_ans = [h.ys[context_size:] for h in best_hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])
return ans return ans
@ -531,12 +721,9 @@ def _deprecated_modified_beam_search(
beam: int = 4, beam: int = 4,
) -> List[int]: ) -> List[int]:
"""It limits the maximum number of symbols per frame to 1. """It limits the maximum number of symbols per frame to 1.
It decodes only one utterance at a time. We keep it only for reference. It decodes only one utterance at a time. We keep it only for reference.
The function :func:`modified_beam_search` should be preferred as it The function :func:`modified_beam_search` should be preferred as it
supports batch decoding. supports batch decoding.
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.
@ -553,9 +740,10 @@ def _deprecated_modified_beam_search(
# support only batch_size == 1 for now # support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0) assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = next(model.parameters()).device
T = encoder_out.size(1) T = encoder_out.size(1)
@ -614,14 +802,16 @@ def _deprecated_modified_beam_search(
topk_hyp_indexes = topk_indexes // logits.size(-1) topk_hyp_indexes = topk_indexes // logits.size(-1)
topk_token_indexes = topk_indexes % logits.size(-1) topk_token_indexes = topk_indexes % logits.size(-1)
topk_hyp_indexes = topk_hyp_indexes.tolist() with warnings.catch_warnings():
topk_token_indexes = topk_token_indexes.tolist() warnings.simplefilter("ignore")
topk_hyp_indexes = topk_hyp_indexes.tolist()
topk_token_indexes = topk_token_indexes.tolist()
for i in range(len(topk_hyp_indexes)): for i in range(len(topk_hyp_indexes)):
hyp = A[topk_hyp_indexes[i]] hyp = A[topk_hyp_indexes[i]]
new_ys = hyp.ys[:] new_ys = hyp.ys[:]
new_token = topk_token_indexes[i] new_token = topk_token_indexes[i]
if new_token != blank_id: if new_token not in (blank_id, unk_id):
new_ys.append(new_token) new_ys.append(new_token)
new_log_prob = topk_log_probs[i] new_log_prob = topk_log_probs[i]
new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob)
@ -640,9 +830,7 @@ def beam_search(
) -> List[int]: ) -> List[int]:
""" """
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
espnet/nets/beam_search_transducer.py#L247 is used as a reference. espnet/nets/beam_search_transducer.py#L247 is used as a reference.
Args: Args:
model: model:
An instance of `Transducer`. An instance of `Transducer`.
@ -658,9 +846,10 @@ def beam_search(
# support only batch_size == 1 for now # support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0) assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size context_size = model.decoder.context_size
device = model.device device = next(model.parameters()).device
decoder_input = torch.tensor( decoder_input = torch.tensor(
[blank_id] * context_size, [blank_id] * context_size,
@ -743,7 +932,7 @@ def beam_search(
# Second, process other non-blank labels # Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1) values, indices = log_prob.topk(beam + 1)
for i, v in zip(indices.tolist(), values.tolist()): for i, v in zip(indices.tolist(), values.tolist()):
if i == blank_id: if i in (blank_id, unk_id):
continue continue
new_ys = y_star.ys + [i] new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + v new_log_prob = y_star.log_prob + v

View File

@ -59,10 +59,10 @@ from typing import Dict, List, Optional, Tuple
import k2 import k2
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import WenetSpeechAsrDataModule from asr_datamodule import Aidatatang_200zhAsrDataModule
from beam_search import ( from beam_search import (
beam_search, beam_search,
fast_beam_search, fast_beam_search_one_best,
greedy_search, greedy_search,
greedy_search_batch, greedy_search_batch,
modified_beam_search, modified_beam_search,
@ -255,7 +255,7 @@ def decode_one_batch(
hyps = [] hyps = []
if params.decoding_method == "fast_beam_search": if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search( hyp_tokens = fast_beam_search_one_best(
model=model, model=model,
decoding_graph=decoding_graph, decoding_graph=decoding_graph,
encoder_out=encoder_out, encoder_out=encoder_out,
@ -273,6 +273,7 @@ def decode_one_batch(
hyp_tokens = greedy_search_batch( hyp_tokens = greedy_search_batch(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
@ -280,6 +281,7 @@ def decode_one_batch(
hyp_tokens = modified_beam_search( hyp_tokens = modified_beam_search(
model=model, model=model,
encoder_out=encoder_out, encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size, beam=params.beam_size,
) )
for i in range(encoder_out.size(0)): for i in range(encoder_out.size(0)):
@ -359,12 +361,12 @@ def decode_dataset(
if params.decoding_method == "greedy_search": if params.decoding_method == "greedy_search":
log_interval = 100 log_interval = 100
else: else:
log_interval = 2 log_interval = 50
results = defaultdict(list) results = defaultdict(list)
for batch_idx, batch in enumerate(dl): for batch_idx, batch in enumerate(dl):
texts = batch["supervisions"]["text"] texts = batch["supervisions"]["text"]
texts = [list(str(text)) for text in texts] texts = [list(str(text).replace(" ", "")) for text in texts]
hyps_dict = decode_one_batch( hyps_dict = decode_one_batch(
params=params, params=params,
@ -440,7 +442,7 @@ def save_results(
@torch.no_grad() @torch.no_grad()
def main(): def main():
parser = get_parser() parser = get_parser()
WenetSpeechAsrDataModule.add_arguments(parser) Aidatatang_200zhAsrDataModule.add_arguments(parser)
args = parser.parse_args() args = parser.parse_args()
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
@ -506,6 +508,13 @@ def main():
model.to(device) model.to(device)
model.load_state_dict(average_checkpoints(filenames, device=device)) model.load_state_dict(average_checkpoints(filenames, device=device))
average = average_checkpoints(filenames, device=device)
checkpoint = {"model": average}
torch.save(
checkpoint,
"pruned_transducer_stateless2/pretrained_average_11_to_29.pt",
)
model.to(device) model.to(device)
model.eval() model.eval()
model.device = device model.device = device
@ -526,33 +535,26 @@ def main():
from lhotse import CutSet from lhotse import CutSet
from lhotse.dataset.webdataset import export_to_webdataset from lhotse.dataset.webdataset import export_to_webdataset
wenetspeech = WenetSpeechAsrDataModule(args) aidatatang_200zh = Aidatatang_200zhAsrDataModule(args)
dev = "dev" dev = "dev"
test_net = "test_net" test = "test"
test_meet = "test_meet"
if not os.path.exists(f"{dev}/shared-0.tar"): if not os.path.exists(f"{dev}/shared-0.tar"):
dev_cuts = wenetspeech.valid_cuts() os.makedirs(dev)
dev_cuts = aidatatang_200zh.valid_cuts()
export_to_webdataset( export_to_webdataset(
dev_cuts, dev_cuts,
output_path=f"{dev}/shared-%d.tar", output_path=f"{dev}/shared-%d.tar",
shard_size=300, shard_size=300,
) )
if not os.path.exists(f"{test_net}/shared-0.tar"): if not os.path.exists(f"{test}/shared-0.tar"):
test_net_cuts = wenetspeech.test_net_cuts() os.makedirs(test)
test_cuts = aidatatang_200zh.test_cuts()
export_to_webdataset( export_to_webdataset(
test_net_cuts, test_cuts,
output_path=f"{test_net}/shared-%d.tar", output_path=f"{test}/shared-%d.tar",
shard_size=300,
)
if not os.path.exists(f"{test_meet}/shared-0.tar"):
test_meeting_cuts = wenetspeech.test_meeting_cuts()
export_to_webdataset(
test_meeting_cuts,
output_path=f"{test_meet}/shared-%d.tar",
shard_size=300, shard_size=300,
) )
@ -567,34 +569,22 @@ def main():
shuffle_shards=True, shuffle_shards=True,
) )
test_net_shards = [ test_shards = [
str(path) str(path)
for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) for path in sorted(glob.glob(os.path.join(test, "shared-*.tar")))
] ]
cuts_test_net_webdataset = CutSet.from_webdataset( cuts_test_webdataset = CutSet.from_webdataset(
test_net_shards, test_shards,
split_by_worker=True, split_by_worker=True,
split_by_node=True, split_by_node=True,
shuffle_shards=True, shuffle_shards=True,
) )
test_meet_shards = [ dev_dl = aidatatang_200zh.valid_dataloaders(cuts_dev_webdataset)
str(path) test_dl = aidatatang_200zh.test_dataloaders(cuts_test_webdataset)
for path in sorted(glob.glob(os.path.join(test_meet, "shared-*.tar")))
]
cuts_test_meet_webdataset = CutSet.from_webdataset(
test_meet_shards,
split_by_worker=True,
split_by_node=True,
shuffle_shards=True,
)
dev_dl = wenetspeech.valid_dataloaders(cuts_dev_webdataset) test_sets = ["dev", "test"]
test_net_dl = wenetspeech.test_dataloaders(cuts_test_net_webdataset) test_dl = [dev_dl, test_dl]
test_meeting_dl = wenetspeech.test_dataloaders(cuts_test_meet_webdataset)
test_sets = ["DEV", "TEST_NET", "TEST_MEETING"]
test_dl = [dev_dl, test_net_dl, test_meeting_dl]
for test_set, test_dl in zip(test_sets, test_dl): for test_set, test_dl in zip(test_sets, test_dl):
results_dict = decode_dataset( results_dict = decode_dataset(

View File

@ -21,8 +21,8 @@ Usage:
./pruned_transducer_stateless2/export.py \ ./pruned_transducer_stateless2/export.py \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--lang-dir data/lang_char \ --lang-dir data/lang_char \
--epoch 20 \ --epoch 29 \
--avg 10 --avg 19
It will generate a file exp_dir/pretrained.pt It will generate a file exp_dir/pretrained.pt
@ -32,7 +32,7 @@ you can do:
cd /path/to/exp_dir cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt ln -s pretrained.pt epoch-9999.pt
cd /path/to/egs/wenetspeech/ASR cd /path/to/egs/aidatatang_200zh/ASR
./pruned_transducer_stateless2/decode.py \ ./pruned_transducer_stateless2/decode.py \
--exp-dir ./pruned_transducer_stateless2/exp \ --exp-dir ./pruned_transducer_stateless2/exp \
--epoch 9999 \ --epoch 9999 \

View File

@ -0,0 +1,347 @@
#!/usr/bin/env python3
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang)
# 2022 Xiaomi Crop. (authors: Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Usage:
(1) greedy search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--lang-dir ./data/lang_char \
--method greedy_search \
--max-sym-per-frame 1 \
/path/to/foo.wav \
/path/to/bar.wav
(2) modified beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless2/exp/pretrained.pt \
--lang-dir ./data/lang_char \
--method modified_beam_search \
--beam-size 4 \
/path/to/foo.wav \
/path/to/bar.wav
(3) fast beam search
./pruned_transducer_stateless2/pretrained.py \
--checkpoint ./pruned_transducer_stateless/exp/pretrained.pt \
--lang-dir ./data/lang_char \
--method fast_beam_search \
--beam 4 \
--max-contexts 4 \
--max-states 8 \
/path/to/foo.wav \
/path/to/bar.wav
You can also use `./pruned_transducer_stateless2/exp/epoch-xx.pt`.
Note: ./pruned_transducer_stateless2/exp/pretrained.pt is generated by
./pruned_transducer_stateless2/export.py
"""
import argparse
import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio
from beam_search import (
beam_search,
fast_beam_search_one_best,
greedy_search,
greedy_search_batch,
modified_beam_search,
)
from torch.nn.utils.rnn import pad_sequence
from train import get_params, get_transducer_model
from icefall.lexicon import Lexicon
def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
parser.add_argument(
"--checkpoint",
type=str,
required=True,
help="Path to the checkpoint. "
"The checkpoint is assumed to be saved by "
"icefall.checkpoint.save_checkpoint().",
)
parser.add_argument(
"--lang-dir",
type=str,
help="""Path to lang.
""",
)
parser.add_argument(
"--decoding-method",
type=str,
default="greedy_search",
help="""Possible values are:
- greedy_search
- modified_beam_search
- fast_beam_search
""",
)
parser.add_argument(
"sound_files",
type=str,
nargs="+",
help="The input sound file(s) to transcribe. "
"Supported formats are those supported by torchaudio.load(). "
"For example, wav and flac are supported. "
"The sample rate has to be 16kHz.",
)
parser.add_argument(
"--sample-rate",
type=int,
default=16000,
help="The sample rate of the input sound file",
)
parser.add_argument(
"--beam-size",
type=int,
default=4,
help="Used only when --method is beam_search and modified_beam_search ",
)
parser.add_argument(
"--beam",
type=float,
default=4,
help="""A floating point value to calculate the cutoff score during beam
search (i.e., `cutoff = max-score - beam`), which is the same as the
`beam` in Kaldi.
Used only when --decoding-method is fast_beam_search""",
)
parser.add_argument(
"--max-contexts",
type=int,
default=4,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--max-states",
type=int,
default=8,
help="""Used only when --decoding-method is
fast_beam_search""",
)
parser.add_argument(
"--context-size",
type=int,
default=2,
help="The context size in the decoder. 1 means bigram; "
"2 means tri-gram",
)
parser.add_argument(
"--max-sym-per-frame",
type=int,
default=1,
help="""Maximum number of symbols per frame. Used only when
--method is greedy_search.
""",
)
return parser
def read_sound_files(
filenames: List[str], expected_sample_rate: float
) -> List[torch.Tensor]:
"""Read a list of sound files into a list 1-D float32 torch tensors.
Args:
filenames:
A list of sound filenames.
expected_sample_rate:
The expected sample rate of the sound files.
Returns:
Return a list of 1-D float32 torch tensors.
"""
ans = []
for f in filenames:
wave, sample_rate = torchaudio.load(f)
assert sample_rate == expected_sample_rate, (
f"expected sample rate: {expected_sample_rate}. "
f"Given: {sample_rate}"
)
# We use only the first channel
ans.append(wave[0])
return ans
@torch.no_grad()
def main():
parser = get_parser()
args = parser.parse_args()
params = get_params()
params.update(vars(args))
lexicon = Lexicon(params.lang_dir)
params.blank_id = lexicon.token_table["<blk>"]
params.vocab_size = max(lexicon.tokens) + 1
logging.info(f"{params}")
device = torch.device("cpu")
if torch.cuda.is_available():
device = torch.device("cuda", 0)
logging.info(f"device: {device}")
logging.info("Creating model")
model = get_transducer_model(params)
checkpoint = torch.load(args.checkpoint, map_location="cpu")
model.load_state_dict(checkpoint["model"], strict=False)
model.to(device)
model.eval()
model.device = device
if params.decoding_method == "fast_beam_search":
decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device)
else:
decoding_graph = None
logging.info("Constructing Fbank computer")
opts = kaldifeat.FbankOptions()
opts.device = device
opts.frame_opts.dither = 0
opts.frame_opts.snip_edges = False
opts.frame_opts.samp_freq = params.sample_rate
opts.mel_opts.num_bins = params.feature_dim
fbank = kaldifeat.Fbank(opts)
logging.info(f"Reading sound files: {params.sound_files}")
waves = read_sound_files(
filenames=params.sound_files, expected_sample_rate=params.sample_rate
)
waves = [w.to(device) for w in waves]
logging.info("Decoding started")
features = fbank(waves)
feature_lengths = [f.size(0) for f in features]
features = pad_sequence(
features, batch_first=True, padding_value=math.log(1e-10)
)
feature_lengths = torch.tensor(feature_lengths, device=device)
with torch.no_grad():
encoder_out, encoder_out_lens = model.encoder(
x=features, x_lens=feature_lengths
)
hyps = []
msg = f"Using {params.decoding_method}"
logging.info(msg)
if params.decoding_method == "fast_beam_search":
hyp_tokens = fast_beam_search_one_best(
model=model,
decoding_graph=decoding_graph,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam,
max_contexts=params.max_contexts,
max_states=params.max_states,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif (
params.decoding_method == "greedy_search"
and params.max_sym_per_frame == 1
):
hyp_tokens = greedy_search_batch(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
elif params.decoding_method == "modified_beam_search":
hyp_tokens = modified_beam_search(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
beam=params.beam_size,
)
for i in range(encoder_out.size(0)):
hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]])
else:
batch_size = encoder_out.size(0)
for i in range(batch_size):
# fmt: off
encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]]
# fmt: on
if params.decoding_method == "greedy_search":
hyp = greedy_search(
model=model,
encoder_out=encoder_out_i,
max_sym_per_frame=params.max_sym_per_frame,
)
elif params.decoding_method == "beam_search":
hyp = beam_search(
model=model,
encoder_out=encoder_out_i,
beam=params.beam_size,
)
else:
raise ValueError(
f"Unsupported decoding method: {params.decoding_method}"
)
hyps.append([lexicon.token_table[idx] for idx in hyp])
s = "\n"
for filename, hyp in zip(params.sound_files, hyps):
words = " ".join(hyp)
s += f"{filename}:\n{words}\n\n"
logging.info(s)
logging.info("Decoding Done")
if __name__ == "__main__":
formatter = (
"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s"
)
logging.basicConfig(format=formatter, level=logging.INFO)
main()

View File

@ -103,7 +103,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--master-port", "--master-port",
type=int, type=int,
default=12354, default=12359,
help="Master port to use for DDP training.", help="Master port to use for DDP training.",
) )