From 17768da017697205d7831fb42c00c3baa164fbb4 Mon Sep 17 00:00:00 2001 From: glynpu Date: Fri, 17 Mar 2023 14:26:50 +0800 Subject: [PATCH] update comments --- egs/himia/wuw/ctc_tdnn/asr_datamodule.py | 4 ++-- egs/himia/wuw/ctc_tdnn/graph.py | 6 +++++- egs/himia/wuw/ctc_tdnn/tokenizer.py | 18 +++++++++--------- egs/himia/wuw/ctc_tdnn/train.py | 3 +++ 4 files changed, 19 insertions(+), 12 deletions(-) diff --git a/egs/himia/wuw/ctc_tdnn/asr_datamodule.py b/egs/himia/wuw/ctc_tdnn/asr_datamodule.py index 72eb2dc8b..db633f9f9 100644 --- a/egs/himia/wuw/ctc_tdnn/asr_datamodule.py +++ b/egs/himia/wuw/ctc_tdnn/asr_datamodule.py @@ -197,7 +197,7 @@ class HiMiaWuwDataModule: default="_7_01", help="""channel of HI_MIA train dataset. All channels are used if it is set "all". - Please refer to state 6 in prepare.sh for its meaning and other + Please refer to stage 6 in prepare.sh for its meaning and other potential values. Currently, Only "_7_01" is verified. """, ) @@ -207,7 +207,7 @@ class HiMiaWuwDataModule: default="_7_01", help="""channel of HI_MIA dev dataset. All channels are used if it is set "all". - Please refer to state 6 in prepare.sh for its meaning and other + Please refer to stage 6 in prepare.sh for its meaning and other potential values. Currently, Only "_7_01" is verified. """, ) diff --git a/egs/himia/wuw/ctc_tdnn/graph.py b/egs/himia/wuw/ctc_tdnn/graph.py index 60e8afe2e..d1ff3114d 100644 --- a/egs/himia/wuw/ctc_tdnn/graph.py +++ b/egs/himia/wuw/ctc_tdnn/graph.py @@ -19,7 +19,7 @@ from typing import List -def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]): +def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]) -> str: """ A graph starts with blank/unknown and following by wakeup word. @@ -27,6 +27,10 @@ def ctc_trivial_decoding_graph(wakeup_word_tokens: List[int]): wakeup_word_tokens: A sequence of token ids corresponding wakeup_word. It should not contain 0 and 1. We assume 0 is for blank and 1 is for unknown. + Returns: + Returns a finite-state transducer in string format, + used as a decoding graph. + Arcs are separated with "\n". """ assert 0 not in wakeup_word_tokens assert 1 not in wakeup_word_tokens diff --git a/egs/himia/wuw/ctc_tdnn/tokenizer.py b/egs/himia/wuw/ctc_tdnn/tokenizer.py index 5bd54d2f0..b6225c66c 100644 --- a/egs/himia/wuw/ctc_tdnn/tokenizer.py +++ b/egs/himia/wuw/ctc_tdnn/tokenizer.py @@ -23,15 +23,15 @@ from typing import List, Tuple class WakeupWordTokenizer(object): def __init__( self, - wakeup_word: str = "", - wakeup_word_tokens: List[int] = None, + wakeup_word: str, + wakeup_word_tokens: List[int], ) -> None: """ Args: wakeup_word: content of positive samples. - A sample will be treated as a negative sample unless its context + A sample will be treated as a negative sample unless its content is exactly the same to key_words. - wakeup_word_tokens: A list if int represents token ids of wakeup_word. + wakeup_word_tokens: A list of int representing token ids of wakeup_word. For example: the pronunciation of "你好米雅" is "n i h ao m i y a". Suppose we are using following lexicon: @@ -67,7 +67,7 @@ class WakeupWordTokenizer(object): def texts_to_token_ids( self, texts: List[str] ) -> Tuple[torch.Tensor, torch.Tensor, int]: - """Convert a list of texts to a list of k2.Fsa based texts. + """Convert a list of texts to parameters needed by CTC loss. Args: texts: @@ -76,7 +76,7 @@ class WakeupWordTokenizer(object): Returns: Return a tuple of 3 elements. The first one is torch.Tensor(List[List[int]]), - each List[int] is tokens sequence for each a reference text. + each List[int] is tokens sequence for each reference text. The second one is number of tokens for each sample, mainly used by CTC loss. @@ -89,13 +89,13 @@ class WakeupWordTokenizer(object): number_positive_samples = 0 for utt_text in texts: if utt_text == self.wakeup_word: - batch_token_ids.append(self.wakeup_word_tokens) + batch_token_ids.extend(self.wakeup_word_tokens) target_lengths.append(self.positive_number_tokens) number_positive_samples += 1 else: - batch_token_ids.append(self.negative_word_tokens) + batch_token_ids.extend(self.negative_word_tokens) target_lengths.append(self.negative_number_tokens) - target = torch.tensor(list(itertools.chain.from_iterable(batch_token_ids))) + target = torch.tensor(batch_token_ids) target_lengths = torch.tensor(target_lengths) return target, target_lengths, number_positive_samples diff --git a/egs/himia/wuw/ctc_tdnn/train.py b/egs/himia/wuw/ctc_tdnn/train.py index d35744e20..62d71b0bf 100755 --- a/egs/himia/wuw/ctc_tdnn/train.py +++ b/egs/himia/wuw/ctc_tdnn/train.py @@ -531,6 +531,9 @@ def run(rank, world_size, args): model = Tdnn(params.feature_dim, params.num_class) + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + checkpoints = load_checkpoint_if_available(params=params, model=model) model.to(device)