mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 02:34:21 +00:00
add otc related scripts using phone instead of bpe
This commit is contained in:
parent
3f62460935
commit
fa13951da5
@ -26,7 +26,6 @@ from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import k2
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from asr_datamodule import LibriSpeechAsrDataModule
|
||||
@ -41,7 +40,6 @@ from icefall.checkpoint import (
|
||||
from icefall.decode import get_lattice, one_best_decoding
|
||||
from icefall.env import get_env_info
|
||||
from icefall.lexicon import Lexicon
|
||||
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
|
||||
from icefall.utils import (
|
||||
AttributeDict,
|
||||
get_texts,
|
||||
@ -94,7 +92,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--avg",
|
||||
type=int,
|
||||
default=1,
|
||||
default=5,
|
||||
help="Number of checkpoints to average. Automatically select "
|
||||
"consecutive checkpoints before the checkpoint specified by "
|
||||
"'--epoch' and '--iter'",
|
||||
@ -195,7 +193,7 @@ def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
||||
def decode_one_batch(
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: Optional[k2.Fsa],
|
||||
HLG: k2.Fsa,
|
||||
batch: dict,
|
||||
word_table: k2.SymbolTable,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
@ -239,10 +237,7 @@ def decode_one_batch(
|
||||
Return the decoding result. See above description for the format of
|
||||
the returned dict. Note: If it decodes to nothing, then return None.
|
||||
"""
|
||||
if HLG is not None:
|
||||
device = HLG.device
|
||||
else:
|
||||
device = H.device
|
||||
device = HLG.device
|
||||
feature = batch["inputs"]
|
||||
assert feature.ndim == 3
|
||||
feature = feature.to(device)
|
||||
@ -271,7 +266,6 @@ def decode_one_batch(
|
||||
1,
|
||||
).to(torch.int32)
|
||||
|
||||
assert HLG is not None
|
||||
decoding_graph = HLG
|
||||
|
||||
lattice = get_lattice(
|
||||
@ -303,7 +297,7 @@ def decode_dataset(
|
||||
dl: torch.utils.data.DataLoader,
|
||||
params: AttributeDict,
|
||||
model: nn.Module,
|
||||
HLG: Optional[k2.Fsa],
|
||||
HLG: k2.Fsa,
|
||||
word_table: k2.SymbolTable,
|
||||
G: Optional[k2.Fsa] = None,
|
||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||
@ -452,7 +446,7 @@ def main():
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
# remove otc_token from decoding units
|
||||
max_token_id = len(lexicon.tokens) - 1
|
||||
max_token_id = len(lexicon.tokens) - 1
|
||||
num_classes = max_token_id + 1 # +1 for the blank
|
||||
|
||||
device = torch.device("cpu")
|
||||
@ -463,9 +457,7 @@ def main():
|
||||
|
||||
params.num_classes = num_classes
|
||||
|
||||
HLG = k2.Fsa.from_dict(
|
||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
||||
)
|
||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
|
||||
|
@ -899,15 +899,6 @@ def run(rank, world_size, args):
|
||||
if torch.cuda.is_available():
|
||||
device = torch.device("cuda", rank)
|
||||
|
||||
if params.show_alignment:
|
||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
|
||||
HLG = HLG.to(device)
|
||||
assert HLG.requires_grad is False
|
||||
if not hasattr(HLG, "lm_scores"):
|
||||
HLG.lm_scores = HLG.scores.clone()
|
||||
params.HLG = HLG
|
||||
|
||||
|
||||
lexicon = Lexicon(params.lang_dir)
|
||||
graph_compiler = OtcPhoneTrainingGraphCompiler(
|
||||
lexicon,
|
||||
@ -1118,7 +1109,6 @@ def main():
|
||||
args.exp_dir = Path(args.exp_dir)
|
||||
args.otc_token = f"{args.otc_token}"
|
||||
|
||||
|
||||
world_size = args.world_size
|
||||
assert world_size >= 1
|
||||
if world_size > 1:
|
||||
|
Loading…
x
Reference in New Issue
Block a user