mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-27 10:44:19 +00:00
add otc related scripts using phone instead of bpe
This commit is contained in:
parent
3f62460935
commit
fa13951da5
@ -26,7 +26,6 @@ from pathlib import Path
|
|||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from asr_datamodule import LibriSpeechAsrDataModule
|
from asr_datamodule import LibriSpeechAsrDataModule
|
||||||
@ -41,7 +40,6 @@ from icefall.checkpoint import (
|
|||||||
from icefall.decode import get_lattice, one_best_decoding
|
from icefall.decode import get_lattice, one_best_decoding
|
||||||
from icefall.env import get_env_info
|
from icefall.env import get_env_info
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.lexicon import Lexicon
|
||||||
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
|
|
||||||
from icefall.utils import (
|
from icefall.utils import (
|
||||||
AttributeDict,
|
AttributeDict,
|
||||||
get_texts,
|
get_texts,
|
||||||
@ -94,7 +92,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--avg",
|
"--avg",
|
||||||
type=int,
|
type=int,
|
||||||
default=1,
|
default=5,
|
||||||
help="Number of checkpoints to average. Automatically select "
|
help="Number of checkpoints to average. Automatically select "
|
||||||
"consecutive checkpoints before the checkpoint specified by "
|
"consecutive checkpoints before the checkpoint specified by "
|
||||||
"'--epoch' and '--iter'",
|
"'--epoch' and '--iter'",
|
||||||
@ -195,7 +193,7 @@ def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
|
|||||||
def decode_one_batch(
|
def decode_one_batch(
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
HLG: Optional[k2.Fsa],
|
HLG: k2.Fsa,
|
||||||
batch: dict,
|
batch: dict,
|
||||||
word_table: k2.SymbolTable,
|
word_table: k2.SymbolTable,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
@ -239,10 +237,7 @@ def decode_one_batch(
|
|||||||
Return the decoding result. See above description for the format of
|
Return the decoding result. See above description for the format of
|
||||||
the returned dict. Note: If it decodes to nothing, then return None.
|
the returned dict. Note: If it decodes to nothing, then return None.
|
||||||
"""
|
"""
|
||||||
if HLG is not None:
|
device = HLG.device
|
||||||
device = HLG.device
|
|
||||||
else:
|
|
||||||
device = H.device
|
|
||||||
feature = batch["inputs"]
|
feature = batch["inputs"]
|
||||||
assert feature.ndim == 3
|
assert feature.ndim == 3
|
||||||
feature = feature.to(device)
|
feature = feature.to(device)
|
||||||
@ -271,7 +266,6 @@ def decode_one_batch(
|
|||||||
1,
|
1,
|
||||||
).to(torch.int32)
|
).to(torch.int32)
|
||||||
|
|
||||||
assert HLG is not None
|
|
||||||
decoding_graph = HLG
|
decoding_graph = HLG
|
||||||
|
|
||||||
lattice = get_lattice(
|
lattice = get_lattice(
|
||||||
@ -303,7 +297,7 @@ def decode_dataset(
|
|||||||
dl: torch.utils.data.DataLoader,
|
dl: torch.utils.data.DataLoader,
|
||||||
params: AttributeDict,
|
params: AttributeDict,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
HLG: Optional[k2.Fsa],
|
HLG: k2.Fsa,
|
||||||
word_table: k2.SymbolTable,
|
word_table: k2.SymbolTable,
|
||||||
G: Optional[k2.Fsa] = None,
|
G: Optional[k2.Fsa] = None,
|
||||||
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
|
||||||
@ -452,7 +446,7 @@ def main():
|
|||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
# remove otc_token from decoding units
|
# remove otc_token from decoding units
|
||||||
max_token_id = len(lexicon.tokens) - 1
|
max_token_id = len(lexicon.tokens) - 1
|
||||||
num_classes = max_token_id + 1 # +1 for the blank
|
num_classes = max_token_id + 1 # +1 for the blank
|
||||||
|
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
@ -463,9 +457,7 @@ def main():
|
|||||||
|
|
||||||
params.num_classes = num_classes
|
params.num_classes = num_classes
|
||||||
|
|
||||||
HLG = k2.Fsa.from_dict(
|
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
|
||||||
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
|
|
||||||
)
|
|
||||||
HLG = HLG.to(device)
|
HLG = HLG.to(device)
|
||||||
assert HLG.requires_grad is False
|
assert HLG.requires_grad is False
|
||||||
|
|
||||||
|
@ -899,15 +899,6 @@ def run(rank, world_size, args):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
device = torch.device("cuda", rank)
|
device = torch.device("cuda", rank)
|
||||||
|
|
||||||
if params.show_alignment:
|
|
||||||
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
|
|
||||||
HLG = HLG.to(device)
|
|
||||||
assert HLG.requires_grad is False
|
|
||||||
if not hasattr(HLG, "lm_scores"):
|
|
||||||
HLG.lm_scores = HLG.scores.clone()
|
|
||||||
params.HLG = HLG
|
|
||||||
|
|
||||||
|
|
||||||
lexicon = Lexicon(params.lang_dir)
|
lexicon = Lexicon(params.lang_dir)
|
||||||
graph_compiler = OtcPhoneTrainingGraphCompiler(
|
graph_compiler = OtcPhoneTrainingGraphCompiler(
|
||||||
lexicon,
|
lexicon,
|
||||||
@ -1118,7 +1109,6 @@ def main():
|
|||||||
args.exp_dir = Path(args.exp_dir)
|
args.exp_dir = Path(args.exp_dir)
|
||||||
args.otc_token = f"{args.otc_token}"
|
args.otc_token = f"{args.otc_token}"
|
||||||
|
|
||||||
|
|
||||||
world_size = args.world_size
|
world_size = args.world_size
|
||||||
assert world_size >= 1
|
assert world_size >= 1
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user