diff --git a/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py b/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py index a9359a690..b6b1cb020 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py +++ b/egs/librispeech/WSASR/conformer_ctc2/decode_phone.py @@ -26,7 +26,6 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple import k2 -import sentencepiece as spm import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule @@ -41,7 +40,6 @@ from icefall.checkpoint import ( from icefall.decode import get_lattice, one_best_decoding from icefall.env import get_env_info from icefall.lexicon import Lexicon -from icefall.otc_graph_compiler import OtcTrainingGraphCompiler from icefall.utils import ( AttributeDict, get_texts, @@ -94,7 +92,7 @@ def get_parser(): parser.add_argument( "--avg", type=int, - default=1, + default=5, help="Number of checkpoints to average. Automatically select " "consecutive checkpoints before the checkpoint specified by " "'--epoch' and '--iter'", @@ -195,7 +193,7 @@ def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: def decode_one_batch( params: AttributeDict, model: nn.Module, - HLG: Optional[k2.Fsa], + HLG: k2.Fsa, batch: dict, word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, @@ -239,10 +237,7 @@ def decode_one_batch( Return the decoding result. See above description for the format of the returned dict. Note: If it decodes to nothing, then return None. """ - if HLG is not None: - device = HLG.device - else: - device = H.device + device = HLG.device feature = batch["inputs"] assert feature.ndim == 3 feature = feature.to(device) @@ -271,7 +266,6 @@ def decode_one_batch( 1, ).to(torch.int32) - assert HLG is not None decoding_graph = HLG lattice = get_lattice( @@ -303,7 +297,7 @@ def decode_dataset( dl: torch.utils.data.DataLoader, params: AttributeDict, model: nn.Module, - HLG: Optional[k2.Fsa], + HLG: k2.Fsa, word_table: k2.SymbolTable, G: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: @@ -452,7 +446,7 @@ def main(): lexicon = Lexicon(params.lang_dir) # remove otc_token from decoding units - max_token_id = len(lexicon.tokens) - 1 + max_token_id = len(lexicon.tokens) - 1 num_classes = max_token_id + 1 # +1 for the blank device = torch.device("cpu") @@ -463,9 +457,7 @@ def main(): params.num_classes = num_classes - HLG = k2.Fsa.from_dict( - torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") - ) + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) HLG = HLG.to(device) assert HLG.requires_grad is False diff --git a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py index 68d8f9919..d4f8fc657 100755 --- a/egs/librispeech/WSASR/conformer_ctc2/train_phone.py +++ b/egs/librispeech/WSASR/conformer_ctc2/train_phone.py @@ -899,15 +899,6 @@ def run(rank, world_size, args): if torch.cuda.is_available(): device = torch.device("cuda", rank) - if params.show_alignment: - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) - HLG = HLG.to(device) - assert HLG.requires_grad is False - if not hasattr(HLG, "lm_scores"): - HLG.lm_scores = HLG.scores.clone() - params.HLG = HLG - - lexicon = Lexicon(params.lang_dir) graph_compiler = OtcPhoneTrainingGraphCompiler( lexicon, @@ -1118,7 +1109,6 @@ def main(): args.exp_dir = Path(args.exp_dir) args.otc_token = f"{args.otc_token}" - world_size = args.world_size assert world_size >= 1 if world_size > 1: