remove 'subword' option for 'otc-granularity' to avoid confusions

This commit is contained in:
Dongji Gao 2023-09-25 15:02:30 -04:00
parent c89c5a7299
commit 3da93d942a
2 changed files with 1 additions and 15 deletions

View File

@ -267,14 +267,6 @@ def get_parser():
help="OTC token", help="OTC token",
) )
parser.add_argument(
"--otc-granularity",
type=str,
choices=["word", "subword"],
default="word",
help="OTC granularity",
)
parser.add_argument( parser.add_argument(
"--allow-bypass-arc", "--allow-bypass-arc",
type=str2bool, type=str2bool,
@ -602,7 +594,6 @@ def compute_loss(
allow_self_loop_arc=params.allow_self_loop_arc, allow_self_loop_arc=params.allow_self_loop_arc,
bypass_weight=bypass_weight, bypass_weight=bypass_weight,
self_loop_weight=self_loop_weight, self_loop_weight=self_loop_weight,
otc_granularity=params.otc_granularity,
) )
dense_fsa_vec = k2.DenseFsaVec( dense_fsa_vec = k2.DenseFsaVec(

View File

@ -180,7 +180,6 @@ class OtcTrainingGraphCompiler(object):
allow_self_loop_arc: str2bool = True, allow_self_loop_arc: str2bool = True,
bypass_weight: float = 0.0, bypass_weight: float = 0.0,
self_loop_weight: float = 0.0, self_loop_weight: float = 0.0,
otc_granularity: str = "word",
): ):
otc_token_id = self.token_table[otc_token] otc_token_id = self.token_table[otc_token]
@ -190,11 +189,7 @@ class OtcTrainingGraphCompiler(object):
for word in text.split(): for word in text.split():
piece_ids = self.sp.encode(word, out_type=int) piece_ids = self.sp.encode(word, out_type=int)
if otc_granularity == "word": text_piece_ids.append(piece_ids)
text_piece_ids.append(piece_ids)
elif otc_granularity == "subword":
for piece_id in piece_ids:
text_piece_ids.append([piece_id])
arcs = [] arcs = []
start_state = 0 start_state = 0