mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-08 09:32:20 +00:00
Add BPE decoding results.
This commit is contained in:
parent
4ccae509d3
commit
f65854cca5
@ -2,3 +2,120 @@
|
||||
Run `./prepare.sh` to prepare the data.
|
||||
|
||||
Run `./xxx_train.py` (to be added) to train a model.
|
||||
|
||||
## Conformer-CTC
|
||||
Results of the pre-trained model from
|
||||
`<https://huggingface.co/GuoLiyong/snowfall_bpe_model/tree/main/exp-duration-200-feat_batchnorm-bpe-lrfactor5.0-conformer-512-8-noam>`
|
||||
are given below
|
||||
|
||||
### HLG - no LM rescoring
|
||||
|
||||
(output beam size is 8)
|
||||
|
||||
#### 1-best decoding
|
||||
|
||||
```
|
||||
[test-clean-no_rescore] %WER 3.15% [1656 / 52576, 127 ins, 377 del, 1152 sub ]
|
||||
[test-other-no_rescore] %WER 7.03% [3682 / 52343, 220 ins, 1024 del, 2438 sub ]
|
||||
```
|
||||
|
||||
#### n-best decoding
|
||||
|
||||
For n=100,
|
||||
|
||||
```
|
||||
[test-clean-no_rescore-100] %WER 3.15% [1656 / 52576, 127 ins, 377 del, 1152 sub ]
|
||||
[test-other-no_rescore-100] %WER 7.14% [3737 / 52343, 275 ins, 1020 del, 2442 sub ]
|
||||
```
|
||||
|
||||
For n=200,
|
||||
|
||||
```
|
||||
[test-clean-no_rescore-200] %WER 3.16% [1660 / 52576, 125 ins, 378 del, 1157 sub ]
|
||||
[test-other-no_rescore-200] %WER 7.04% [3684 / 52343, 228 ins, 1012 del, 2444 sub ]
|
||||
```
|
||||
|
||||
### HLG - with LM rescoring
|
||||
|
||||
#### Whole lattice rescoring
|
||||
|
||||
```
|
||||
[test-clean-lm_scale_0.8] %WER 2.77% [1456 / 52576, 150 ins, 210 del, 1096 sub ]
|
||||
[test-other-lm_scale_0.8] %WER 6.23% [3262 / 52343, 246 ins, 635 del, 2381 sub ]
|
||||
```
|
||||
|
||||
WERs of different LM scales are:
|
||||
|
||||
```
|
||||
For test-clean, WER of different settings are:
|
||||
lm_scale_0.8 2.77 best for test-clean
|
||||
lm_scale_0.9 2.87
|
||||
lm_scale_1.0 3.06
|
||||
lm_scale_1.1 3.34
|
||||
lm_scale_1.2 3.71
|
||||
lm_scale_1.3 4.18
|
||||
lm_scale_1.4 4.8
|
||||
lm_scale_1.5 5.48
|
||||
lm_scale_1.6 6.08
|
||||
lm_scale_1.7 6.79
|
||||
lm_scale_1.8 7.49
|
||||
lm_scale_1.9 8.14
|
||||
lm_scale_2.0 8.82
|
||||
|
||||
For test-other, WER of different settings are:
|
||||
lm_scale_0.8 6.23 best for test-other
|
||||
lm_scale_0.9 6.37
|
||||
lm_scale_1.0 6.62
|
||||
lm_scale_1.1 6.99
|
||||
lm_scale_1.2 7.46
|
||||
lm_scale_1.3 8.13
|
||||
lm_scale_1.4 8.84
|
||||
lm_scale_1.5 9.61
|
||||
lm_scale_1.6 10.32
|
||||
lm_scale_1.7 11.17
|
||||
lm_scale_1.8 12.12
|
||||
lm_scale_1.9 12.93
|
||||
lm_scale_2.0 13.77
|
||||
```
|
||||
|
||||
#### n-best LM rescoring
|
||||
|
||||
n = 100
|
||||
|
||||
```
|
||||
[test-clean-lm_scale_0.8] %WER 2.79% [1469 / 52576, 149 ins, 212 del, 1108 sub ]
|
||||
[test-other-lm_scale_0.8] %WER 6.36% [3329 / 52343, 259 ins, 666 del, 2404 sub ]
|
||||
```
|
||||
|
||||
WERs of different LM scales are:
|
||||
|
||||
```
|
||||
For test-clean, WER of different settings are:
|
||||
lm_scale_0.8 2.79 best for test-clean
|
||||
lm_scale_0.9 2.89
|
||||
lm_scale_1.0 3.03
|
||||
lm_scale_1.1 3.28
|
||||
lm_scale_1.2 3.52
|
||||
lm_scale_1.3 3.78
|
||||
lm_scale_1.4 4.04
|
||||
lm_scale_1.5 4.24
|
||||
lm_scale_1.6 4.45
|
||||
lm_scale_1.7 4.58
|
||||
lm_scale_1.8 4.7
|
||||
lm_scale_1.9 4.8
|
||||
lm_scale_2.0 4.92
|
||||
For test-other, WER of different settings are:
|
||||
lm_scale_0.8 6.36 best for test-other
|
||||
lm_scale_0.9 6.45
|
||||
lm_scale_1.0 6.64
|
||||
lm_scale_1.1 6.92
|
||||
lm_scale_1.2 7.25
|
||||
lm_scale_1.3 7.59
|
||||
lm_scale_1.4 7.88
|
||||
lm_scale_1.5 8.13
|
||||
lm_scale_1.6 8.36
|
||||
lm_scale_1.7 8.54
|
||||
lm_scale_1.8 8.71
|
||||
lm_scale_1.9 8.88
|
||||
lm_scale_2.0 9.02
|
||||
```
|
||||
|
@ -6,13 +6,25 @@
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from conformer import Conformer
|
||||
|
||||
from icefall.checkpoint import average_checkpoints, load_checkpoint
|
||||
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
|
||||
from icefall.decode import (
|
||||
get_lattice,
|
||||
nbest_decoding,
|
||||
one_best_decoding,
|
||||
rescore_with_n_best_list,
|
||||
rescore_with_whole_lattice,
|
||||
)
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
@ -22,40 +34,6 @@ from icefall.utils import (
|
||||
)
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("conformer_ctc/exp"),
|
||||
"lang_dir": Path("data/lang/bpe"),
|
||||
"lm_dir": Path("data/lm"),
|
||||
"feature_dim": 80,
|
||||
"nhead": 8,
|
||||
"attention_dim": 512,
|
||||
"num_classes": 5000,
|
||||
"subsampling_factor": 4,
|
||||
"num_decoder_layers": 6,
|
||||
"vgg_frontend": False,
|
||||
"is_espnet_structure": True,
|
||||
"mmi_loss": False,
|
||||
"use_feat_batchnorm": True,
|
||||
"search_beam": 20,
|
||||
"output_beam": 5,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
# Possible values for method:
|
||||
# - 1best
|
||||
# - nbest
|
||||
# - nbest-rescoring
|
||||
# - whole-lattice-rescoring
|
||||
"method": "whole-lattice-rescoring",
|
||||
# num_paths is used when method is "nbest" and "nbest-rescoring"
|
||||
"num_paths": 30,
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def get_parser():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
@ -79,6 +57,270 @@ def get_parser():
|
||||
return parser
|
||||
|
||||
|
||||
def get_params() -> AttributeDict:
|
||||
params = AttributeDict(
|
||||
{
|
||||
"exp_dir": Path("conformer_ctc/exp"),
|
||||
"lang_dir": Path("data/lang/bpe"),
|
||||
"lm_dir": Path("data/lm"),
|
||||
"feature_dim": 80,
|
||||
"nhead": 8,
|
||||
"attention_dim": 512,
|
||||
"subsampling_factor": 4,
|
||||
"num_decoder_layers": 6,
|
||||
"vgg_frontend": False,
|
||||
"is_espnet_structure": True,
|
||||
"mmi_loss": False,
|
||||
"use_feat_batchnorm": True,
|
||||
"search_beam": 20,
|
||||
"output_beam": 8,
|
||||
"min_active_states": 30,
|
||||
"max_active_states": 10000,
|
||||
"use_double_scores": True,
|
||||
# Possible values for method:
|
||||
# - 1best
|
||||
# - nbest
|
||||
# - nbest-rescoring
|
||||
# - whole-lattice-rescoring
|
||||
"method": "nbest-rescoring",
|
||||
# num_paths is used when method is "nbest" and "nbest-rescoring"
|
||||
"num_paths": 100,
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: k2.Fsa,
|
||||
batch: dict,
|
||||
lexicon: Lexicon,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[List[int]]]:
|
||||
"""Decode one batch and return the result in a dict. The dict has the
|
||||
following format:
|
||||
|
||||
- key: It indicates the setting used for decoding. For example,
|
||||
if no rescoring is used, the key is the string `no_rescore`.
|
||||
If LM rescoring is used, the key is the string `lm_scale_xxx`,
|
||||
where `xxx` is the value of `lm_scale`. An example key is
|
||||
`lm_scale_0.7`
|
||||
- value: It contains the decoding result. `len(value)` equals to
|
||||
batch size. `value[i]` is the decoding result for the i-th
|
||||
utterance in the given batch.
|
||||
Args:
|
||||
params:
|
||||
It's the return value of :func:`get_params`.
|
||||
|
||||
- params.method is "1best", it uses 1best decoding without LM rescoring.
|
||||
- params.method is "nbest", it uses nbest decoding without LM rescoring.
|
||||
- params.method is "nbest-rescoring", it uses nbest LM rescoring.
|
||||
- params.method is "whole-lattice-rescoring", it uses whole lattice LM
|
||||
rescoring.
|
||||
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph.
|
||||
batch:
|
||||
It is the return value from iterating
|
||||
`lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation
|
||||
for the format of the `batch`.
|
||||
lexicon:
|
||||
It contains word symbol table.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict.
|
||||
"""
|
||||
device = HLG.device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
# at entry, feature is [N, T, C]
|
||||
|
||||
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
|
||||
|
||||
supervisions = batch["supervisions"]
|
||||
|
||||
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
|
||||
# nnet_output is [N, C, T]
|
||||
|
||||
nnet_output = nnet_output.permute(0, 2, 1)
|
||||
# now nnet_output is [N, T, C]
|
||||
|
||||
supervision_segments = torch.stack(
|
||||
(
|
||||
supervisions["sequence_idx"],
|
||||
supervisions["start_frame"] // params.subsampling_factor,
|
||||
supervisions["num_frames"] // params.subsampling_factor,
|
||||
),
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
lattice = get_lattice(
|
||||
nnet_output=nnet_output,
|
||||
HLG=HLG,
|
||||
supervision_segments=supervision_segments,
|
||||
search_beam=params.search_beam,
|
||||
output_beam=params.output_beam,
|
||||
min_active_states=params.min_active_states,
|
||||
max_active_states=params.max_active_states,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
)
|
||||
|
||||
if params.method in ["1best", "nbest"]:
|
||||
if params.method == "1best":
|
||||
best_path = one_best_decoding(
|
||||
lattice=lattice, use_double_scores=params.use_double_scores
|
||||
)
|
||||
key = "no_rescore"
|
||||
else:
|
||||
best_path = nbest_decoding(
|
||||
lattice=lattice,
|
||||
num_paths=params.num_paths,
|
||||
use_double_scores=params.use_double_scores,
|
||||
)
|
||||
key = f"no_rescore-{params.num_paths}"
|
||||
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
|
||||
return {key: hyps}
|
||||
|
||||
assert params.method in ["nbest-rescoring", "whole-lattice-rescoring"]
|
||||
|
||||
lm_scale_list = [0.8, 0.9, 1.0, 1.1, 1.2, 1.3]
|
||||
lm_scale_list += [1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]
|
||||
|
||||
if params.method == "nbest-rescoring":
|
||||
best_path_dict = rescore_with_n_best_list(
|
||||
lattice=lattice,
|
||||
G=G,
|
||||
num_paths=params.num_paths,
|
||||
lm_scale_list=lm_scale_list,
|
||||
)
|
||||
else:
|
||||
best_path_dict = rescore_with_whole_lattice(
|
||||
lattice=lattice, G_with_epsilon_loops=G, lm_scale_list=lm_scale_list
|
||||
)
|
||||
|
||||
ans = dict()
|
||||
for lm_scale_str, best_path in best_path_dict.items():
|
||||
hyps = get_texts(best_path)
|
||||
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
|
||||
ans[lm_scale_str] = hyps
|
||||
return ans
|
||||
|
||||
|
||||
def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: k2.Fsa,
|
||||
lexicon: Lexicon,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[List[int], List[int]]]]:
|
||||
"""Decode dataset.
|
||||
|
||||
Args:
|
||||
dl:
|
||||
PyTorch's dataloader containing the dataset to decode.
|
||||
params:
|
||||
It is returned by :func:`get_params`.
|
||||
model:
|
||||
The neural model.
|
||||
HLG:
|
||||
The decoding graph.
|
||||
lexicon:
|
||||
It contains word symbol table.
|
||||
G:
|
||||
An LM. It is not None when params.method is "nbest-rescoring"
|
||||
or "whole-lattice-rescoring". In general, the G in HLG
|
||||
is a 3-gram LM, while this G is a 4-gram LM.
|
||||
Returns:
|
||||
Return a dict, whose key may be "no-rescore" if no LM rescoring
|
||||
is used, or it may be "lm_scale_0.7" if LM rescoring is used.
|
||||
Its value is a list of tuples. Each tuple contains two elements:
|
||||
The first is the reference transcript, and the second is the
|
||||
predicted result.
|
||||
"""
|
||||
results = []
|
||||
|
||||
num_cuts = 0
|
||||
tot_num_cuts = len(dl.dataset.cuts)
|
||||
|
||||
results = defaultdict(list)
|
||||
for batch_idx, batch in enumerate(dl):
|
||||
texts = batch["supervisions"]["text"]
|
||||
|
||||
hyps_dict = decode_one_batch(
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
batch=batch,
|
||||
lexicon=lexicon,
|
||||
G=G,
|
||||
)
|
||||
|
||||
for lm_scale, hyps in hyps_dict.items():
|
||||
this_batch = []
|
||||
assert len(hyps) == len(texts)
|
||||
for hyp_words, ref_text in zip(hyps, texts):
|
||||
ref_words = ref_text.split()
|
||||
this_batch.append((ref_words, hyp_words))
|
||||
|
||||
results[lm_scale].extend(this_batch)
|
||||
|
||||
num_cuts += len(batch["supervisions"]["text"])
|
||||
|
||||
if batch_idx % 100 == 0:
|
||||
logging.info(
|
||||
f"batch {batch_idx}, cuts processed until now is "
|
||||
f"{num_cuts}/{tot_num_cuts} "
|
||||
f"({float(num_cuts)/tot_num_cuts*100:.6f}%)"
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
def save_results(
|
||||
params: AttributeDict,
|
||||
test_set_name: str,
|
||||
results_dict: Dict[str, List[Tuple[List[int], List[int]]]],
|
||||
):
|
||||
test_set_wers = dict()
|
||||
for key, results in results_dict.items():
|
||||
recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt"
|
||||
store_transcripts(filename=recog_path, texts=results)
|
||||
logging.info(f"The transcripts are stored in {recog_path}")
|
||||
|
||||
# The following prints out WERs, per-word error statistics and aligned
|
||||
# ref/hyp pairs.
|
||||
errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt"
|
||||
with open(errs_filename, "w") as f:
|
||||
wer = write_error_stats(f, f"{test_set_name}-{key}", results)
|
||||
test_set_wers[key] = wer
|
||||
|
||||
logging.info("Wrote detailed error stats to {}".format(errs_filename))
|
||||
|
||||
test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1])
|
||||
errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt"
|
||||
with open(errs_info, "w") as f:
|
||||
print("settings\tWER", file=f)
|
||||
for key, val in test_set_wers:
|
||||
print("{}\t{}".format(key, val), file=f)
|
||||
|
||||
s = "\nFor {}, WER of different settings are:\n".format(test_set_name)
|
||||
note = "\tbest for {}".format(test_set_name)
|
||||
for key, val in test_set_wers:
|
||||
s += "{}\t{}{}\n".format(key, val, note)
|
||||
note = ""
|
||||
logging.info(s)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
parser = get_parser()
|
||||
@ -92,15 +334,64 @@ def main():
|
||||
logging.info("Decoding started")
|
||||
logging.info(params)
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
max_token_id = max(lexicon.tokens)
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
|
||||
device = torch.device("cpu")
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lm_dir}/HLG_bpe.pt"))
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
|
||||
if params.method in ["nbest-rescoring", "whole-lattice-rescoring"]:
|
||||
if not (params.lm_dir / "G_4_gram.pt").is_file():
|
||||
logging.info("Loading G_4_gram.fst.txt")
|
||||
logging.warning("It may take 8 minutes.")
|
||||
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
|
||||
first_word_disambig_id = lexicon.words["#0"]
|
||||
|
||||
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
|
||||
# G.aux_labels is not needed in later computations, so
|
||||
# remove it here.
|
||||
del G.aux_labels
|
||||
# CAUTION: The following line is crucial.
|
||||
# Arcs entering the back-off state have label equal to #0.
|
||||
# We have to change it to 0 here.
|
||||
G.labels[G.labels >= first_word_disambig_id] = 0
|
||||
G = k2.Fsa.from_fsas([G]).to(device)
|
||||
G = k2.arc_sort(G)
|
||||
torch.save(G.as_dict(), params.lm_dir / "G_4_gram.pt")
|
||||
else:
|
||||
logging.info("Loading pre-compiled G_4_gram.pt")
|
||||
d = torch.load(params.lm_dir / "G_4_gram.pt")
|
||||
G = k2.Fsa.from_dict(d).to(device)
|
||||
|
||||
if params.method == "whole-lattice-rescoring":
|
||||
# Add epsilon self-loops to G as we will compose
|
||||
# it with the whole lattice later
|
||||
G = k2.add_epsilon_self_loops(G)
|
||||
G = k2.arc_sort(G)
|
||||
G = G.to(device)
|
||||
|
||||
# G.lm_scores is used to replace HLG.lm_scores during
|
||||
# LM rescoring.
|
||||
G.lm_scores = G.scores.clone()
|
||||
else:
|
||||
G = None
|
||||
|
||||
model = Conformer(
|
||||
num_features=params.feature_dim,
|
||||
nhead=params.nhead,
|
||||
d_model=params.attention_dim,
|
||||
num_classes=params.num_classes,
|
||||
num_classes=num_classes,
|
||||
subsampling_factor=params.subsampling_factor,
|
||||
num_decoder_layers=params.num_decoder_layers,
|
||||
vgg_frontend=params.vgg_frontend,
|
||||
@ -122,7 +413,32 @@ def main():
|
||||
|
||||
model.to(device)
|
||||
model.eval()
|
||||
token_ids_with_blank = list(range(params.num_classes))
|
||||
num_param = sum([p.numel() for p in model.parameters()])
|
||||
logging.info(f"Number of model parameters: {num_param}")
|
||||
|
||||
librispeech = LibriSpeechAsrDataModule(args)
|
||||
# CAUTION: `test_sets` is for displaying only.
|
||||
# If you want to skip test-clean, you have to skip
|
||||
# it inside the for loop. That is, use
|
||||
#
|
||||
# if test_set == 'test-clean': continue
|
||||
#
|
||||
test_sets = ["test-clean", "test-other"]
|
||||
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
|
||||
results_dict = decode_dataset(
|
||||
dl=test_dl,
|
||||
params=params,
|
||||
model=model,
|
||||
HLG=HLG,
|
||||
lexicon=lexicon,
|
||||
G=G,
|
||||
)
|
||||
|
||||
save_results(
|
||||
params=params, test_set_name=test_set, results_dict=results_dict
|
||||
)
|
||||
|
||||
logging.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -39,7 +39,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa:
|
||||
if Path("data/lm/G_3_gram.pt").is_file():
|
||||
print("Loading pre-compiled G_3_gram")
|
||||
d = torch.load("data/lm/G_3_gram.pt")
|
||||
G = k2.Fsa.from_dict(d).to(device)
|
||||
G = k2.Fsa.from_dict(d)
|
||||
else:
|
||||
print("Loading G_3_gram.fst.txt")
|
||||
with open("data/lm/G_3_gram.fst.txt") as f:
|
||||
@ -114,7 +114,7 @@ def bpe_based_HLG():
|
||||
|
||||
print("Compiling BPE based HLG")
|
||||
HLG = compile_HLG("data/lang/bpe")
|
||||
print("Saving HLG.pt to data/lm")
|
||||
print("Saving HLG_bpe.pt to data/lm")
|
||||
torch.save(HLG.as_dict(), "data/lm/HLG_bpe.pt")
|
||||
|
||||
|
||||
|
@ -326,6 +326,8 @@ def main():
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", 0)
|
||||
|
||||
logging.info(f"device: {device}")
|
||||
|
||||
HLG = k2.Fsa.from_dict(torch.load("data/lm/HLG.pt"))
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
@ -54,6 +54,7 @@ def get_lattice(
|
||||
output_beam: float,
|
||||
min_active_states: int,
|
||||
max_active_states: int,
|
||||
subsampling_factor: int = 1,
|
||||
):
|
||||
"""Get the decoding lattice from a decoding graph and neural
|
||||
network output.
|
||||
@ -87,10 +88,14 @@ def get_lattice(
|
||||
frame for any given intersection/composition task. This is advisory,
|
||||
in that it will try not to exceed that but may not always succeed.
|
||||
You can use a very large number if no constraint is needed.
|
||||
subsampling_factor:
|
||||
The subsampling factor of the model.
|
||||
Returns:
|
||||
A lattice containing the decoding result.
|
||||
"""
|
||||
dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
|
||||
dense_fsa_vec = k2.DenseFsaVec(
|
||||
nnet_output, supervision_segments, allow_truncate=subsampling_factor - 1
|
||||
)
|
||||
|
||||
lattice = k2.intersect_dense_pruned(
|
||||
HLG,
|
||||
|
Loading…
x
Reference in New Issue
Block a user