add otc related scripts using phone instead of bpe

This commit is contained in:
Dongji Gao 2024-04-21 16:27:22 -04:00
parent 3f62460935
commit fa13951da5
2 changed files with 6 additions and 24 deletions

View File

@ -26,7 +26,6 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
import sentencepiece as spm
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import LibriSpeechAsrDataModule from asr_datamodule import LibriSpeechAsrDataModule
@ -41,7 +40,6 @@ from icefall.checkpoint import (
from icefall.decode import get_lattice, one_best_decoding from icefall.decode import get_lattice, one_best_decoding
from icefall.env import get_env_info from icefall.env import get_env_info
from icefall.lexicon import Lexicon from icefall.lexicon import Lexicon
from icefall.otc_graph_compiler import OtcTrainingGraphCompiler
from icefall.utils import ( from icefall.utils import (
AttributeDict, AttributeDict,
get_texts, get_texts,
@ -94,7 +92,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--avg", "--avg",
type=int, type=int,
default=1, default=5,
help="Number of checkpoints to average. Automatically select " help="Number of checkpoints to average. Automatically select "
"consecutive checkpoints before the checkpoint specified by " "consecutive checkpoints before the checkpoint specified by "
"'--epoch' and '--iter'", "'--epoch' and '--iter'",
@ -195,7 +193,7 @@ def remove_duplicates_and_blank(hyp: List[int]) -> List[int]:
def decode_one_batch( def decode_one_batch(
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
HLG: Optional[k2.Fsa], HLG: k2.Fsa,
batch: dict, batch: dict,
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
@ -239,10 +237,7 @@ def decode_one_batch(
Return the decoding result. See above description for the format of Return the decoding result. See above description for the format of
the returned dict. Note: If it decodes to nothing, then return None. the returned dict. Note: If it decodes to nothing, then return None.
""" """
if HLG is not None: device = HLG.device
device = HLG.device
else:
device = H.device
feature = batch["inputs"] feature = batch["inputs"]
assert feature.ndim == 3 assert feature.ndim == 3
feature = feature.to(device) feature = feature.to(device)
@ -271,7 +266,6 @@ def decode_one_batch(
1, 1,
).to(torch.int32) ).to(torch.int32)
assert HLG is not None
decoding_graph = HLG decoding_graph = HLG
lattice = get_lattice( lattice = get_lattice(
@ -303,7 +297,7 @@ def decode_dataset(
dl: torch.utils.data.DataLoader, dl: torch.utils.data.DataLoader,
params: AttributeDict, params: AttributeDict,
model: nn.Module, model: nn.Module,
HLG: Optional[k2.Fsa], HLG: k2.Fsa,
word_table: k2.SymbolTable, word_table: k2.SymbolTable,
G: Optional[k2.Fsa] = None, G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
@ -452,7 +446,7 @@ def main():
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
# remove otc_token from decoding units # remove otc_token from decoding units
max_token_id = len(lexicon.tokens) - 1 max_token_id = len(lexicon.tokens) - 1
num_classes = max_token_id + 1 # +1 for the blank num_classes = max_token_id + 1 # +1 for the blank
device = torch.device("cpu") device = torch.device("cpu")
@ -463,9 +457,7 @@ def main():
params.num_classes = num_classes params.num_classes = num_classes
HLG = k2.Fsa.from_dict( HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")
)
HLG = HLG.to(device) HLG = HLG.to(device)
assert HLG.requires_grad is False assert HLG.requires_grad is False

View File

@ -899,15 +899,6 @@ def run(rank, world_size, args):
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device("cuda", rank) device = torch.device("cuda", rank)
if params.show_alignment:
HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu"))
HLG = HLG.to(device)
assert HLG.requires_grad is False
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()
params.HLG = HLG
lexicon = Lexicon(params.lang_dir) lexicon = Lexicon(params.lang_dir)
graph_compiler = OtcPhoneTrainingGraphCompiler( graph_compiler = OtcPhoneTrainingGraphCompiler(
lexicon, lexicon,
@ -1118,7 +1109,6 @@ def main():
args.exp_dir = Path(args.exp_dir) args.exp_dir = Path(args.exp_dir)
args.otc_token = f"{args.otc_token}" args.otc_token = f"{args.otc_token}"
world_size = args.world_size world_size = args.world_size
assert world_size >= 1 assert world_size >= 1
if world_size > 1: if world_size > 1: