= += ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -= +
= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,3 +318,4 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo + diff --git a/egs/csj/ASR/local/conf/fluent.ini b/egs/csj/ASR/local/conf/fluent.ini index f7f27f5bc..5d22f9eb8 100644 --- a/egs/csj/ASR/local/conf/fluent.ini +++ b/egs/csj/ASR/local/conf/fluent.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "
" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was " " in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = += ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -= +
= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,3 +318,4 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo + diff --git a/egs/csj/ASR/local/conf/number.ini b/egs/csj/ASR/local/conf/number.ini index cf9038f62..2613c3409 100644 --- a/egs/csj/ASR/local/conf/number.ini +++ b/egs/csj/ASR/local/conf/number.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "
" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was " " in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -115,59 +115,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = += ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -= +
= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -318,3 +318,4 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo + diff --git a/egs/csj/ASR/local/conf/symbol.ini b/egs/csj/ASR/local/conf/symbol.ini index f9801284b..8ba451dd5 100644 --- a/egs/csj/ASR/local/conf/symbol.ini +++ b/egs/csj/ASR/local/conf/symbol.ini @@ -1,17 +1,17 @@ ; # This section is ignored if this file is not supplied as the first config file to -; # lhotse prepare csj +; # lhotse prepare csj [SEGMENTS] ; # Allowed period of nonverbal noise. If exceeded, a new segment is created. gap = 0.5 ; # Maximum length of segment (s). maxlen = 10 -; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. +; # Minimum length of segment (s). Segments shorter than `minlen` will be dropped silently. minlen = 0.02 -; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. -; # Pass an empty string to avoid adding any symbol. It was "
" in kaldi. -; # If you intend to use a multicharacter string for gap_sym, remember to register the -; # multicharacter string as part of userdef-string in prepare_lang_char.py. -gap_sym = +; # Use this symbol to represent a period of allowed nonverbal noise, i.e. `gap`. +; # Pass an empty string to avoid adding any symbol. It was " " in kaldi. +; # If you intend to use a multicharacter string for gap_sym, remember to register the +; # multicharacter string as part of userdef-string in prepare_lang_char.py. +gap_sym = [CONSTANTS] ; # Name of this mode @@ -116,59 +116,59 @@ B^ = 0 ; # 0 to remain, 1 to delete ; # Example: '(笑 ナニガ)', '(笑 (F エー)+ソー+イッ+タ+ヨー+ナ)' 笑 = 0 -; # Example: 'コク(笑 サイ+(D オン))', +; # Example: 'コク(笑 サイ+(D オン))', 笑^ = 0 ; # 泣きながら発話 ; # 0 to remain, 1 to delete -; # Example: '(泣 ドンナニ)' +; # Example: '(泣 ドンナニ)' 泣 = 0 泣^ = 0 ; # 咳をしながら発話 ; # 0 to remain, 1 to delete -; # Example: 'シャ(咳 リン) ノ' +; # Example: 'シャ(咳 リン) ノ' 咳 = 0 ; # Example: 'イッ(咳 パン)', 'ワズ(咳 カ)' 咳^ = 0 ; # ささやき声や独り言などの小さな声 ; # 0 to remain, 1 to delete -; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' +; # Example: '(L アレコレナンダッケ)', '(L (W コデ;(? コレ,ココデ))+(? セツメー+シ+タ+ホー+ガ+イー+カ+ナ))' L = 0 ; # Example: 'デ(L ス)', 'ッ(L テ+コ)ト' L^ = 0 [REPLACEMENTS] ; # ボーカルフライなどで母音が同定できない場合 - = + = ; # 「うん/うーん/ふーん」の音の特定が困難な場合 - = + = ; # 非語彙的な母音の引き延ばし - = + = ; # 非語彙的な子音の引き延ばし - = += ; # 言語音と独立に講演者の笑いが生じている場合 -<笑> = +<笑> = ; # 言語音と独立に講演者の咳が生じている場合 -<咳> = +<咳> = ; # 言語音と独立に講演者の息が生じている場合 -<息> = +<息> = ; # 講演者の泣き声 -<泣> = +<泣> = ; # 聴衆(司会者なども含む)の発話 -<フロア発話> = +<フロア発話> = ; # 聴衆の笑い -<フロア笑> = +<フロア笑> = ; # 聴衆の拍手 -<拍手> = +<拍手> = ; # 講演者が発表中に用いたデモンストレーションの音声 -<デモ> = +<デモ> = ; # 学会講演に発表時間を知らせるためにならすベルの音 -<ベル> = +<ベル> = ; # 転記単位全体が再度読み直された場合 -<朗読間違い> = +<朗読間違い> = ; # 上記以外の音で特に目立った音 -<雑音> = +<雑音> = ; # 0.2秒以上のポーズ -= +
= ; # Redacted information, for R ; # It is \x00D7 multiplication sign, not your normal 'x' × = × @@ -319,3 +319,4 @@ spk_id = 2 ャ = ǐa ュ = ǐu ョ = ǐo + diff --git a/egs/csj/ASR/local/display_manifest_statistics.py b/egs/csj/ASR/local/display_manifest_statistics.py index c043cf853..c9de21073 100644 --- a/egs/csj/ASR/local/display_manifest_statistics.py +++ b/egs/csj/ASR/local/display_manifest_statistics.py @@ -37,7 +37,9 @@ def get_parser(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) - parser.add_argument("--manifest-dir", type=Path, help="Path to cutset manifests") + parser.add_argument( + "--manifest-dir", type=Path, help="Path to cutset manifests" + ) return parser.parse_args() diff --git a/egs/csj/ASR/local/prepare_lang_char.py b/egs/csj/ASR/local/prepare_lang_char.py index f0078421b..e4d996871 100644 --- a/egs/csj/ASR/local/prepare_lang_char.py +++ b/egs/csj/ASR/local/prepare_lang_char.py @@ -68,7 +68,8 @@ def get_args(): type=Path, default=None, help=( - "Name of lang dir. If not set, this will default to lang_char_{trans-mode}" + "Name of lang dir. " + "If not set, this will default to lang_char_{trans-mode}" ), ) @@ -86,7 +87,9 @@ def main(): args = get_args() logging.basicConfig( - format="%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s", + format=( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] " "%(message)s" + ), level=logging.INFO, ) @@ -108,7 +111,8 @@ def main(): words = set() logging.info( - f"Creating vocabulary from {args.train_cut.name} at {args.trans_mode} mode." + f"Creating vocabulary from {args.train_cut.name}" + f" at {args.trans_mode} mode." ) for cut in train_set: try: @@ -119,7 +123,8 @@ def main(): ) except KeyError: raise KeyError( - f"Could not find {args.trans_mode} in {cut.supervisions[0].custom}" + f"Could not find {args.trans_mode} in " + f"{cut.supervisions[0].custom}" ) for t in text.split(): if t in args.userdef_string: @@ -138,7 +143,9 @@ def main(): (args.lang_dir / "words_len").write_text(f"{len(words)}") - (args.lang_dir / "userdef_string").write_text("\n".join(args.userdef_string)) + (args.lang_dir / "userdef_string").write_text( + "\n".join(args.userdef_string) + ) (args.lang_dir / "trans_mode").write_text(args.trans_mode) logging.info("Done.") diff --git a/egs/csj/ASR/local/validate_manifest.py b/egs/csj/ASR/local/validate_manifest.py index 89448a49c..0c4c6c1ea 100644 --- a/egs/csj/ASR/local/validate_manifest.py +++ b/egs/csj/ASR/local/validate_manifest.py @@ -68,7 +68,8 @@ def validate_supervision_and_cut_time_bounds(c: Cut): if s.end > c.end: raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger than cut end time {c.end}" + f"{c.id}: Supervision end time {s.end} is larger " + f"than cut end time {c.end}" ) @@ -88,7 +89,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py index c3e3e84bf..d78e26240 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py +++ b/egs/gigaspeech/ASR/conformer_ctc/asr_datamodule.py @@ -61,12 +61,10 @@ class GigaSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--manifest-dir", @@ -78,91 +76,75 @@ class GigaSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -176,22 +158,18 @@ class GigaSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it " - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it " + "with training dataset. ", ) # GigaSpeech specific arguments @@ -205,25 +183,30 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev (speeds up training)", + help="Should we use only 1000 utterances for dev " + "(speeds up training)", ) def train_dataloaders(self, cuts_train: CutSet) -> DataLoader: logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -238,7 +221,9 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -271,7 +256,9 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -317,7 +304,9 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -373,7 +362,9 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "cuts_DEV.jsonl.gz" + ) if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/conformer_ctc/conformer.py b/egs/gigaspeech/ASR/conformer_ctc/conformer.py index 1153a814c..6fac07f93 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/conformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/conformer.py @@ -160,7 +160,9 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -180,14 +182,18 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -221,7 +227,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -340,7 +348,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -356,7 +366,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -626,9 +638,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -696,25 +708,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -751,7 +771,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -763,7 +785,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -797,9 +821,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/gigaspeech/ASR/conformer_ctc/decode.py b/egs/gigaspeech/ASR/conformer_ctc/decode.py index b38ae9c8c..9c1418baa 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/decode.py +++ b/egs/gigaspeech/ASR/conformer_ctc/decode.py @@ -62,19 +62,16 @@ def get_parser(): "--epoch", type=int, default=0, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=1, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -479,7 +476,9 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert len(results) > 0, "It should not decode to empty in the first batch!" + assert ( + len(results) > 0 + ), "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for cut_id, ref_text in zip(cut_ids, texts): @@ -494,7 +493,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -527,7 +528,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -702,7 +705,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py index 880aa76e2..ef53b77f8 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py +++ b/egs/gigaspeech/ASR/conformer_ctc/gigaspeech_scoring.py @@ -73,7 +73,8 @@ def asr_text_post_processing(text: str) -> str: if __name__ == "__main__": parser = argparse.ArgumentParser( - description="This script evaluates GigaSpeech ASR result viaSCTK's tool sclite" + description="This script evaluates GigaSpeech ASR result via" + "SCTK's tool sclite" ) parser.add_argument( "ref", diff --git a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py index 3b94f0c4b..cdc85ce9a 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/gigaspeech/ASR/conformer_ctc/label_smoothing.py @@ -78,10 +78,13 @@ class LabelSmoothingLoss(torch.nn.Module): ignored = target == self.ignore_index target[ignored] = 0 - true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) + true_dist = torch.nn.functional.one_hot( + target, num_classes=num_classes + ).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 true_dist[ignored] = 0 diff --git a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py index 8e0f73d05..542fb0364 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/subsampling.py +++ b/egs/gigaspeech/ASR/conformer_ctc/subsampling.py @@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), - nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -128,13 +132,17 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/gigaspeech/ASR/conformer_ctc/train.py b/egs/gigaspeech/ASR/conformer_ctc/train.py index 4883d04d8..2965cde18 100755 --- a/egs/gigaspeech/ASR/conformer_ctc/train.py +++ b/egs/gigaspeech/ASR/conformer_ctc/train.py @@ -386,7 +386,9 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -519,7 +521,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -637,7 +641,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/gigaspeech/ASR/conformer_ctc/transformer.py b/egs/gigaspeech/ASR/conformer_ctc/transformer.py index 0566cfc81..00ca027a7 100644 --- a/egs/gigaspeech/ASR/conformer_ctc/transformer.py +++ b/egs/gigaspeech/ASR/conformer_ctc/transformer.py @@ -151,7 +151,9 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) self.decoder_criterion = LabelSmoothingLoss() else: @@ -179,13 +181,18 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: + if ( + isinstance(self.use_feat_batchnorm, bool) + and self.use_feat_batchnorm + ): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -266,17 +273,23 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -337,17 +350,23 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -620,7 +639,9 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) class PositionalEncoding(nn.Module): @@ -822,7 +843,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -843,7 +866,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index 07beeb1f0..8209ee3ec 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -77,7 +77,9 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py index 0ee845ec8..6410249db 100755 --- a/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/gigaspeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -47,10 +47,8 @@ def get_parser(): "--batch-duration", type=float, default=600.0, - help=( - "The maximum number of audio seconds in a batch." - "Determines batch size dynamically." - ), + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", ) parser.add_argument( @@ -136,7 +134,9 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py index 31abe7fff..48d10a157 100755 --- a/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/gigaspeech/ASR/local/preprocess_gigaspeech.py @@ -98,13 +98,19 @@ def preprocess_giga_speech(): f"Speed perturb for {partition} with factors 0.9 and 1.1 " "(Perturbing may take 8 minutes and saving may take 20 minutes)" ) - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 9ae3f071e..c87686e1e 100644 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -73,12 +73,10 @@ class GigaSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--manifest-dir", @@ -90,91 +88,75 @@ class GigaSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -188,22 +170,18 @@ class GigaSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it " - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it " + "with training dataset. ", ) # GigaSpeech specific arguments @@ -217,7 +195,8 @@ class GigaSpeechAsrDataModule: "--small-dev", type=str2bool, default=False, - help="Should we use only 1000 utterances for dev (speeds up training)", + help="Should we use only 1000 utterances for dev " + "(speeds up training)", ) def train_dataloaders( @@ -237,16 +216,20 @@ class GigaSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -261,7 +244,9 @@ class GigaSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -304,7 +289,9 @@ class GigaSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -360,7 +347,9 @@ class GigaSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -416,7 +405,9 @@ class GigaSpeechAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - cuts_valid = load_manifest_lazy(self.args.manifest_dir / "cuts_DEV.jsonl.gz") + cuts_valid = load_manifest_lazy( + self.args.manifest_dir / "cuts_DEV.jsonl.gz" + ) if self.args.small_dev: return cuts_valid.subset(first=1000) else: diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py index 9f5d4711b..5849a3471 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/decode.py @@ -77,7 +77,11 @@ from beam_search import ( from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -114,11 +118,9 @@ def get_parser(): "--avg", type=int, default=8, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -186,7 +188,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -255,7 +258,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -270,7 +275,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -316,7 +324,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -386,7 +398,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -420,7 +434,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -496,7 +511,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py index 17f8614dc..cff9c7377 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/export.py @@ -51,7 +51,11 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -83,11 +87,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -118,7 +120,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -157,7 +160,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -205,7 +209,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py index 4d1a2356d..83ae25561 100755 --- a/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/gigaspeech/ASR/pruned_transducer_stateless2/train.py @@ -77,7 +77,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -176,45 +178,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -554,16 +553,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -726,7 +732,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") diff --git a/egs/librispeech/ASR/conformer_ctc/ali.py b/egs/librispeech/ASR/conformer_ctc/ali.py index 0169d0f82..2828e309e 100755 --- a/egs/librispeech/ASR/conformer_ctc/ali.py +++ b/egs/librispeech/ASR/conformer_ctc/ali.py @@ -61,19 +61,16 @@ def get_parser(): "--epoch", type=int, default=34, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -234,7 +231,9 @@ def compute_alignments( labels_ali = get_alignments(best_path, kind="labels") aux_labels_ali = get_alignments(best_path, kind="aux_labels") assert len(labels_ali) == len(aux_labels_ali) == len(cut_list) - for cut, labels, aux_labels in zip(cut_list, labels_ali, aux_labels_ali): + for cut, labels, aux_labels in zip( + cut_list, labels_ali, aux_labels_ali + ): cut.labels_alignment = labels_writer.store_array( key=cut.id, value=np.asarray(labels, dtype=np.int32), @@ -259,7 +258,9 @@ def compute_alignments( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return CutSet.from_cuts(cuts) @@ -288,7 +289,9 @@ def main(): out_labels_ali_filename = out_dir / f"labels_{params.dataset}.h5" out_aux_labels_ali_filename = out_dir / f"aux_labels_{params.dataset}.h5" - out_manifest_filename = out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" + out_manifest_filename = ( + out_dir / f"librispeech_cuts_{params.dataset}.jsonl.gz" + ) for f in ( out_labels_ali_filename, diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 1153a814c..6fac07f93 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -160,7 +160,9 @@ class ConformerEncoderLayer(nn.Module): use_conv_batchnorm: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -180,14 +182,18 @@ class ConformerEncoderLayer(nn.Module): d_model, cnn_module_kernel, use_batchnorm=use_conv_batchnorm ) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -221,7 +227,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -340,7 +348,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -356,7 +366,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -626,9 +638,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -696,25 +708,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -751,7 +771,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -763,7 +785,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -797,9 +821,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 66fdf82d9..3f3b1acda 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -64,19 +64,16 @@ def get_parser(): "--epoch", type=int, default=77, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=55, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -554,7 +551,9 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert len(results) > 0, "It should not decode to empty in the first batch!" + assert ( + len(results) > 0 + ), "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -569,7 +568,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -601,7 +602,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -806,7 +809,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc/export.py b/egs/librispeech/ASR/conformer_ctc/export.py index bdb8a85e5..28c28df01 100755 --- a/egs/librispeech/ASR/conformer_ctc/export.py +++ b/egs/librispeech/ASR/conformer_ctc/export.py @@ -40,20 +40,17 @@ def get_parser(): "--epoch", type=int, default=34, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -160,7 +157,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py index cb0d6e04d..1f2f3b137 100644 --- a/egs/librispeech/ASR/conformer_ctc/label_smoothing.py +++ b/egs/librispeech/ASR/conformer_ctc/label_smoothing.py @@ -82,10 +82,13 @@ class LabelSmoothingLoss(torch.nn.Module): # for why we don't use target[ignored] = 0 here target = torch.where(ignored, torch.zeros_like(target), target) - true_dist = torch.nn.functional.one_hot(target, num_classes=num_classes).to(x) + true_dist = torch.nn.functional.one_hot( + target, num_classes=num_classes + ).to(x) true_dist = ( - true_dist * (1 - self.label_smoothing) + self.label_smoothing / num_classes + true_dist * (1 - self.label_smoothing) + + self.label_smoothing / num_classes ) # Set the value of ignored indexes to 0 diff --git a/egs/librispeech/ASR/conformer_ctc/pretrained.py b/egs/librispeech/ASR/conformer_ctc/pretrained.py index 8cabf1a53..a2c0a5486 100755 --- a/egs/librispeech/ASR/conformer_ctc/pretrained.py +++ b/egs/librispeech/ASR/conformer_ctc/pretrained.py @@ -48,11 +48,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -191,12 +189,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -240,9 +236,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -303,7 +300,9 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -428,7 +427,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py index 8e0f73d05..542fb0364 100644 --- a/egs/librispeech/ASR/conformer_ctc/subsampling.py +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -42,9 +42,13 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), - nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -128,13 +132,17 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 1a1c2f4c5..6419f6816 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -393,7 +393,9 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + raise ValueError( + f"Unsupported type of graph compiler: {type(graph_compiler)}" + ) dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -420,7 +422,9 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -449,7 +453,9 @@ def compute_loss( info["utt_duration"] = supervisions["num_frames"].sum().item() # averaged padding proportion over utterances info["utt_pad_proportion"] = ( - ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)) + .sum() + .item() ) return loss, info @@ -562,7 +568,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -652,7 +660,7 @@ def run(rank, world_size, args): graph_compiler.eos_id = 1 else: raise ValueError( - "Unsupported type of lang dir (we expected it to have " + f"Unsupported type of lang dir (we expected it to have " f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" ) @@ -725,7 +733,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 0566cfc81..00ca027a7 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -151,7 +151,9 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) self.decoder_criterion = LabelSmoothingLoss() else: @@ -179,13 +181,18 @@ class Transformer(nn.Module): memory_key_padding_mask for the decoder. Its shape is (N, T). It is None if `supervision` is None. """ - if isinstance(self.use_feat_batchnorm, bool) and self.use_feat_batchnorm: + if ( + isinstance(self.use_feat_batchnorm, bool) + and self.use_feat_batchnorm + ): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) if isinstance(self.use_feat_batchnorm, float): x *= self.use_feat_batchnorm - encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -266,17 +273,23 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -337,17 +350,23 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -620,7 +639,9 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) class PositionalEncoding(nn.Module): @@ -822,7 +843,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -843,7 +866,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_ctc2/attention.py b/egs/librispeech/ASR/conformer_ctc2/attention.py index 356d3f21b..1375d7245 100644 --- a/egs/librispeech/ASR/conformer_ctc2/attention.py +++ b/egs/librispeech/ASR/conformer_ctc2/attention.py @@ -18,10 +18,11 @@ from typing import Optional, Tuple import torch import torch.nn as nn -from scaling import ScaledLinear from torch import Tensor from torch.nn.init import xavier_normal_ +from scaling import ScaledLinear + class MultiheadAttention(nn.Module): r"""Allows the model to jointly attend to information @@ -75,7 +76,9 @@ class MultiheadAttention(nn.Module): self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim self.vdim = vdim if vdim is not None else embed_dim - self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + self._qkv_same_embed_dim = ( + self.kdim == embed_dim and self.vdim == embed_dim + ) self.num_heads = num_heads self.dropout = dropout @@ -91,7 +94,9 @@ class MultiheadAttention(nn.Module): self.v_proj_weight = ScaledLinear(self.vdim, embed_dim, bias=bias) self.register_parameter("in_proj_weight", None) else: - self.in_proj_weight = ScaledLinear(embed_dim, 3 * embed_dim, bias=bias) + self.in_proj_weight = ScaledLinear( + embed_dim, 3 * embed_dim, bias=bias + ) self.register_parameter("q_proj_weight", None) self.register_parameter("k_proj_weight", None) self.register_parameter("v_proj_weight", None) @@ -102,8 +107,12 @@ class MultiheadAttention(nn.Module): self.out_proj = ScaledLinear(embed_dim, embed_dim, bias=bias) if add_bias_kv: - self.bias_k = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) - self.bias_v = nn.Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_k = nn.Parameter( + torch.empty((1, 1, embed_dim), **factory_kwargs) + ) + self.bias_v = nn.Parameter( + torch.empty((1, 1, embed_dim), **factory_kwargs) + ) else: self.bias_k = self.bias_v = None diff --git a/egs/librispeech/ASR/conformer_ctc2/conformer.py b/egs/librispeech/ASR/conformer_ctc2/conformer.py index a6f1679ef..b906d2650 100644 --- a/egs/librispeech/ASR/conformer_ctc2/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/conformer.py @@ -29,8 +29,9 @@ from scaling import ( ScaledConv1d, ScaledLinear, ) -from subsampling import Conv2dSubsampling from torch import Tensor, nn +from subsampling import Conv2dSubsampling + from transformer import Supervisions, Transformer, encoder_padding_mask @@ -181,7 +182,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -353,7 +356,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -368,7 +373,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -643,9 +650,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -714,25 +721,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -769,7 +784,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -777,9 +794,13 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -813,9 +834,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -838,7 +863,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_ctc2/decode.py b/egs/librispeech/ASR/conformer_ctc2/decode.py index 934177b1f..97f2f2d39 100755 --- a/egs/librispeech/ASR/conformer_ctc2/decode.py +++ b/egs/librispeech/ASR/conformer_ctc2/decode.py @@ -90,11 +90,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -132,13 +130,11 @@ def get_parser(): "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -662,7 +658,9 @@ def decode_dataset( results[lm_scale].extend(this_batch) else: - assert len(results) > 0, "It should not decode to empty in the first batch!" + assert ( + len(results) > 0 + ), "It should not decode to empty in the first batch!" this_batch = [] hyp_words = [] for ref_text in texts: @@ -677,7 +675,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -709,7 +709,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -850,12 +852,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -878,12 +881,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -911,7 +915,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -981,7 +985,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_ctc2/export.py b/egs/librispeech/ASR/conformer_ctc2/export.py index 0e1841d8d..584b3c3fc 100755 --- a/egs/librispeech/ASR/conformer_ctc2/export.py +++ b/egs/librispeech/ASR/conformer_ctc2/export.py @@ -47,7 +47,6 @@ import logging from pathlib import Path import torch -from conformer import Conformer from decode import get_params from icefall.checkpoint import ( @@ -56,8 +55,10 @@ from icefall.checkpoint import ( find_checkpoints, load_checkpoint, ) -from icefall.lexicon import Lexicon +from conformer import Conformer + from icefall.utils import str2bool +from icefall.lexicon import Lexicon def get_parser(): @@ -88,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -180,12 +177,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -208,12 +206,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -241,7 +240,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -274,7 +273,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conformer_ctc2/train.py b/egs/librispeech/ASR/conformer_ctc2/train.py index 4d7137ad7..18fa3e69f 100755 --- a/egs/librispeech/ASR/conformer_ctc2/train.py +++ b/egs/librispeech/ASR/conformer_ctc2/train.py @@ -69,8 +69,8 @@ from torch.cuda.amp import GradScaler from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.tensorboard import SummaryWriter -from icefall import diagnostics from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler +from icefall import diagnostics from icefall.checkpoint import load_checkpoint, remove_checkpoints from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.checkpoint import ( @@ -89,7 +89,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -496,7 +498,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -525,7 +531,9 @@ def compute_loss( # Works with a phone lexicon decoding_graph = graph_compiler.compile(texts) else: - raise ValueError(f"Unsupported type of graph compiler: {type(graph_compiler)}") + raise ValueError( + f"Unsupported type of graph compiler: {type(graph_compiler)}" + ) dense_fsa_vec = k2.DenseFsaVec( nnet_output, @@ -552,7 +560,9 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -570,7 +580,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) info["ctc_loss"] = ctc_loss.detach().cpu().item() if params.att_rate != 0.0: info["att_loss"] = att_loss.detach().cpu().item() @@ -708,7 +720,8 @@ def train_one_epoch( except RuntimeError as e: if "CUDA out of memory" in str(e): logging.error( - f"failing batch size:{batch_size} failing batch names {batch_name}" + f"failing batch size:{batch_size} " + f"failing batch names {batch_name}" ) raise @@ -763,9 +776,9 @@ def train_one_epoch( f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}" ) - if loss_info["ctc_loss"] == float("inf") or loss_info["att_loss"] == float( - "inf" - ): + if loss_info["ctc_loss"] == float("inf") or loss_info[ + "att_loss" + ] == float("inf"): logging.error( "Your loss contains inf, something goes wrong" f"failing batch names {batch_name}" @@ -778,7 +791,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -870,7 +885,7 @@ def run(rank, world_size, args): graph_compiler.eos_id = 1 else: raise ValueError( - "Unsupported type of lang dir (we expected it to have " + f"Unsupported type of lang dir (we expected it to have " f"'lang_bpe' or 'lang_phone' in its name): {params.lang_dir}" ) diff --git a/egs/librispeech/ASR/conformer_ctc2/transformer.py b/egs/librispeech/ASR/conformer_ctc2/transformer.py index d3443dc94..3ef7edc23 100644 --- a/egs/librispeech/ASR/conformer_ctc2/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc2/transformer.py @@ -21,17 +21,19 @@ from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn -from attention import MultiheadAttention from label_smoothing import LabelSmoothingLoss +from subsampling import Conv2dSubsampling +from attention import MultiheadAttention +from torch.nn.utils.rnn import pad_sequence + from scaling import ( ActivationBalancer, BasicNorm, DoubleSwish, - ScaledEmbedding, ScaledLinear, + ScaledEmbedding, ) -from subsampling import Conv2dSubsampling -from torch.nn.utils.rnn import pad_sequence + # Note: TorchScript requires Dict/List/etc. to be fully typed. Supervisions = Dict[str, torch.Tensor] @@ -208,7 +210,9 @@ class Transformer(nn.Module): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) mask = encoder_padding_mask(x.size(0), supervisions) mask = mask.to(x.device) if mask is not None else None - x = self.encoder(x, src_key_padding_mask=mask, warmup=warmup) # (T, N, C) + x = self.encoder( + x, src_key_padding_mask=mask, warmup=warmup + ) # (T, N, C) return x, mask @@ -257,17 +261,23 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -328,17 +338,23 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -943,7 +959,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -964,7 +982,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conformer_mmi/conformer.py b/egs/librispeech/ASR/conformer_mmi/conformer.py index 4d9ddaea9..97c8d83a2 100644 --- a/egs/librispeech/ASR/conformer_mmi/conformer.py +++ b/egs/librispeech/ASR/conformer_mmi/conformer.py @@ -156,7 +156,9 @@ class ConformerEncoderLayer(nn.Module): normalize_before: bool = True, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -174,14 +176,18 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -215,7 +221,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -334,7 +342,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -350,7 +360,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -620,9 +632,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -690,25 +702,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -745,7 +765,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -757,7 +779,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -791,9 +815,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -816,7 +844,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/conformer_mmi/decode.py b/egs/librispeech/ASR/conformer_mmi/decode.py index e8390ded9..fc9861489 100755 --- a/egs/librispeech/ASR/conformer_mmi/decode.py +++ b/egs/librispeech/ASR/conformer_mmi/decode.py @@ -60,19 +60,16 @@ def get_parser(): "--epoch", type=int, default=34, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -481,7 +478,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -513,7 +512,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -652,7 +653,9 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) return model.to(device) @@ -684,7 +687,9 @@ def main(): eos_id=eos_id, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/conformer_mmi/subsampling.py b/egs/librispeech/ASR/conformer_mmi/subsampling.py index ad9415987..5c3e1222e 100644 --- a/egs/librispeech/ASR/conformer_mmi/subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/subsampling.py @@ -25,9 +25,13 @@ class Conv2dSubsampling(nn.Module): assert idim >= 7 super().__init__() self.conv = nn.Sequential( - nn.Conv2d(in_channels=1, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), - nn.Conv2d(in_channels=odim, out_channels=odim, kernel_size=3, stride=2), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), nn.ReLU(), ) self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) @@ -111,13 +115,17 @@ class VggSubsampling(nn.Module): ) ) layers.append( - torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0, ceil_mode=True) + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) ) cur_channels = block_dim self.layers = nn.Sequential(*layers) - self.out = nn.Linear(block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim) + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. diff --git a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py index d0bb017dd..937845d77 100755 --- a/egs/librispeech/ASR/conformer_mmi/test_subsampling.py +++ b/egs/librispeech/ASR/conformer_mmi/test_subsampling.py @@ -1,7 +1,8 @@ #!/usr/bin/env python3 +from subsampling import Conv2dSubsampling +from subsampling import VggSubsampling import torch -from subsampling import Conv2dSubsampling, VggSubsampling def test_conv2d_subsampling(): diff --git a/egs/librispeech/ASR/conformer_mmi/test_transformer.py b/egs/librispeech/ASR/conformer_mmi/test_transformer.py index 25d18076d..08e680607 100644 --- a/egs/librispeech/ASR/conformer_mmi/test_transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/test_transformer.py @@ -1,16 +1,17 @@ #!/usr/bin/env python3 import torch -from torch.nn.utils.rnn import pad_sequence from transformer import ( Transformer, - add_eos, - add_sos, - decoder_padding_mask, encoder_padding_mask, generate_square_subsequent_mask, + decoder_padding_mask, + add_sos, + add_eos, ) +from torch.nn.utils.rnn import pad_sequence + def test_encoder_padding_mask(): supervisions = { diff --git a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py index f8c94cff9..011dadd73 100755 --- a/egs/librispeech/ASR/conformer_mmi/train-with-attention.py +++ b/egs/librispeech/ASR/conformer_mmi/train-with-attention.py @@ -36,14 +36,23 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments +from icefall.ali import ( + convert_alignments_to_tensor, + load_alignments, + lookup_alignments, +) from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + encode_supervisions, + setup_logger, + str2bool, +) def get_parser(): @@ -361,7 +370,10 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: + if ( + params.batch_idx_train > params.use_ali_until + and params.beam_size < 8 + ): # logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -750,14 +762,19 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): train_dl.sampler.set_epoch(epoch) - if params.batch_idx_train >= params.use_ali_until and train_ali is not None: + if ( + params.batch_idx_train >= params.use_ali_until + and train_ali is not None + ): # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/train.py b/egs/librispeech/ASR/conformer_mmi/train.py index 5cfb2bfc7..9a5bdcce2 100755 --- a/egs/librispeech/ASR/conformer_mmi/train.py +++ b/egs/librispeech/ASR/conformer_mmi/train.py @@ -36,14 +36,23 @@ from torch.nn.utils import clip_grad_norm_ from torch.utils.tensorboard import SummaryWriter from transformer import Noam -from icefall.ali import convert_alignments_to_tensor, load_alignments, lookup_alignments +from icefall.ali import ( + convert_alignments_to_tensor, + load_alignments, + lookup_alignments, +) from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dist import cleanup_dist, setup_dist from icefall.lexicon import Lexicon from icefall.mmi import LFMMILoss from icefall.mmi_graph_compiler import MmiTrainingGraphCompiler -from icefall.utils import AttributeDict, encode_supervisions, setup_logger, str2bool +from icefall.utils import ( + AttributeDict, + encode_supervisions, + setup_logger, + str2bool, +) def get_parser(): @@ -368,7 +377,10 @@ def compute_loss( nnet_output = nnet_output.clone() nnet_output[:, :min_len, :] += ali_scale * mask[:, :min_len, :] - if params.batch_idx_train > params.use_ali_until and params.beam_size < 8: + if ( + params.batch_idx_train > params.use_ali_until + and params.beam_size < 8 + ): logging.info("Change beam size to 8") params.beam_size = 8 else: @@ -758,14 +770,19 @@ def run(rank, world_size, args): for epoch in range(params.start_epoch, params.num_epochs): fix_random_seed(params.seed + epoch) train_dl.sampler.set_epoch(epoch) - if params.batch_idx_train >= params.use_ali_until and train_ali is not None: + if ( + params.batch_idx_train >= params.use_ali_until + and train_ali is not None + ): # Delete the alignments to save memory train_ali = None valid_ali = None cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/conformer_mmi/transformer.py b/egs/librispeech/ASR/conformer_mmi/transformer.py index 2542d9abe..68a4ff65c 100644 --- a/egs/librispeech/ASR/conformer_mmi/transformer.py +++ b/egs/librispeech/ASR/conformer_mmi/transformer.py @@ -148,7 +148,9 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) self.decoder_criterion = LabelSmoothingLoss(self.decoder_num_class) else: @@ -180,7 +182,9 @@ class Transformer(nn.Module): x = x.permute(0, 2, 1) # (N, T, C) -> (N, C, T) x = self.feat_batchnorm(x) x = x.permute(0, 2, 1) # (N, C, T) -> (N, T, C) - encoder_memory, memory_key_padding_mask = self.run_encoder(x, supervision) + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) x = self.ctc_output(encoder_memory) return x, encoder_memory, memory_key_padding_mask @@ -270,7 +274,9 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -335,7 +341,9 @@ class Transformer(nn.Module): ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -608,7 +616,9 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) class PositionalEncoding(nn.Module): @@ -877,7 +887,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -898,7 +910,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py index a1c43f7f5..620d69a19 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/decode.py @@ -135,24 +135,20 @@ def get_parser(): "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -219,7 +215,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -287,7 +284,9 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -302,7 +301,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -348,7 +350,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -421,7 +427,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -454,7 +462,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -497,7 +506,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -529,12 +540,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -557,12 +569,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -590,7 +603,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py index 0639ba746..8ca7d5568 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/emformer.py @@ -35,6 +35,7 @@ from scaling import ( from icefall.utils import make_pad_mask + LOG_EPSILON = math.log(1e-10) @@ -126,7 +127,9 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) + attn_caches[li][si] = torch.stack( + attn_caches[li][si], dim=1 + ) conv_caches = [] for layer in state_list[0][1]: @@ -265,7 +268,9 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) + first = torch.arange( + self.chunk_length, self.chunk_length + self.cache_size + ) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -279,7 +284,9 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: + def _merge_right_context( + self, right_context: torch.Tensor, B: int + ) -> torch.Tensor: """ Args: right_context: @@ -330,8 +337,12 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) - pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) + cache = torch.zeros( + B, D, self.cache_size, device=x.device, dtype=x.dtype + ) + pad_utterance = torch.cat( + [cache, utterance], dim=2 + ) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -344,7 +355,9 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context(right_context, B) # (B, D, R) + right_context = self._merge_right_context( + right_context, B + ) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -445,7 +458,8 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple ofnhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple of" + f"nhead ({nhead})." ) self.embed_dim = embed_dim @@ -455,7 +469,9 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) + self.emb_to_key_value = ScaledLinear( + embed_dim, 2 * embed_dim, bias=True + ) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -497,7 +513,9 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) + attention_weights_float = attention_weights_float.view( + B, self.nhead, Q, -1 + ) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -533,7 +551,9 @@ class EmformerAttention(nn.Module): scaling = float(self.head_dim) ** -0.5 # compute query with [right_context, utterance, summary]. - query = self.emb_to_query(torch.cat([right_context, utterance, summary])) + query = self.emb_to_query( + torch.cat([right_context, utterance, summary]) + ) # compute key and value with [memory, right_context, utterance]. key, value = self.emb_to_key_value( torch.cat([memory, right_context, utterance]) @@ -544,12 +564,16 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) + value = torch.cat( + [value[: M + R], left_context_val, value[M + R :]] + ) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) + tensor.contiguous() + .view(-1, B * self.nhead, self.head_dim) + .transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -564,7 +588,9 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + attention = ( + attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + ) # apply output projection outputs = self.out_proj(attention) @@ -646,7 +672,12 @@ class EmformerAttention(nn.Module): - output of right context and utterance, with shape (R + U, B, D). - memory output, with shape (M, B, D), where M = S - 1 or M = 0. """ - (output_right_context_utterance, output_memory, _, _,) = self._forward_impl( + ( + output_right_context_utterance, + output_memory, + _, + _, + ) = self._forward_impl( utterance, right_context, summary, @@ -916,9 +947,13 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + ) else: - summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + summary = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) output_right_context_utterance, output_memory = self.attention( utterance=utterance, right_context=right_context, @@ -957,10 +992,14 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - summary = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1) + summary = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + ) summary = summary[:1] else: - summary = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + summary = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) ( output_right_context_utterance, output_memory, @@ -975,7 +1014,9 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) + attn_cache = self._update_attn_cache( + next_key, next_val, memory, attn_cache + ) return output_right_context_utterance, output_memory, attn_cache def forward( @@ -1110,7 +1151,11 @@ class EmformerEncoderLayer(nn.Module): src = src + self.dropout(self.feed_forward_macaron(src)) # emformer attention module - (src_att, output_memory, attn_cache,) = self._apply_attention_module_infer( + ( + src_att, + output_memory, + attn_cache, + ) = self._apply_attention_module_infer( src, R, memory, attn_cache, padding_mask=padding_mask ) src = src + self.dropout(src_att) @@ -1250,7 +1295,9 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) + num_chunks = math.ceil( + (T - self.right_context_length) / self.chunk_length + ) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1269,7 +1316,9 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: + def _gen_attention_mask_col_widths( + self, chunk_idx: int, U: int + ) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1430,7 +1479,9 @@ class EmformerEncoder(nn.Module): output_lengths = torch.clamp(lengths - self.right_context_length, min=0) attention_mask = self._gen_attention_mask(utterance) memory = ( - self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[:-1] + self.init_memory_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ + :-1 + ] if self.use_memory else torch.empty(0).to(dtype=x.dtype, device=x.device) ) @@ -1592,8 +1643,12 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), ] for _ in range(self.num_encoder_layers) ] @@ -1638,11 +1693,17 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if left_context_length != 0 and left_context_length % subsampling_factor != 0: + if ( + left_context_length != 0 + and left_context_length % subsampling_factor != 0 + ): raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if right_context_length != 0 and right_context_length % subsampling_factor != 0: + if ( + right_context_length != 0 + and right_context_length % subsampling_factor != 0 + ): raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1705,7 +1766,9 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) + output, output_lengths = self.encoder( + x, x_lens, warmup=warmup + ) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py index 59105e286..4930881ea 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/export.py @@ -103,11 +103,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -138,20 +136,19 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) add_model_arguments(parser) @@ -184,12 +181,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -212,12 +210,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -245,7 +244,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -280,7 +279,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py index c211b215e..9494e1fc1 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/stream.py @@ -68,12 +68,14 @@ class Stream(object): elif params.decoding_method == "fast_beam_search": # feature_len is needed to get partial results. # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( - decoding_graph + self.rnnt_decoding_stream: k2.RnntDecodingStream = ( + k2.RnntDecodingStream(decoding_graph) ) self.hyp: Optional[List[int]] = None else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) self.ground_truth: str = "" diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py index abe83732a..61dbe8658 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/streaming_decode.py @@ -113,9 +113,8 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( @@ -132,24 +131,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -216,7 +211,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -375,7 +371,9 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -392,7 +390,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,10 +551,14 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor(num_processed_frames_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 + tail_length = ( + 3 * params.subsampling_factor + params.right_context_length + 3 + ) if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -601,7 +605,9 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) # Update cached states of each stream state_list = unstack_states(states) @@ -776,7 +782,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -824,7 +831,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -858,12 +867,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -886,12 +896,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -919,7 +930,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py index a76417e5f..c07d8f76b 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless/train.py @@ -95,7 +95,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -263,45 +265,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -637,7 +636,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -665,16 +668,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -861,7 +871,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -969,7 +981,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py index 9cb4a5afc..98b8290b5 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/decode.py @@ -135,24 +135,20 @@ def get_parser(): "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -219,7 +215,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -287,7 +284,9 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -302,7 +301,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -348,7 +350,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -421,7 +427,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -454,7 +462,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -497,7 +506,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -529,12 +540,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -557,12 +569,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -590,7 +603,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py index 09200f2e1..f16f5acc7 100644 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/emformer.py @@ -35,6 +35,7 @@ from scaling import ( from icefall.utils import make_pad_mask + LOG_EPSILON = math.log(1e-10) @@ -126,7 +127,9 @@ def stack_states( for si, s in enumerate(layer): attn_caches[li][si].append(s) if b == batch_size - 1: - attn_caches[li][si] = torch.stack(attn_caches[li][si], dim=1) + attn_caches[li][si] = torch.stack( + attn_caches[li][si], dim=1 + ) conv_caches = [] for layer in state_list[0][1]: @@ -265,7 +268,9 @@ class ConvolutionModule(nn.Module): intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length ) - first = torch.arange(self.chunk_length, self.chunk_length + self.cache_size) + first = torch.arange( + self.chunk_length, self.chunk_length + self.cache_size + ) indexes = intervals.unsqueeze(1) + first.unsqueeze(0) indexes = torch.cat( [indexes, torch.arange(U_ - self.cache_size, U_).unsqueeze(0)] @@ -279,7 +284,9 @@ class ConvolutionModule(nn.Module): # (num_chunks * B, cache_size + right_context_length, D) return pad_right_context.permute(0, 2, 1) - def _merge_right_context(self, right_context: torch.Tensor, B: int) -> torch.Tensor: + def _merge_right_context( + self, right_context: torch.Tensor, B: int + ) -> torch.Tensor: """ Args: right_context: @@ -330,8 +337,12 @@ class ConvolutionModule(nn.Module): right_context = x[:, :, :R] # (B, D, R) # make causal convolution - cache = torch.zeros(B, D, self.cache_size, device=x.device, dtype=x.dtype) - pad_utterance = torch.cat([cache, utterance], dim=2) # (B, D, cache + U) + cache = torch.zeros( + B, D, self.cache_size, device=x.device, dtype=x.dtype + ) + pad_utterance = torch.cat( + [cache, utterance], dim=2 + ) # (B, D, cache + U) # depth-wise conv on utterance utterance = self.depthwise_conv(pad_utterance) # (B, D, U) @@ -344,7 +355,9 @@ class ConvolutionModule(nn.Module): right_context = self.depthwise_conv( pad_right_context ) # (num_segs * B, D, right_context_length) - right_context = self._merge_right_context(right_context, B) # (B, D, R) + right_context = self._merge_right_context( + right_context, B + ) # (B, D, R) x = torch.cat([right_context, utterance], dim=2) # (B, D, R + U) x = self.deriv_balancer2(x) @@ -445,7 +458,8 @@ class EmformerAttention(nn.Module): if embed_dim % nhead != 0: raise ValueError( - f"embed_dim ({embed_dim}) is not a multiple ofnhead ({nhead})." + f"embed_dim ({embed_dim}) is not a multiple of" + f"nhead ({nhead})." ) self.embed_dim = embed_dim @@ -455,7 +469,9 @@ class EmformerAttention(nn.Module): self.head_dim = embed_dim // nhead self.dropout = dropout - self.emb_to_key_value = ScaledLinear(embed_dim, 2 * embed_dim, bias=True) + self.emb_to_key_value = ScaledLinear( + embed_dim, 2 * embed_dim, bias=True + ) self.emb_to_query = ScaledLinear(embed_dim, embed_dim, bias=True) self.out_proj = ScaledLinear( embed_dim, embed_dim, bias=True, initial_scale=0.25 @@ -497,7 +513,9 @@ class EmformerAttention(nn.Module): if padding_mask is not None: Q = attention_weights.size(1) B = attention_weights.size(0) // self.nhead - attention_weights_float = attention_weights_float.view(B, self.nhead, Q, -1) + attention_weights_float = attention_weights_float.view( + B, self.nhead, Q, -1 + ) attention_weights_float = attention_weights_float.masked_fill( padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), self.negative_inf, @@ -543,12 +561,16 @@ class EmformerAttention(nn.Module): # [memory, right context, left context, uttrance] # this is used in inference mode key = torch.cat([key[: M + R], left_context_key, key[M + R :]]) - value = torch.cat([value[: M + R], left_context_val, value[M + R :]]) + value = torch.cat( + [value[: M + R], left_context_val, value[M + R :]] + ) Q = query.size(0) # KV = key.size(0) reshaped_query, reshaped_key, reshaped_value = [ - tensor.contiguous().view(-1, B * self.nhead, self.head_dim).transpose(0, 1) + tensor.contiguous() + .view(-1, B * self.nhead, self.head_dim) + .transpose(0, 1) for tensor in [query, key, value] ] # (B * nhead, Q or KV, head_dim) attention_weights = torch.bmm( @@ -563,7 +585,9 @@ class EmformerAttention(nn.Module): # compute attention outputs attention = torch.bmm(attention_probs, reshaped_value) assert attention.shape == (B * self.nhead, Q, self.head_dim) - attention = attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + attention = ( + attention.transpose(0, 1).contiguous().view(Q, B, self.embed_dim) + ) # apply output projection output_right_context_utterance = self.out_proj(attention) @@ -881,11 +905,13 @@ class EmformerEncoderLayer(nn.Module): right_context = right_context_utterance[:R] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ - :-1, :, : - ] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + )[:-1, :, :] else: - memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) + memory = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) output_right_context_utterance = self.attention( utterance=utterance, right_context=right_context, @@ -922,12 +948,18 @@ class EmformerEncoderLayer(nn.Module): left_context_val = attn_cache[2] if self.use_memory: - memory = self.summary_op(utterance.permute(1, 2, 0)).permute(2, 0, 1)[ - :1, :, : - ] + memory = self.summary_op(utterance.permute(1, 2, 0)).permute( + 2, 0, 1 + )[:1, :, :] else: - memory = torch.empty(0).to(dtype=utterance.dtype, device=utterance.device) - (output_right_context_utterance, next_key, next_val,) = self.attention.infer( + memory = torch.empty(0).to( + dtype=utterance.dtype, device=utterance.device + ) + ( + output_right_context_utterance, + next_key, + next_val, + ) = self.attention.infer( utterance=utterance, right_context=right_context, memory=pre_memory, @@ -935,7 +967,9 @@ class EmformerEncoderLayer(nn.Module): left_context_val=left_context_val, padding_mask=padding_mask, ) - attn_cache = self._update_attn_cache(next_key, next_val, memory, attn_cache) + attn_cache = self._update_attn_cache( + next_key, next_val, memory, attn_cache + ) return output_right_context_utterance, attn_cache def forward( @@ -1192,7 +1226,9 @@ class EmformerEncoder(nn.Module): def _gen_right_context(self, x: torch.Tensor) -> torch.Tensor: """Hard copy each chunk's right context and concat them.""" T = x.shape[0] - num_chunks = math.ceil((T - self.right_context_length) / self.chunk_length) + num_chunks = math.ceil( + (T - self.right_context_length) / self.chunk_length + ) # first (num_chunks - 1) right context block intervals = torch.arange( 0, self.chunk_length * (num_chunks - 1), self.chunk_length @@ -1211,7 +1247,9 @@ class EmformerEncoder(nn.Module): right_context_blocks = x[indexes.reshape(-1)] return right_context_blocks - def _gen_attention_mask_col_widths(self, chunk_idx: int, U: int) -> List[int]: + def _gen_attention_mask_col_widths( + self, chunk_idx: int, U: int + ) -> List[int]: """Calculate column widths (key, value) in attention mask for the chunk_idx chunk.""" num_chunks = math.ceil(U / self.chunk_length) @@ -1511,8 +1549,12 @@ class EmformerEncoder(nn.Module): attn_caches = [ [ torch.zeros(self.memory_size, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), - torch.zeros(self.left_context_length, self.d_model, device=device), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), + torch.zeros( + self.left_context_length, self.d_model, device=device + ), ] for _ in range(self.num_encoder_layers) ] @@ -1557,11 +1599,17 @@ class Emformer(EncoderInterface): raise NotImplementedError( "chunk_length must be a mutiple of subsampling_factor." ) - if left_context_length != 0 and left_context_length % subsampling_factor != 0: + if ( + left_context_length != 0 + and left_context_length % subsampling_factor != 0 + ): raise NotImplementedError( "left_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) - if right_context_length != 0 and right_context_length % subsampling_factor != 0: + if ( + right_context_length != 0 + and right_context_length % subsampling_factor != 0 + ): raise NotImplementedError( "right_context_length must be 0 or a mutiple of subsampling_factor." # noqa ) @@ -1624,7 +1672,9 @@ class Emformer(EncoderInterface): x_lens = (((x_lens - 1) >> 1) - 1) >> 1 assert x.size(0) == x_lens.max().item() - output, output_lengths = self.encoder(x, x_lens, warmup=warmup) # (T, N, C) + output, output_lengths = self.encoder( + x, x_lens, warmup=warmup + ) # (T, N, C) output = output.permute(1, 0, 2) # (T, N, C) -> (N, T, C) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py index 4d05b367c..ab15e0241 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/export.py @@ -103,11 +103,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -138,20 +136,19 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) add_model_arguments(parser) @@ -184,12 +181,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -212,12 +210,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -245,7 +244,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -280,7 +279,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py index 0486ac2eb..71150392d 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/streaming_decode.py @@ -113,9 +113,8 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( @@ -132,24 +131,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -216,7 +211,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -375,7 +371,9 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -392,7 +390,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -551,10 +551,14 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor(num_processed_frames_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) # Make sure it has at least 1 frame after subsampling, first-and-last-frame cutting, and right context cutting # noqa - tail_length = 3 * params.subsampling_factor + params.right_context_length + 3 + tail_length = ( + 3 * params.subsampling_factor + params.right_context_length + 3 + ) if features.size(1) < tail_length: pad_length = tail_length - features.size(1) feature_lens += pad_length @@ -601,7 +605,9 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) # Update cached states of each stream state_list = unstack_states(states) @@ -776,7 +782,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -824,7 +831,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -858,12 +867,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -886,12 +896,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -919,7 +930,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py index 2c2593b56..2bbc45d78 100755 --- a/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/conv_emformer_transducer_stateless2/train.py @@ -95,7 +95,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -263,45 +265,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -637,7 +636,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -665,16 +668,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -861,7 +871,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -969,7 +981,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/local/add_alignment_librispeech.py b/egs/librispeech/ASR/local/add_alignment_librispeech.py index cc34a72d8..fe6a26c51 100755 --- a/egs/librispeech/ASR/local/add_alignment_librispeech.py +++ b/egs/librispeech/ASR/local/add_alignment_librispeech.py @@ -157,7 +157,9 @@ def add_alignment( for ali_path in part_ali_dir.rglob("*.alignment.txt"): ali = parse_alignments(ali_path) alignments.update(ali) - logging.info(f"{part} has {len(alignments.keys())} cuts with alignments.") + logging.info( + f"{part} has {len(alignments.keys())} cuts with alignments." + ) # add alignment attribute and write out cuts_in = load_manifest_lazy(cuts_in_path) @@ -168,14 +170,18 @@ def add_alignment( if origin_id in alignments: ali = alignments[origin_id] else: - logging.info(f"Warning: {origin_id} does not have alignment.") + logging.info( + f"Warning: {origin_id} does not have alignment." + ) ali = [] subcut.alignment = {"word": ali} writer.write(cut, flush=True) def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) parser = get_parser() diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index df6c609bb..c628dfd53 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -57,7 +57,7 @@ def get_args(): return parser.parse_args() -def compile_HLG(lang_dir: str, lm: str = "G_3_gram") -> k2.Fsa: +def compile_HLG(lang_dir: str, lm: str="G_3_gram") -> k2.Fsa: """ Args: lang_dir: @@ -159,7 +159,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compile_lg.py b/egs/librispeech/ASR/local/compile_lg.py index 19bf3bff4..45c4b7f5f 100755 --- a/egs/librispeech/ASR/local/compile_lg.py +++ b/egs/librispeech/ASR/local/compile_lg.py @@ -132,7 +132,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py index 97750f3ea..c0c7ef8c5 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_dev_test.py @@ -80,7 +80,9 @@ def compute_fbank_gigaspeech_dev_test(): def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_gigaspeech_dev_test() diff --git a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py index 37fce11f4..5587106e5 100644 --- a/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py +++ b/egs/librispeech/ASR/local/compute_fbank_gigaspeech_splits.py @@ -48,10 +48,8 @@ def get_parser(): "--batch-duration", type=float, default=600.0, - help=( - "The maximum number of audio seconds in a batch." - "Determines batch size dynamically." - ), + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", ) parser.add_argument( @@ -146,7 +144,9 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_gigaspeech_splits" - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index 9f8503814..ce7d087f0 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -112,7 +112,9 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if "train" in partition: cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -126,7 +128,9 @@ def compute_fbank_librispeech(bpe_model: Optional[str] = None): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index 4a4093ae4..056da29e5 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -83,7 +83,9 @@ def compute_fbank_musan(): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine(part["recordings"] for part in manifests.values()) + recordings=combine( + part["recordings"] for part in manifests.values() + ) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) @@ -99,7 +101,9 @@ def compute_fbank_musan(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan() diff --git a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py index f149b7871..133499c8b 100755 --- a/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py +++ b/egs/librispeech/ASR/local/convert_transcript_words_to_tokens.py @@ -46,19 +46,21 @@ def get_args(): parser.add_argument( "--transcript", type=str, - help=( - "The input transcript file." - "We assume that the transcript file consists of " - "lines. Each line consists of space separated words." - ), + help="The input transcript file." + "We assume that the transcript file consists of " + "lines. Each line consists of space separated words.", ) parser.add_argument("--lexicon", type=str, help="The input lexicon file.") - parser.add_argument("--oov", type=str, default="
", help="The OOV word.") + parser.add_argument( + "--oov", type=str, default=" ", help="The OOV word." + ) return parser.parse_args() -def process_line(lexicon: Dict[str, List[str]], line: str, oov_token: str) -> None: +def process_line( + lexicon: Dict[str, List[str]], line: str, oov_token: str +) -> None: """ Args: lexicon: diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 3518db524..030122aa7 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -87,7 +87,9 @@ def main(out_dir: str): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/filter_cuts.py b/egs/librispeech/ASR/local/filter_cuts.py index fbcc9e24a..dff98a954 100644 --- a/egs/librispeech/ASR/local/filter_cuts.py +++ b/egs/librispeech/ASR/local/filter_cuts.py @@ -79,7 +79,8 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): total += 1 if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) removed += 1 return False @@ -124,7 +125,8 @@ def filter_cuts(cut_set: CutSet, sp: spm.SentencePieceProcessor): ans = cut_set.filter(remove_short_and_long_utterances).to_eager() ratio = removed / total * 100 logging.info( - f"Removed {removed} cuts from {total} cuts. {ratio:.3f}% data is removed." + f"Removed {removed} cuts from {total} cuts. " + f"{ratio:.3f}% data is removed." ) return ans @@ -153,7 +155,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/generate_unique_lexicon.py b/egs/librispeech/ASR/local/generate_unique_lexicon.py index 3459c2f5a..566c0743d 100755 --- a/egs/librispeech/ASR/local/generate_unique_lexicon.py +++ b/egs/librispeech/ASR/local/generate_unique_lexicon.py @@ -91,7 +91,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index e121aefa9..dec8a7442 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -150,7 +150,9 @@ def generate_lexicon( words_pieces_ids: List[List[int]] = sp.encode(words, out_type=int) # Now convert word piece IDs back to word piece strings. - words_pieces: List[List[str]] = [sp.id_to_piece(ids) for ids in words_pieces_ids] + words_pieces: List[List[str]] = [ + sp.id_to_piece(ids) for ids in words_pieces_ids + ] lexicon = [] for word, pieces in zip(words, words_pieces): diff --git a/egs/librispeech/ASR/local/prepare_lm_training_data.py b/egs/librispeech/ASR/local/prepare_lm_training_data.py index 70343fef7..5070341f1 100755 --- a/egs/librispeech/ASR/local/prepare_lm_training_data.py +++ b/egs/librispeech/ASR/local/prepare_lm_training_data.py @@ -137,7 +137,8 @@ def main(): for i in range(num_sentences): if step and i % step == 0: logging.info( - f"Processed number of lines: {i} ({i/num_sentences*100: .3f}%)" + f"Processed number of lines: {i} " + f"({i/num_sentences*100: .3f}%)" ) word_ids = sentences[i] @@ -153,14 +154,18 @@ def main(): sentence_lengths[i] = token_ids.numel() - output["sentence_lengths"] = torch.tensor(sentence_lengths, dtype=torch.int32) + output["sentence_lengths"] = torch.tensor( + sentence_lengths, dtype=torch.int32 + ) torch.save(output, args.lm_archive) logging.info(f"Saved to {args.lm_archive}") if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/local/preprocess_gigaspeech.py b/egs/librispeech/ASR/local/preprocess_gigaspeech.py index 8aa5e461d..077f23039 100644 --- a/egs/librispeech/ASR/local/preprocess_gigaspeech.py +++ b/egs/librispeech/ASR/local/preprocess_gigaspeech.py @@ -119,7 +119,9 @@ def preprocess_giga_speech(): def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) preprocess_giga_speech() diff --git a/egs/librispeech/ASR/local/test_prepare_lang.py b/egs/librispeech/ASR/local/test_prepare_lang.py index 74e025ad7..d4cf62bba 100755 --- a/egs/librispeech/ASR/local/test_prepare_lang.py +++ b/egs/librispeech/ASR/local/test_prepare_lang.py @@ -88,7 +88,9 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) + fsa_disambig = lexicon_to_fst( + lexicon_disambig, phone2id=phone2id, word2id=word2id + ) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/librispeech/ASR/local/validate_manifest.py b/egs/librispeech/ASR/local/validate_manifest.py index 807aaf891..7c57d629a 100755 --- a/egs/librispeech/ASR/local/validate_manifest.py +++ b/egs/librispeech/ASR/local/validate_manifest.py @@ -64,7 +64,8 @@ def validate_supervision_and_cut_time_bounds(c: Cut): if s.end > c.end: raise ValueError( - f"{c.id}: Supervision end time {s.end} is larger than cut end time {c.end}" + f"{c.id}: Supervision end time {s.end} is larger " + f"than cut end time {c.end}" ) @@ -84,7 +85,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py old mode 100644 new mode 100755 index e69de29bb..27414d717 --- a/egs/librispeech/ASR/lstm_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/decode.py @@ -0,0 +1,818 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method greedy_search + +(2) beam search (not recommended) +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method beam_search \ + --beam-size 4 + +(3) modified beam search +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method modified_beam_search \ + --beam-size 4 + +(4) fast beam search (one best) +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 + +(5) fast beam search (nbest) +./lstm_transducer_stateless/decode.py \ + --epoch 30 \ + --avg 15 \ + --exp-dir ./pruned_transducer_stateless3/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(6) fast beam search (nbest oracle WER) +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_oracle \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 \ + --num-paths 200 \ + --nbest-scale 0.5 + +(7) fast beam search (with LG) +./lstm_transducer_stateless/decode.py \ + --epoch 35 \ + --avg 15 \ + --exp-dir ./lstm_transducer_stateless/exp \ + --max-duration 600 \ + --decoding-method fast_beam_search_nbest_LG \ + --beam 20.0 \ + --max-contexts 8 \ + --max-states 64 +""" + + +import argparse +import logging +import math +from collections import defaultdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import ( + beam_search, + fast_beam_search_nbest, + fast_beam_search_nbest_LG, + fast_beam_search_nbest_oracle, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.lexicon import Lexicon +from icefall.utils import ( + AttributeDict, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + +LOG_EPS = math.log(1e-10) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=30, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="lstm_transducer_stateless/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--lang-dir", + type=Path, + default="data/lang_bpe_500", + help="The lang dir containing word table and LG graph", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + - fast_beam_search_nbest + - fast_beam_search_nbest_oracle + - fast_beam_search_nbest_LG + If you use fast_beam_search_nbest_LG, you have to specify + `--lang-dir`, which should contain `LG.pt`. + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search, + fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle + """, + ) + + parser.add_argument( + "--ngram-lm-scale", + type=float, + default=0.01, + help=""" + Used only when --decoding_method is fast_beam_search_nbest_LG. + It specifies the scale for n-gram LM scores. + """, + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search, fast_beam_search_nbest, fast_beam_search_nbest_LG, + and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--num-paths", + type=int, + default=200, + help="""Number of paths for nbest decoding. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + parser.add_argument( + "--nbest-scale", + type=float, + default=0.5, + help="""Scale applied to lattice scores when computing nbest paths. + Used only when the decoding method is fast_beam_search_nbest, + fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""", + ) + + add_model_arguments(parser) + + return parser + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + batch: dict, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[List[str]]]: + """Decode one batch and return the result in a dict. The dict has the + following format: + + - key: It indicates the setting used for decoding. For example, + if greedy_search is used, it would be "greedy_search" + If beam search with a beam size of 7 is used, it would be + "beam_7" + - value: It contains the decoding result. `len(value)` equals to + batch size. `value[i]` is the decoding result for the i-th + utterance in the given batch. + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return the decoding result. See above description for the format of + the returned dict. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + # tail padding here to alleviate the tail deletion problem + num_tail_padded_frames = 35 + feature = torch.nn.functional.pad( + feature, + (0, 0, 0, num_tail_padded_frames), + mode="constant", + value=LOG_EPS, + ) + feature_lens += num_tail_padded_frames + + encoder_out, encoder_out_lens, _ = model.encoder( + x=feature, x_lens=feature_lens + ) + + hyps = [] + + if params.decoding_method == "fast_beam_search": + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_LG": + hyp_tokens = fast_beam_search_nbest_LG( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in hyp_tokens: + hyps.append([word_table[i] for i in hyp]) + elif params.decoding_method == "fast_beam_search_nbest": + hyp_tokens = fast_beam_search_nbest( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "fast_beam_search_nbest_oracle": + hyp_tokens = fast_beam_search_nbest_oracle( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + num_paths=params.num_paths, + ref_texts=sp.encode(supervisions["text"]), + nbest_scale=params.nbest_scale, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.decoding_method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + batch_size = encoder_out.size(0) + + for i in range(batch_size): + # fmt: off + encoder_out_i = encoder_out[i:i + 1, :encoder_out_lens[i]] + # fmt: on + if params.decoding_method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.decoding_method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + hyps.append(sp.decode(hyp).split()) + + if params.decoding_method == "greedy_search": + return {"greedy_search": hyps} + elif "fast_beam_search" in params.decoding_method: + key = f"beam_{params.beam}_" + key += f"max_contexts_{params.max_contexts}_" + key += f"max_states_{params.max_states}" + if "nbest" in params.decoding_method: + key += f"_num_paths_{params.num_paths}_" + key += f"nbest_scale_{params.nbest_scale}" + if "LG" in params.decoding_method: + key += f"_ngram_lm_scale_{params.ngram_lm_scale}" + + return {key: hyps} + else: + return {f"beam_size_{params.beam_size}": hyps} + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + word_table: Optional[k2.SymbolTable] = None, + decoding_graph: Optional[k2.Fsa] = None, +) -> Dict[str, List[Tuple[str, List[str], List[str]]]]: + """Decode dataset. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + word_table: + The word symbol table. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search, fast_beam_search_nbest, + fast_beam_search_nbest_oracle, and fast_beam_search_nbest_LG. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + num_cuts = 0 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + if params.decoding_method == "greedy_search": + log_interval = 50 + else: + log_interval = 20 + + results = defaultdict(list) + for batch_idx, batch in enumerate(dl): + texts = batch["supervisions"]["text"] + cut_ids = [cut.id for cut in batch["supervisions"]["cut"]] + + hyps_dict = decode_one_batch( + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + word_table=word_table, + batch=batch, + ) + + for name, hyps in hyps_dict.items(): + this_batch = [] + assert len(hyps) == len(texts) + for cut_id, hyp_words, ref_text in zip(cut_ids, hyps, texts): + ref_words = ref_text.split() + this_batch.append((cut_id, ref_words, hyp_words)) + + results[name].extend(this_batch) + + num_cuts += len(texts) + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) + return results + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[str, List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + results = sorted(results) + store_transcripts(filename=recog_path, texts=results) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "beam_search", + "fast_beam_search", + "fast_beam_search_nbest", + "fast_beam_search_nbest_LG", + "fast_beam_search_nbest_oracle", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + if "nbest" in params.decoding_method: + params.suffix += f"-nbest-scale-{params.nbest_scale}" + params.suffix += f"-num-paths-{params.num_paths}" + if "LG" in params.decoding_method: + params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-decode-{params.suffix}") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id(" ") + params.unk_id = sp.piece_to_id(" ") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to(device) + model.eval() + + if "fast_beam_search" in params.decoding_method: + if params.decoding_method == "fast_beam_search_nbest_LG": + lexicon = Lexicon(params.lang_dir) + word_table = lexicon.word_table + lg_filename = params.lang_dir / "LG.pt" + logging.info(f"Loading {lg_filename}") + decoding_graph = k2.Fsa.from_dict( + torch.load(lg_filename, map_location=device) + ) + decoding_graph.scores *= params.ngram_lm_scale + else: + word_table = None + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) + else: + decoding_graph = None + word_table = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + # we need cut ids to display recognition results. + args.return_cuts = True + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_clean_dl = librispeech.test_dataloaders(test_clean_cuts) + test_other_dl = librispeech.test_dataloaders(test_other_cuts) + + test_sets = ["test-clean", "test-other"] + test_dl = [test_clean_dl, test_other_dl] + + for test_set, test_dl in zip(test_sets, test_dl): + results_dict = decode_dataset( + dl=test_dl, + params=params, + model=model, + sp=sp, + word_table=word_table, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/export.py b/egs/librispeech/ASR/lstm_transducer_stateless/export.py old mode 100644 new mode 100755 index e69de29bb..13dac6009 --- a/egs/librispeech/ASR/lstm_transducer_stateless/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/export.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This script converts several saved checkpoints +# to a single one using model averaging. +""" + +Usage: + +(1) Export to torchscript model using torch.jit.trace() + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 35 \ + --avg 10 \ + --jit-trace 1 + +It will generate 3 files: `encoder_jit_trace.pt`, +`decoder_jit_trace.pt`, and `joiner_jit_trace.pt`. + +(2) Export `model.state_dict()` + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 35 \ + --avg 10 + +It will generate a file `pretrained.pt` in the given `exp_dir`. You can later +load it by `icefall.checkpoint.load_checkpoint()`. + +To use the generated file with `lstm_transducer_stateless/decode.py`, +you can do: + + cd /path/to/exp_dir + ln -s pretrained.pt epoch-9999.pt + + cd /path/to/egs/librispeech/ASR + ./lstm_transducer_stateless/decode.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --epoch 9999 \ + --avg 1 \ + --max-duration 600 \ + --decoding-method greedy_search \ + --bpe-model data/lang_bpe_500/bpe.model + +Check ./pretrained.py for its usage. + +Note: If you don't want to train a model from scratch, we have +provided one for you. You can get it at + +https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18 + +with the following commands: + + sudo apt-get install git-lfs + git lfs install + git clone https://huggingface.co/Zengwei/icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18 + # You will find the pre-trained model in icefall-asr-librispeech-lstm-transducer-stateless-2022-08-18/exp +""" + +import argparse +import logging +from pathlib import Path + +import sentencepiece as spm +import torch +import torch.nn as nn +from scaling_converter import convert_scaled_to_non_scaled +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.utils import str2bool + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="""It specifies the checkpoint to use for averaging. + Note: Epoch counts from 0. + You can specify --avg to use more checkpoints for model averaging.""", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=True, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="pruned_transducer_stateless3/exp", + help="""It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--jit-trace", + type=str2bool, + default=False, + help="""True to save a model after applying torch.jit.trace. + It will generate 3 files: + - encoder_jit_trace.pt + - decoder_jit_trace.pt + - joiner_jit_trace.pt + + Check ./jit_pretrained.py for how to use them. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + add_model_arguments(parser) + + return parser + + +def export_encoder_model_jit_trace( + encoder_model: nn.Module, + encoder_filename: str, +) -> None: + """Export the given encoder model with torch.jit.trace() + + Note: The warmup argument is fixed to 1. + + Args: + encoder_model: + The input encoder model + encoder_filename: + The filename to save the exported model. + """ + x = torch.zeros(1, 100, 80, dtype=torch.float32) + x_lens = torch.tensor([100], dtype=torch.int64) + states = encoder_model.get_init_states() + + traced_model = torch.jit.trace(encoder_model, (x, x_lens, states)) + traced_model.save(encoder_filename) + logging.info(f"Saved to {encoder_filename}") + + +def export_decoder_model_jit_trace( + decoder_model: nn.Module, + decoder_filename: str, +) -> None: + """Export the given decoder model with torch.jit.trace() + + Note: The argument need_pad is fixed to False. + + Args: + decoder_model: + The input decoder model + decoder_filename: + The filename to save the exported model. + """ + y = torch.zeros(10, decoder_model.context_size, dtype=torch.int64) + need_pad = torch.tensor([False]) + + traced_model = torch.jit.trace(decoder_model, (y, need_pad)) + traced_model.save(decoder_filename) + logging.info(f"Saved to {decoder_filename}") + + +def export_joiner_model_jit_trace( + joiner_model: nn.Module, + joiner_filename: str, +) -> None: + """Export the given joiner model with torch.jit.trace() + + Note: The argument project_input is fixed to True. A user should not + project the encoder_out/decoder_out by himself/herself. The exported joiner + will do that for the user. + + Args: + joiner_model: + The input joiner model + joiner_filename: + The filename to save the exported model. + + """ + encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] + decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] + encoder_out = torch.rand(1, encoder_out_dim, dtype=torch.float32) + decoder_out = torch.rand(1, decoder_out_dim, dtype=torch.float32) + + traced_model = torch.jit.trace(joiner_model, (encoder_out, decoder_out)) + traced_model.save(joiner_filename) + logging.info(f"Saved to {joiner_filename}") + + +@torch.no_grad() +def main(): + args = get_parser().parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id(" ") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.to("cpu") + model.eval() + + if params.jit_trace is True: + convert_scaled_to_non_scaled(model, inplace=True) + logging.info("Using torch.jit.trace()") + encoder_filename = params.exp_dir / "encoder_jit_trace.pt" + export_encoder_model_jit_trace(model.encoder, encoder_filename) + + decoder_filename = params.exp_dir / "decoder_jit_trace.pt" + export_decoder_model_jit_trace(model.decoder, decoder_filename) + + joiner_filename = params.exp_dir / "joiner_jit_trace.pt" + export_joiner_model_jit_trace(model.joiner, joiner_filename) + else: + logging.info("Not using torchscript") + # Save it using a format so that it can be loaded + # by :func:`load_checkpoint` + filename = params.exp_dir / "pretrained.pt" + torch.save({"model": model.state_dict()}, str(filename)) + logging.info(f"Saved to {filename}") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py old mode 100644 new mode 100755 index e69de29bb..594c33e4f --- a/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/jit_pretrained.py @@ -0,0 +1,322 @@ +#!/usr/bin/env python3 +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, either exported by `torch.jit.trace()` +or by `torch.jit.script()`, and uses them to decode waves. +You can use the following command to get the exported models: + +./lstm_transducer_stateless/export.py \ + --exp-dir ./lstm_transducer_stateless/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit-trace 1 + +Usage of this script: + +./lstm_transducer_stateless/jit_pretrained.py \ + --encoder-model-filename ./lstm_transducer_stateless/exp/encoder_jit_trace.pt \ + --decoder-model-filename ./lstm_transducer_stateless/exp/decoder_jit_trace.pt \ + --joiner-model-filename ./lstm_transducer_stateless/exp/joiner_jit_trace.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + /path/to/foo.wav \ + /path/to/bar.wav +""" + +import argparse +import logging +import math +from typing import List + +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from torch.nn.utils.rnn import pad_sequence + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--encoder-model-filename", + type=str, + required=True, + help="Path to the encoder torchscript model. ", + ) + + parser.add_argument( + "--decoder-model-filename", + type=str, + required=True, + help="Path to the decoder torchscript model. ", + ) + + parser.add_argument( + "--joiner-model-filename", + type=str, + required=True, + help="Path to the joiner torchscript model. ", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="Context size of the decoder model", + ) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +def greedy_search( + decoder: torch.jit.ScriptModule, + joiner: torch.jit.ScriptModule, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + context_size: int, +) -> List[List[int]]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + decoder: + The decoder model. + joiner: + The joiner model. + encoder_out: + A 3-D tensor of shape (N, T, C) + encoder_out_lens: + A 1-D tensor of shape (N,). + context_size: + The context size of the decoder model. + Returns: + Return the decoded results for each utterance. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = encoder_out.device + blank_id = 0 # hard-code to 0 + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[blank_id] * context_size for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ).squeeze(1) + + offset = 0 + for batch_size in batch_size_list: + start = offset + end = offset + batch_size + current_encoder_out = packed_encoder_out.data[start:end] + current_encoder_out = current_encoder_out + # current_encoder_out's shape: (batch_size, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = joiner( + current_encoder_out, + decoder_out, + ) + # logits'shape (batch_size, vocab_size) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + hyps[i].append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = decoder( + decoder_input, + need_pad=torch.tensor([False]), + ) + decoder_out = decoder_out.squeeze(1) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + logging.info(vars(args)) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + encoder = torch.jit.load(args.encoder_model_filename) + decoder = torch.jit.load(args.decoder_model_filename) + joiner = torch.jit.load(args.joiner_model_filename) + + encoder.eval() + decoder.eval() + joiner.eval() + + encoder.to(device) + decoder.to(device) + joiner.to(device) + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = args.sample_rate + opts.mel_opts.num_bins = 80 + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {args.sound_files}") + waves = read_sound_files( + filenames=args.sound_files, + expected_sample_rate=args.sample_rate, + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, + batch_first=True, + padding_value=math.log(1e-10), + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + states = encoder.get_init_states(batch_size=features.size(0), device=device) + + encoder_out, encoder_out_lens, _ = encoder( + x=features, + x_lens=feature_lengths, + states=states, + ) + + hyps = greedy_search( + decoder=decoder, + joiner=joiner, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + context_size=args.context_size, + ) + s = "\n" + for filename, hyp in zip(args.sound_files, hyps): + words = sp.decode(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py index e69de29bb..c54a4c478 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/lstm.py @@ -0,0 +1,871 @@ +# Copyright 2022 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +from typing import List, Optional, Tuple + +import torch +from encoder_interface import EncoderInterface +from scaling import ( + ActivationBalancer, + BasicNorm, + DoubleSwish, + ScaledConv2d, + ScaledLinear, + ScaledLSTM, +) +from torch import nn + +LOG_EPSILON = math.log(1e-10) + + +def unstack_states( + states: Tuple[torch.Tensor, torch.Tensor] +) -> List[Tuple[torch.Tensor, torch.Tensor]]: + """ + Unstack the lstm states corresponding to a batch of utterances into a list + of states, where the i-th entry is the state from the i-th utterance. + + Args: + states: + A tuple of 2 elements. + ``states[0]`` is the lstm hidden states, of a batch of utterance. + ``states[1]`` is the lstm cell states, of a batch of utterances. + + Returns: + A list of states. + ``states[i]`` is a tuple of 2 elememts of i-th utterance. + ``states[i][0]`` is the lstm hidden states of i-th utterance. + ``states[i][1]`` is the lstm cell states of i-th utterance. + """ + hidden_states, cell_states = states + + list_hidden_states = hidden_states.unbind(dim=1) + list_cell_states = cell_states.unbind(dim=1) + + ans = [ + (h.unsqueeze(1), c.unsqueeze(1)) + for (h, c) in zip(list_hidden_states, list_cell_states) + ] + return ans + + +def stack_states( + states_list: List[Tuple[torch.Tensor, torch.Tensor]] +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Stack list of lstm states corresponding to separate utterances into a single + lstm state so that it can be used as an input for lstm when those utterances + are formed into a batch. + + Args: + state_list: + Each element in state_list corresponds to the lstm state for a single + utterance. + ``states[i]`` is a tuple of 2 elememts of i-th utterance. + ``states[i][0]`` is the lstm hidden states of i-th utterance. + ``states[i][1]`` is the lstm cell states of i-th utterance. + + + Returns: + A new state corresponding to a batch of utterances. + It is a tuple of 2 elements. + ``states[0]`` is the lstm hidden states, of a batch of utterance. + ``states[1]`` is the lstm cell states, of a batch of utterances. + """ + hidden_states = torch.cat([s[0] for s in states_list], dim=1) + cell_states = torch.cat([s[1] for s in states_list], dim=1) + ans = (hidden_states, cell_states) + return ans + + +class RNN(EncoderInterface): + """ + Args: + num_features (int): + Number of input features. + subsampling_factor (int): + Subsampling factor of encoder (convolution layers before lstm layers) (default=4). # noqa + d_model (int): + Output dimension (default=512). + dim_feedforward (int): + Feedforward dimension (default=2048). + rnn_hidden_size (int): + Hidden dimension for lstm layers (default=1024). + num_encoder_layers (int): + Number of encoder layers (default=12). + dropout (float): + Dropout rate (default=0.1). + layer_dropout (float): + Dropout value for model-level warmup (default=0.075). + aux_layer_period (int): + Period of auxiliary layers used for random combiner during training. + If set to 0, will not use the random combiner (Default). + You can set a positive integer to use the random combiner, e.g., 3. + is_pnnx: + True to make this class exportable via PNNX. + """ + + def __init__( + self, + num_features: int, + subsampling_factor: int = 4, + d_model: int = 512, + dim_feedforward: int = 2048, + rnn_hidden_size: int = 1024, + num_encoder_layers: int = 12, + dropout: float = 0.1, + layer_dropout: float = 0.075, + aux_layer_period: int = 0, + is_pnnx: bool = False, + ) -> None: + super(RNN, self).__init__() + + self.num_features = num_features + self.subsampling_factor = subsampling_factor + if subsampling_factor != 4: + raise NotImplementedError("Support only 'subsampling_factor=4'.") + + # self.encoder_embed converts the input of shape (N, T, num_features) + # to the shape (N, T//subsampling_factor, d_model). + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_features -> d_model + self.encoder_embed = Conv2dSubsampling( + num_features, + d_model, + is_pnnx=is_pnnx, + ) + + self.is_pnnx = is_pnnx + + self.num_encoder_layers = num_encoder_layers + self.d_model = d_model + self.rnn_hidden_size = rnn_hidden_size + + encoder_layer = RNNEncoderLayer( + d_model=d_model, + dim_feedforward=dim_feedforward, + rnn_hidden_size=rnn_hidden_size, + dropout=dropout, + layer_dropout=layer_dropout, + ) + self.encoder = RNNEncoder( + encoder_layer, + num_encoder_layers, + aux_layers=list( + range( + num_encoder_layers // 3, + num_encoder_layers - 1, + aux_layer_period, + ) + ) + if aux_layer_period > 0 + else None, + ) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Args: + x: + The input tensor. Its shape is (N, T, C), where N is the batch size, + T is the sequence length, C is the feature dimension. + x_lens: + A tensor of shape (N,), containing the number of frames in `x` + before padding. + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). + warmup: + A floating point value that gradually increases from 0 throughout + training; when it is >= 1.0 we are "fully warmed up". It is used + to turn modules on sequentially. + + Returns: + A tuple of 3 tensors: + - embeddings: its shape is (N, T', d_model), where T' is the output + sequence lengths. + - lengths: a tensor of shape (batch_size,) containing the number of + frames in `embeddings` before padding. + - updated states, whose shape is the same as the input states. + """ + x = self.encoder_embed(x) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + + # lengths = ((x_lens - 3) // 2 - 1) // 2 # issue an warning + # + # Note: rounding_mode in torch.div() is available only in torch >= 1.8.0 + if not self.is_pnnx: + lengths = (((x_lens - 3) >> 1) - 1) >> 1 + else: + lengths1 = torch.floor((x_lens - 3) / 2) + lengths = torch.floor((lengths1 - 1) / 2) + lengths = lengths.to(x_lens) + + if not torch.jit.is_tracing(): + assert x.size(0) == lengths.max().item() + + if states is None: + x = self.encoder(x, warmup=warmup)[0] + # torch.jit.trace requires returned types to be the same as annotated # noqa + new_states = (torch.empty(0), torch.empty(0)) + else: + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == ( + self.num_encoder_layers, + x.size(1), + self.d_model, + ) + # for cell state + assert states[1].shape == ( + self.num_encoder_layers, + x.size(1), + self.rnn_hidden_size, + ) + x, new_states = self.encoder(x, states) + + x = x.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + return x, lengths, new_states + + @torch.jit.export + def get_init_states( + self, batch_size: int = 1, device: torch.device = torch.device("cpu") + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Get model initial states.""" + # for rnn hidden states + hidden_states = torch.zeros( + (self.num_encoder_layers, batch_size, self.d_model), device=device + ) + cell_states = torch.zeros( + (self.num_encoder_layers, batch_size, self.rnn_hidden_size), + device=device, + ) + return (hidden_states, cell_states) + + +class RNNEncoderLayer(nn.Module): + """ + RNNEncoderLayer is made up of lstm and feedforward networks. + + Args: + d_model: + The number of expected features in the input (required). + dim_feedforward: + The dimension of feedforward network model (default=2048). + rnn_hidden_size: + The hidden dimension of rnn layer. + dropout: + The dropout value (default=0.1). + layer_dropout: + The dropout value for model-level warmup (default=0.075). + """ + + def __init__( + self, + d_model: int, + dim_feedforward: int, + rnn_hidden_size: int, + dropout: float = 0.1, + layer_dropout: float = 0.075, + ) -> None: + super(RNNEncoderLayer, self).__init__() + self.layer_dropout = layer_dropout + self.d_model = d_model + self.rnn_hidden_size = rnn_hidden_size + + assert rnn_hidden_size >= d_model, (rnn_hidden_size, d_model) + self.lstm = ScaledLSTM( + input_size=d_model, + hidden_size=rnn_hidden_size, + proj_size=d_model if rnn_hidden_size > d_model else 0, + num_layers=1, + dropout=0.0, + ) + self.feed_forward = nn.Sequential( + ScaledLinear(d_model, dim_feedforward), + ActivationBalancer(channel_dim=-1), + DoubleSwish(), + nn.Dropout(dropout), + ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), + ) + self.norm_final = BasicNorm(d_model) + + # try to ensure the output is close to zero-mean (or at least, zero-median). # noqa + self.balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 + ) + self.dropout = nn.Dropout(dropout) + + def forward( + self, + src: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Pass the input through the encoder layer. + + Args: + src: + The sequence to the encoder layer (required). + Its shape is (S, N, E), where S is the sequence length, + N is the batch size, and E is the feature number. + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (1, N, d_model); + states[1] is the cell states of all layers, + with shape of (1, N, rnn_hidden_size). + warmup: + It controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + """ + src_orig = src + + warmup_scale = min(0.1 + warmup, 1.0) + # alpha = 1.0 means fully use this encoder layer, 0.0 would mean + # completely bypass it. + if self.training: + alpha = ( + warmup_scale + if torch.rand(()).item() <= (1.0 - self.layer_dropout) + else 0.1 + ) + else: + alpha = 1.0 + + # lstm module + if states is None: + src_lstm = self.lstm(src)[0] + # torch.jit.trace requires returned types be the same as annotated + new_states = (torch.empty(0), torch.empty(0)) + else: + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == (1, src.size(1), self.d_model) + # for cell state + assert states[1].shape == (1, src.size(1), self.rnn_hidden_size) + src_lstm, new_states = self.lstm(src, states) + src = self.dropout(src_lstm) + src + + # feed forward module + src = src + self.dropout(self.feed_forward(src)) + + src = self.norm_final(self.balancer(src)) + + if alpha != 1.0: + src = alpha * src + (1 - alpha) * src_orig + + return src, new_states + + +class RNNEncoder(nn.Module): + """ + RNNEncoder is a stack of N encoder layers. + + Args: + encoder_layer: + An instance of the RNNEncoderLayer() class (required). + num_layers: + The number of sub-encoder-layers in the encoder (required). + """ + + def __init__( + self, + encoder_layer: nn.Module, + num_layers: int, + aux_layers: Optional[List[int]] = None, + ) -> None: + super(RNNEncoder, self).__init__() + self.layers = nn.ModuleList( + [copy.deepcopy(encoder_layer) for i in range(num_layers)] + ) + self.num_layers = num_layers + self.d_model = encoder_layer.d_model + self.rnn_hidden_size = encoder_layer.rnn_hidden_size + + self.aux_layers: List[int] = [] + self.combiner: Optional[nn.Module] = None + if aux_layers is not None: + assert len(set(aux_layers)) == len(aux_layers) + assert num_layers - 1 not in aux_layers + self.aux_layers = aux_layers + [num_layers - 1] + self.combiner = RandomCombine( + num_inputs=len(self.aux_layers), + final_weight=0.5, + pure_prob=0.333, + stddev=2.0, + ) + + def forward( + self, + src: torch.Tensor, + states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + warmup: float = 1.0, + ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + Pass the input through the encoder layer in turn. + + Args: + src: + The sequence to the encoder layer (required). + Its shape is (S, N, E), where S is the sequence length, + N is the batch size, and E is the feature number. + states: + A tuple of 2 tensors (optional). It is for streaming inference. + states[0] is the hidden states of all layers, + with shape of (num_layers, N, d_model); + states[1] is the cell states of all layers, + with shape of (num_layers, N, rnn_hidden_size). + warmup: + It controls selective bypass of of layers; if < 1.0, we will + bypass layers more frequently. + """ + if states is not None: + assert not self.training + assert len(states) == 2 + if not torch.jit.is_tracing(): + # for hidden state + assert states[0].shape == ( + self.num_layers, + src.size(1), + self.d_model, + ) + # for cell state + assert states[1].shape == ( + self.num_layers, + src.size(1), + self.rnn_hidden_size, + ) + + output = src + + outputs = [] + + new_hidden_states = [] + new_cell_states = [] + + for i, mod in enumerate(self.layers): + if states is None: + output = mod(output, warmup=warmup)[0] + else: + layer_state = ( + states[0][i : i + 1, :, :], # h: (1, N, d_model) + states[1][i : i + 1, :, :], # c: (1, N, rnn_hidden_size) + ) + output, (h, c) = mod(output, layer_state) + new_hidden_states.append(h) + new_cell_states.append(c) + + if self.combiner is not None and i in self.aux_layers: + outputs.append(output) + + if self.combiner is not None: + output = self.combiner(outputs) + + if states is None: + new_states = (torch.empty(0), torch.empty(0)) + else: + new_states = ( + torch.cat(new_hidden_states, dim=0), + torch.cat(new_cell_states, dim=0), + ) + + return output, new_states + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape (N, T, idim) to an output + with shape (N, T', odim), where + T' = ((T-3)//2-1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + layer1_channels: int = 8, + layer2_channels: int = 32, + layer3_channels: int = 128, + is_pnnx: bool = False, + ) -> None: + """ + Args: + in_channels: + Number of channels in. The input shape is (N, T, in_channels). + Caution: It requires: T >= 9, in_channels >= 9. + out_channels + Output dim. The output shape is (N, ((T-3)//2-1)//2, out_channels) + layer1_channels: + Number of channels in layer1 + layer1_channels: + Number of channels in layer2 + is_pnnx: + True if we are converting the model to PNNX format. + False otherwise. + """ + assert in_channels >= 9 + super().__init__() + + self.conv = nn.Sequential( + ScaledConv2d( + in_channels=1, + out_channels=layer1_channels, + kernel_size=3, + padding=0, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer1_channels, + out_channels=layer2_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ScaledConv2d( + in_channels=layer2_channels, + out_channels=layer3_channels, + kernel_size=3, + stride=2, + ), + ActivationBalancer(channel_dim=1), + DoubleSwish(), + ) + self.out = ScaledLinear( + layer3_channels * (((in_channels - 3) // 2 - 1) // 2), out_channels + ) + # set learn_eps=False because out_norm is preceded by `out`, and `out` + # itself has learned scale, so the extra degree of freedom is not + # needed. + self.out_norm = BasicNorm(out_channels, learn_eps=False) + # constrain median of output to be close to zero. + self.out_balancer = ActivationBalancer( + channel_dim=-1, min_positive=0.45, max_positive=0.55 + ) + + # ncnn supports only batch size == 1 + self.is_pnnx = is_pnnx + self.conv_out_dim = self.out.weight.shape[1] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is (N, T, idim). + + Returns: + Return a tensor of shape (N, ((T-3)//2-1)//2, odim) + """ + # On entry, x is (N, T, idim) + x = x.unsqueeze(1) # (N, T, idim) -> (N, 1, T, idim) i.e., (N, C, H, W) + x = self.conv(x) + + if torch.jit.is_tracing() and self.is_pnnx: + x = x.permute(0, 2, 1, 3).reshape(1, -1, self.conv_out_dim) + x = self.out(x) + else: + # Now x is of shape (N, odim, ((T-3)//2-1)//2, ((idim-3)//2-1)//2) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + + # Now x is of shape (N, ((T-3)//2-1))//2, odim) + x = self.out_norm(x) + x = self.out_balancer(x) + return x + + +class RandomCombine(nn.Module): + """ + This module combines a list of Tensors, all with the same shape, to + produce a single output of that same shape which, in training time, + is a random combination of all the inputs; but which in test time + will be just the last input. + + The idea is that the list of Tensors will be a list of outputs of multiple + conformer layers. This has a similar effect as iterated loss. (See: + DEJA-VU: DOUBLE FEATURE PRESENTATION AND ITERATED LOSS IN DEEP TRANSFORMER + NETWORKS). + """ + + def __init__( + self, + num_inputs: int, + final_weight: float = 0.5, + pure_prob: float = 0.5, + stddev: float = 2.0, + ) -> None: + """ + Args: + num_inputs: + The number of tensor inputs, which equals the number of layers' + outputs that are fed into this module. E.g. in an 18-layer neural + net if we output layers 16, 12, 18, num_inputs would be 3. + final_weight: + The amount of weight or probability we assign to the + final layer when randomly choosing layers or when choosing + continuous layer weights. + pure_prob: + The probability, on each frame, with which we choose + only a single layer to output (rather than an interpolation) + stddev: + A standard deviation that we add to log-probs for computing + randomized weights. + + The method of choosing which layers, or combinations of layers, to use, + is conceptually as follows:: + + With probability `pure_prob`:: + With probability `final_weight`: choose final layer, + Else: choose random non-final layer. + Else:: + Choose initial log-weights that correspond to assigning + weight `final_weight` to the final layer and equal + weights to other layers; then add Gaussian noise + with variance `stddev` to these log-weights, and normalize + to weights (note: the average weight assigned to the + final layer here will not be `final_weight` if stddev>0). + """ + super().__init__() + assert 0 <= pure_prob <= 1, pure_prob + assert 0 < final_weight < 1, final_weight + assert num_inputs >= 1 + + self.num_inputs = num_inputs + self.final_weight = final_weight + self.pure_prob = pure_prob + self.stddev = stddev + + self.final_log_weight = ( + torch.tensor( + (final_weight / (1 - final_weight)) * (self.num_inputs - 1) + ) + .log() + .item() + ) + + def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: + """Forward function. + Args: + inputs: + A list of Tensor, e.g. from various layers of a transformer. + All must be the same shape, of (*, num_channels) + Returns: + A Tensor of shape (*, num_channels). In test mode + this is just the final input. + """ + num_inputs = self.num_inputs + assert len(inputs) == num_inputs + if not self.training or torch.jit.is_scripting(): + return inputs[-1] + + # Shape of weights: (*, num_inputs) + num_channels = inputs[0].shape[-1] + num_frames = inputs[0].numel() // num_channels + + ndim = inputs[0].ndim + # stacked_inputs: (num_frames, num_channels, num_inputs) + stacked_inputs = torch.stack(inputs, dim=ndim).reshape( + (num_frames, num_channels, num_inputs) + ) + + # weights: (num_frames, num_inputs) + weights = self._get_random_weights( + inputs[0].dtype, inputs[0].device, num_frames + ) + + weights = weights.reshape(num_frames, num_inputs, 1) + # ans: (num_frames, num_channels, 1) + ans = torch.matmul(stacked_inputs, weights) + # ans: (*, num_channels) + + ans = ans.reshape(inputs[0].shape[:-1] + (num_channels,)) + + # The following if causes errors for torch script in torch 1.6.0 + # if __name__ == "__main__": + # # for testing only... + # print("Weights = ", weights.reshape(num_frames, num_inputs)) + return ans + + def _get_random_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ) -> torch.Tensor: + """Return a tensor of random weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired + Returns: + A tensor of shape (num_frames, self.num_inputs), such that + `ans.sum(dim=1)` is all ones. + """ + pure_prob = self.pure_prob + if pure_prob == 0.0: + return self._get_random_mixed_weights(dtype, device, num_frames) + elif pure_prob == 1.0: + return self._get_random_pure_weights(dtype, device, num_frames) + else: + p = self._get_random_pure_weights(dtype, device, num_frames) + m = self._get_random_mixed_weights(dtype, device, num_frames) + return torch.where( + torch.rand(num_frames, 1, device=device) < self.pure_prob, p, m + ) + + def _get_random_pure_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ): + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A one-hot tensor of shape `(num_frames, self.num_inputs)`, with + exactly one weight equal to 1.0 on each frame. + """ + final_prob = self.final_weight + + # final contains self.num_inputs - 1 in all elements + final = torch.full((num_frames,), self.num_inputs - 1, device=device) + # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa + nonfinal = torch.randint( + self.num_inputs - 1, (num_frames,), device=device + ) + + indexes = torch.where( + torch.rand(num_frames, device=device) < final_prob, final, nonfinal + ) + ans = torch.nn.functional.one_hot( + indexes, num_classes=self.num_inputs + ).to(dtype=dtype) + return ans + + def _get_random_mixed_weights( + self, dtype: torch.dtype, device: torch.device, num_frames: int + ): + """Return a tensor of random one-hot weights, of shape + `(num_frames, self.num_inputs)`, + Args: + dtype: + The data-type desired for the answer, e.g. float, double. + device: + The device needed for the answer. + num_frames: + The number of sets of weights desired. + Returns: + A tensor of shape (num_frames, self.num_inputs), which elements + in [0..1] that sum to one over the second axis, i.e. + `ans.sum(dim=1)` is all ones. + """ + logprobs = ( + torch.randn(num_frames, self.num_inputs, dtype=dtype, device=device) + * self.stddev # noqa + ) + logprobs[:, -1] += self.final_log_weight + return logprobs.softmax(dim=1) + + +def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): + print( + f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" # noqa + ) + num_inputs = 3 + num_channels = 50 + m = RandomCombine( + num_inputs=num_inputs, + final_weight=final_weight, + pure_prob=pure_prob, + stddev=stddev, + ) + + x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] + + y = m(x) + assert y.shape == x[0].shape + assert torch.allclose(y, x[0]) # .. since actually all ones. + + +def _test_random_combine_main(): + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.0) + _test_random_combine(0.999, 0, 0.0) + _test_random_combine(0.5, 0, 0.3) + _test_random_combine(0.5, 1, 0.3) + _test_random_combine(0.5, 0.5, 0.3) + + feature_dim = 50 + c = RNN(num_features=feature_dim, d_model=128) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = c( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + ) + f # to remove flake8 warnings + + +if __name__ == "__main__": + feature_dim = 80 + m = RNN( + num_features=feature_dim, + d_model=512, + rnn_hidden_size=1024, + dim_feedforward=2048, + num_encoder_layers=12, + ) + batch_size = 5 + seq_len = 20 + # Just make sure the forward pass runs. + f = m( + torch.randn(batch_size, seq_len, feature_dim), + torch.full((batch_size,), seq_len, dtype=torch.int64), + warmup=0.5, + ) + num_param = sum([p.numel() for p in m.parameters()]) + print(f"Number of model parameters: {num_param}") + + _test_random_combine_main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index e69de29bb..d71132b4a 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -0,0 +1,210 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Tuple + +import k2 +import torch +import torch.nn as nn +from encoder_interface import EncoderInterface +from scaling import ScaledLinear + +from icefall.utils import add_sos + + +class Transducer(nn.Module): + """It implements https://arxiv.org/pdf/1211.3711.pdf + "Sequence Transduction with Recurrent Neural Networks" + """ + + def __init__( + self, + encoder: EncoderInterface, + decoder: nn.Module, + joiner: nn.Module, + encoder_dim: int, + decoder_dim: int, + joiner_dim: int, + vocab_size: int, + ): + """ + Args: + encoder: + It is the transcription network in the paper. Its accepts + two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). + It returns two tensors: `logits` of shape (N, T, encoder_dm) and + `logit_lens` of shape (N,). + decoder: + It is the prediction network in the paper. Its input shape + is (N, U) and its output shape is (N, U, decoder_dim). + It should contain one attribute: `blank_id`. + joiner: + It has two inputs with shapes: (N, T, encoder_dim) and + (N, U, decoder_dim). + Its output shape is (N, T, U, vocab_size). Note that its output + contains unnormalized probs, i.e., not processed by log-softmax. + """ + super().__init__() + assert isinstance(encoder, EncoderInterface), type(encoder) + assert hasattr(decoder, "blank_id") + + self.encoder = encoder + self.decoder = decoder + self.joiner = joiner + + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) + self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: k2.RaggedTensor, + prune_range: int = 5, + am_scale: float = 0.0, + lm_scale: float = 0.0, + warmup: float = 1.0, + reduction: str = "sum", + delay_penalty: float = 0.0, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: + A 3-D tensor of shape (N, T, C). + x_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + y: + A ragged tensor with 2 axes [utt][label]. It contains labels of each + utterance. + prune_range: + The prune range for rnnt loss, it means how many symbols(context) + we are considering for each frame to compute the loss. + am_scale: + The scale to smooth the loss with am (output of encoder network) + part + lm_scale: + The scale to smooth the loss with lm (output of predictor network) + part + warmup: + A value warmup >= 0 that determines which modules are active, values + warmup > 1 "are fully warmed up" and all modules will be active. + reduction: + "sum" to sum the losses over all utterances in the batch. + "none" to return the loss in a 1-D tensor for each utterance + in the batch. + delay_penalty: + A constant value used to penalize symbol delay, to encourage + streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details. + Returns: + Return the transducer loss. + + Note: + Regarding am_scale & lm_scale, it will make the loss-function one of + the form: + lm_scale * lm_probs + am_scale * am_probs + + (1-lm_scale-am_scale) * combined_probs + """ + assert reduction in ("sum", "none"), reduction + assert x.ndim == 3, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.num_axes == 2, y.num_axes + + assert x.size(0) == x_lens.size(0) == y.dim0 + + encoder_out, x_lens, _ = self.encoder(x, x_lens, warmup=warmup) + assert torch.all(x_lens > 0) + + # Now for the decoder, i.e., the prediction network + row_splits = y.shape.row_splits(1) + y_lens = row_splits[1:] - row_splits[:-1] + + blank_id = self.decoder.blank_id + sos_y = add_sos(y, sos_id=blank_id) + + # sos_y_padded: [B, S + 1], start with SOS. + sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) + + # decoder_out: [B, S + 1, decoder_dim] + decoder_out = self.decoder(sos_y_padded) + + # Note: y does not start with SOS + # y_padded : [B, S] + y_padded = y.pad(mode="constant", padding_value=0) + + y_padded = y_padded.to(torch.int64) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) + boundary[:, 2] = y_lens + boundary[:, 3] = x_lens + + lm = self.simple_lm_proj(decoder_out) + am = self.simple_am_proj(encoder_out) + + with torch.cuda.amp.autocast(enabled=False): + simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( + lm=lm.float(), + am=am.float(), + symbols=y_padded, + termination_symbol=blank_id, + lm_only_scale=lm_scale, + am_only_scale=am_scale, + boundary=boundary, + reduction=reduction, + delay_penalty=delay_penalty, + return_grad=True, + ) + + # ranges : [B, T, prune_range] + ranges = k2.get_rnnt_prune_ranges( + px_grad=px_grad, + py_grad=py_grad, + boundary=boundary, + s_range=prune_range, + ) + + # am_pruned : [B, T, prune_range, encoder_dim] + # lm_pruned : [B, T, prune_range, decoder_dim] + am_pruned, lm_pruned = k2.do_rnnt_pruning( + am=self.joiner.encoder_proj(encoder_out), + lm=self.joiner.decoder_proj(decoder_out), + ranges=ranges, + ) + + # logits : [B, T, prune_range, vocab_size] + + # project_input=False since we applied the decoder's input projections + # prior to do_rnnt_pruning (this is an optimization for speed). + logits = self.joiner(am_pruned, lm_pruned, project_input=False) + + with torch.cuda.amp.autocast(enabled=False): + pruned_loss = k2.rnnt_loss_pruned( + logits=logits.float(), + symbols=y_padded, + ranges=ranges, + termination_symbol=blank_id, + boundary=boundary, + delay_penalty=delay_penalty, + reduction=reduction, + ) + + return (simple_loss, pruned_loss) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py old mode 100644 new mode 100755 index e69de29bb..2a6e2adc6 --- a/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/pretrained.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +(1) greedy search +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method greedy_search \ + /path/to/foo.wav \ + /path/to/bar.wav + +(2) beam search +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(3) modified beam search +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method modified_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +(4) fast beam search +./lstm_transducer_stateless/pretrained.py \ + --checkpoint ./lstm_transducer_stateless/exp/pretrained.pt \ + --bpe-model ./data/lang_bpe_500/bpe.model \ + --method fast_beam_search \ + --beam-size 4 \ + /path/to/foo.wav \ + /path/to/bar.wav + +You can also use `./lstm_transducer_stateless/exp/epoch-xx.pt`. + +Note: ./lstm_transducer_stateless/exp/pretrained.pt is generated by +./lstm_transducer_stateless/export.py +""" + + +import argparse +import logging +import math +from typing import List + +import k2 +import kaldifeat +import sentencepiece as spm +import torch +import torchaudio +from beam_search import ( + beam_search, + fast_beam_search_one_best, + greedy_search, + greedy_search_batch, + modified_beam_search, +) +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", + ) + + parser.add_argument( + "--bpe-model", + type=str, + help="""Path to bpe.model.""", + ) + + parser.add_argument( + "--method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - beam_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "sound_files", + type=str, + nargs="+", + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", + ) + + parser.add_argument( + "--sample-rate", + type=int, + default=16000, + help="The sample rate of the input sound file", + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An integer indicating how many candidates we will keep for each + frame. Used only when --method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=4, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=4, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=8, + help="""Used only when --method is fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. Used only when + --method is greedy_search. + """, + ) + + add_model_arguments(parser) + + return parser + + +def read_sound_files( + filenames: List[str], expected_sample_rate: float +) -> List[torch.Tensor]: + """Read a list of sound files into a list 1-D float32 torch tensors. + Args: + filenames: + A list of sound filenames. + expected_sample_rate: + The expected sample rate of the sound files. + Returns: + Return a list of 1-D float32 torch tensors. + """ + ans = [] + for f in filenames: + wave, sample_rate = torchaudio.load(f) + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) + # We use only the first channel + ans.append(wave[0]) + return ans + + +@torch.no_grad() +def main(): + parser = get_parser() + args = parser.parse_args() + + params = get_params() + + params.update(vars(args)) + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id(" ") + params.unk_id = sp.piece_to_id(" ") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"device: {device}") + + logging.info("Creating model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + checkpoint = torch.load(args.checkpoint, map_location="cpu") + model.load_state_dict(checkpoint["model"], strict=False) + model.to(device) + model.eval() + model.device = device + + logging.info("Constructing Fbank computer") + opts = kaldifeat.FbankOptions() + opts.device = device + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = params.sample_rate + opts.mel_opts.num_bins = params.feature_dim + + fbank = kaldifeat.Fbank(opts) + + logging.info(f"Reading sound files: {params.sound_files}") + waves = read_sound_files( + filenames=params.sound_files, expected_sample_rate=params.sample_rate + ) + waves = [w.to(device) for w in waves] + + logging.info("Decoding started") + features = fbank(waves) + feature_lengths = [f.size(0) for f in features] + + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) + + feature_lengths = torch.tensor(feature_lengths, device=device) + + encoder_out, encoder_out_lens, _ = model.encoder( + x=features, x_lens=feature_lengths + ) + + num_waves = encoder_out.size(0) + hyps = [] + msg = f"Using {params.method}" + if params.method == "beam_search": + msg += f" with beam size {params.beam_size}" + logging.info(msg) + + if params.method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + hyp_tokens = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "modified_beam_search": + hyp_tokens = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + ) + + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + elif params.method == "greedy_search" and params.max_sym_per_frame == 1: + hyp_tokens = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + ) + for hyp in sp.decode(hyp_tokens): + hyps.append(hyp.split()) + else: + for i in range(num_waves): + # fmt: off + encoder_out_i = encoder_out[i:i+1, :encoder_out_lens[i]] + # fmt: on + if params.method == "greedy_search": + hyp = greedy_search( + model=model, + encoder_out=encoder_out_i, + max_sym_per_frame=params.max_sym_per_frame, + ) + elif params.method == "beam_search": + hyp = beam_search( + model=model, + encoder_out=encoder_out_i, + beam=params.beam_size, + ) + else: + raise ValueError(f"Unsupported method: {params.method}") + + hyps.append(sp.decode(hyp).split()) + + s = "\n" + for filename, hyp in zip(params.sound_files, hyps): + words = " ".join(hyp) + s += f"{filename}:\n{words}\n\n" + logging.info(s) + + logging.info("Decoding Done") + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py index e69de29bb..97d890c82 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/stream.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/stream.py @@ -0,0 +1,148 @@ +# Copyright 2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import List, Optional, Tuple + +import k2 +import torch +from beam_search import Hypothesis, HypothesisList + +from icefall.utils import AttributeDict + + +class Stream(object): + def __init__( + self, + params: AttributeDict, + cut_id: str, + decoding_graph: Optional[k2.Fsa] = None, + device: torch.device = torch.device("cpu"), + LOG_EPS: float = math.log(1e-10), + ) -> None: + """ + Args: + params: + It's the return value of :func:`get_params`. + cut_id: + The cut id of the current stream. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or HLG, Used + only when --decoding_method is fast_beam_search. + device: + The device to run this stream. + LOG_EPS: + A float value used for padding. + """ + self.LOG_EPS = LOG_EPS + self.cut_id = cut_id + + # Containing attention caches and convolution caches + self.states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + # It uses different attributes for different decoding methods. + self.context_size = params.context_size + self.decoding_method = params.decoding_method + if params.decoding_method == "greedy_search": + self.hyp = [params.blank_id] * params.context_size + elif params.decoding_method == "modified_beam_search": + self.hyps = HypothesisList() + self.hyps.add( + Hypothesis( + ys=[params.blank_id] * params.context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + ) + ) + elif params.decoding_method == "fast_beam_search": + # feature_len is needed to get partial results. + # The rnnt_decoding_stream for fast_beam_search. + self.rnnt_decoding_stream: k2.RnntDecodingStream = ( + k2.RnntDecodingStream(decoding_graph) + ) + self.hyp: Optional[List[int]] = None + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + + self.ground_truth: str = "" + + self.feature: Optional[torch.Tensor] = None + # Make sure all feature frames can be used. + # We aim to obtain 1 frame after subsampling. + self.chunk_length = params.subsampling_factor + self.pad_length = 5 + self.num_frames = 0 + self.num_processed_frames = 0 + + # After all feature frames are processed, we set this flag to True + self._done = False + + def set_feature(self, feature: torch.Tensor) -> None: + assert feature.dim() == 2, feature.dim() + # tail padding here to alleviate the tail deletion problem + num_tail_padded_frames = 35 + self.num_frames = feature.size(0) + num_tail_padded_frames + self.feature = torch.nn.functional.pad( + feature, + (0, 0, 0, self.pad_length + num_tail_padded_frames), + mode="constant", + value=self.LOG_EPS, + ) + + def get_feature_chunk(self) -> torch.Tensor: + """Get a chunk of feature frames. + + Returns: + A tensor of shape (ret_length, feature_dim). + """ + update_length = min( + self.num_frames - self.num_processed_frames, self.chunk_length + ) + ret_length = update_length + self.pad_length + + ret_feature = self.feature[ + self.num_processed_frames : self.num_processed_frames + ret_length + ] + # Cut off used frames. + # self.feature = self.feature[update_length:] + + self.num_processed_frames += update_length + if self.num_processed_frames >= self.num_frames: + self._done = True + + return ret_feature + + @property + def id(self) -> str: + return self.cut_id + + @property + def done(self) -> bool: + """Return True if all feature frames are processed.""" + return self._done + + def decoding_result(self) -> List[int]: + """Obtain current decoding result.""" + if self.decoding_method == "greedy_search": + return self.hyp[self.context_size :] + elif self.decoding_method == "modified_beam_search": + best_hyp = self.hyps.get_most_probable(length_norm=True) + return best_hyp.ys[self.context_size :] + else: + assert self.decoding_method == "fast_beam_search" + return self.hyp diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py old mode 100644 new mode 100755 index e69de29bb..d6376bdc0 --- a/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/streaming_decode.py @@ -0,0 +1,968 @@ +#!/usr/bin/env python3 +# +# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +(1) greedy search +./lstm_transducer_stateless/streaming_decode.py \ + --epoch 35 \ + --avg 10 \ + --exp-dir lstm_transducer_stateless/exp \ + --num-decode-streams 2000 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method greedy_search \ + --use-averaged-model True + +(2) modified beam search +./lstm_transducer_stateless/streaming_decode.py \ + --epoch 35 \ + --avg 10 \ + --exp-dir lstm_transducer_stateless/exp \ + --num-decode-streams 2000 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method modified_beam_search \ + --use-averaged-model True \ + --beam-size 4 + +(3) fast beam search +./lstm_transducer_stateless/streaming_decode.py \ + --epoch 35 \ + --avg 10 \ + --exp-dir lstm_transducer_stateless/exp \ + --num-decode-streams 2000 \ + --num-encoder-layers 12 \ + --rnn-hidden-size 1024 \ + --decoding-method fast_beam_search \ + --use-averaged-model True \ + --beam 4 \ + --max-contexts 4 \ + --max-states 8 +""" +import argparse +import logging +import warnings +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import k2 +import numpy as np +import sentencepiece as spm +import torch +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from beam_search import Hypothesis, HypothesisList, get_hyps_shape +from kaldifeat import Fbank, FbankOptions +from lhotse import CutSet +from lstm import LOG_EPSILON, stack_states, unstack_states +from stream import Stream +from torch.nn.utils.rnn import pad_sequence +from train import add_model_arguments, get_params, get_transducer_model + +from icefall.checkpoint import ( + average_checkpoints, + average_checkpoints_with_averaged_model, + find_checkpoints, + load_checkpoint, +) +from icefall.decode import one_best_decoding +from icefall.utils import ( + AttributeDict, + get_texts, + setup_logger, + store_transcripts, + str2bool, + write_error_stats, +) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=28, + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", + ) + + parser.add_argument( + "--iter", + type=int, + default=0, + help="""If positive, --epoch is ignored and it + will use the checkpoint exp_dir/checkpoint-iter.pt. + You can specify --avg to use more checkpoints for model averaging. + """, + ) + + parser.add_argument( + "--avg", + type=int, + default=15, + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", + ) + + parser.add_argument( + "--use-averaged-model", + type=str2bool, + default=False, + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="transducer_emformer/exp", + help="The experiment dir", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + parser.add_argument( + "--beam-size", + type=int, + default=4, + help="""An interger indicating how many candidates we will keep for each + frame. Used only when --decoding-method is beam_search or + modified_beam_search.""", + ) + + parser.add_argument( + "--beam", + type=float, + default=20.0, + help="""A floating point value to calculate the cutoff score during beam + search (i.e., `cutoff = max-score - beam`), which is the same as the + `beam` in Kaldi. + Used only when --decoding-method is fast_beam_search""", + ) + + parser.add_argument( + "--max-contexts", + type=int, + default=8, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--max-states", + type=int, + default=64, + help="""Used only when --decoding-method is + fast_beam_search""", + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + parser.add_argument( + "--max-sym-per-frame", + type=int, + default=1, + help="""Maximum number of symbols per frame. + Used only when --decoding_method is greedy_search""", + ) + + parser.add_argument( + "--sampling-rate", + type=float, + default=16000, + help="Sample rate of the audio", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded in parallel", + ) + + add_model_arguments(parser) + + return parser + + +def greedy_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[Stream], +) -> None: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + streams: + A list of Stream objects. + """ + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + T = encoder_out.size(1) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (batch_size, 1, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + assert logits.ndim == 2, logits.shape + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + decoder_out = model.joiner.decoder_proj(decoder_out) + + +def modified_beam_search( + model: nn.Module, + encoder_out: torch.Tensor, + streams: List[Stream], + beam: int = 4, +): + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The RNN-T model. + encoder_out: + A 3-D tensor of shape (N, T, encoder_out_dim) containing the output of + the encoder model. + streams: + A list of stream objects. + beam: + Number of active paths during the beam search. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert len(streams) == encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + batch_size = len(streams) + T = encoder_out.size(1) + + B = [stream.hyps for stream in streams] + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + current_encoder_out = encoder_out[:, t].unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.stack( + [hyp.log_prob.reshape(1) for hyps in A for hyp in hyps], dim=0 + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, decoder_output_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) + # logits is of shape (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) + + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + if new_token != blank_id: + new_ys.append(new_token) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis(ys=new_ys, log_prob=new_log_prob) + B[i].add(new_hyp) + + for i in range(batch_size): + streams[i].hyps = B[i] + + +def fast_beam_search_one_best( + model: nn.Module, + streams: List[Stream], + encoder_out: torch.Tensor, + processed_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, +) -> None: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using modified beam search, and then + the shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + streams: + A list of stream objects. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + processed_lens: + A tensor of shape (N,) containing the number of processed frames + in `encoder_out` before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + assert B == len(streams) + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(streams[i].rnnt_decoding_stream) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = logits.log_softmax(dim=-1) + decoding_streams.advance(log_probs) + + decoding_streams.terminate_and_flush_to_streams() + + lattice = decoding_streams.format_output(processed_lens.tolist()) + + best_path = one_best_decoding(lattice) + hyps = get_texts(best_path) + + for i in range(B): + streams[i].hyp = hyps[i] + + +def decode_one_chunk( + model: nn.Module, + streams: List[Stream], + params: AttributeDict, + decoding_graph: Optional[k2.Fsa] = None, +) -> List[int]: + """ + Args: + model: + The Transducer model. + streams: + A list of Stream objects. + params: + It is returned by :func:`get_params`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search. + + Returns: + A list of indexes indicating the finished streams. + """ + device = next(model.parameters()).device + + feature_list = [] + feature_len_list = [] + state_list = [] + num_processed_frames_list = [] + + for stream in streams: + # We should first get `stream.num_processed_frames` + # before calling `stream.get_feature_chunk()` + # since `stream.num_processed_frames` would be updated + num_processed_frames_list.append(stream.num_processed_frames) + feature = stream.get_feature_chunk() + feature_len = feature.size(0) + feature_list.append(feature) + feature_len_list.append(feature_len) + state_list.append(stream.states) + + features = pad_sequence( + feature_list, batch_first=True, padding_value=LOG_EPSILON + ).to(device) + feature_lens = torch.tensor(feature_len_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) + + # Make sure it has at least 1 frame after subsampling + tail_length = params.subsampling_factor + 5 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPSILON, + ) + + # Stack states of all streams + states = stack_states(state_list) + + encoder_out, encoder_out_lens, states = model.encoder( + x=features, + x_lens=feature_lens, + states=states, + ) + + if params.decoding_method == "greedy_search": + greedy_search( + model=model, + streams=streams, + encoder_out=encoder_out, + ) + elif params.decoding_method == "modified_beam_search": + modified_beam_search( + model=model, + streams=streams, + encoder_out=encoder_out, + beam=params.beam_size, + ) + elif params.decoding_method == "fast_beam_search": + # feature_len is needed to get partial results. + # The rnnt_decoding_stream for fast_beam_search. + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + processed_lens = ( + num_processed_frames // params.subsampling_factor + + encoder_out_lens + ) + fast_beam_search_one_best( + model=model, + streams=streams, + encoder_out=encoder_out, + processed_lens=processed_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + ) + else: + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) + + # Update cached states of each stream + state_list = unstack_states(states) + for i, s in enumerate(state_list): + streams[i].states = s + + finished_streams = [i for i, stream in enumerate(streams) if stream.done] + return finished_streams + + +def create_streaming_feature_extractor() -> Fbank: + """Create a CPU streaming feature extractor. + + At present, we assume it returns a fbank feature extractor with + fixed options. In the future, we will support passing in the options + from outside. + + Returns: + Return a CPU streaming feature extractor. + """ + opts = FbankOptions() + opts.device = "cpu" + opts.frame_opts.dither = 0 + opts.frame_opts.snip_edges = False + opts.frame_opts.samp_freq = 16000 + opts.mel_opts.num_bins = 80 + return Fbank(opts) + + +def decode_dataset( + cuts: CutSet, + model: nn.Module, + params: AttributeDict, + sp: spm.SentencePieceProcessor, + decoding_graph: Optional[k2.Fsa] = None, +): + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The Transducer model. + sp: + The BPE model. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search. + + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + device = next(model.parameters()).device + + log_interval = 300 + + fbank = create_streaming_feature_extractor() + + decode_results = [] + streams = [] + for num, cut in enumerate(cuts): + # Each utterance has a Stream. + stream = Stream( + params=params, + cut_id=cut.id, + decoding_graph=decoding_graph, + device=device, + LOG_EPS=LOG_EPSILON, + ) + + stream.states = model.encoder.get_init_states(device=device) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + # The trained model is using normalized samples + assert audio.max() <= 1, "Should be normalized to [-1, 1])" + + samples = torch.from_numpy(audio).squeeze(0) + feature = fbank(samples) + stream.set_feature(feature) + stream.ground_truth = cut.supervisions[0].text + + streams.append(stream) + + while len(streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + model=model, + streams=streams, + params=params, + decoding_graph=decoding_graph, + ) + + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + streams[i].id, + streams[i].ground_truth.split(), + sp.decode(streams[i].decoding_result()).split(), + ) + ) + del streams[i] + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + while len(streams) > 0: + finished_streams = decode_one_chunk( + model=model, + streams=streams, + params=params, + decoding_graph=decoding_graph, + ) + + for i in sorted(finished_streams, reverse=True): + decode_results.append( + ( + streams[i].id, + streams[i].ground_truth.split(), + sp.decode(streams[i].decoding_result()).split(), + ) + ) + del streams[i] + + if params.decoding_method == "greedy_search": + key = "greedy_search" + elif params.decoding_method == "fast_beam_search": + key = ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ) + else: + key = f"beam_size_{params.beam_size}" + + return {key: decode_results} + + +def save_results( + params: AttributeDict, + test_set_name: str, + results_dict: Dict[str, List[Tuple[List[str], List[str]]]], +): + test_set_wers = dict() + for key, results in results_dict.items(): + recog_path = ( + params.res_dir / f"recogs-{test_set_name}-{key}-{params.suffix}.txt" + ) + store_transcripts(filename=recog_path, texts=sorted(results)) + logging.info(f"The transcripts are stored in {recog_path}") + + # The following prints out WERs, per-word error statistics and aligned + # ref/hyp pairs. + errs_filename = ( + params.res_dir / f"errs-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_filename, "w") as f: + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=True + ) + test_set_wers[key] = wer + + logging.info("Wrote detailed error stats to {}".format(errs_filename)) + + test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) + errs_info = ( + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + ) + with open(errs_info, "w") as f: + print("settings\tWER", file=f) + for key, val in test_set_wers: + print("{}\t{}".format(key, val), file=f) + + s = "\nFor {}, WER of different settings are:\n".format(test_set_name) + note = "\tbest for {}".format(test_set_name) + for key, val in test_set_wers: + s += "{}\t{}{}\n".format(key, val, note) + note = "" + logging.info(s) + + +@torch.no_grad() +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + params = get_params() + params.update(vars(args)) + + assert params.decoding_method in ( + "greedy_search", + "fast_beam_search", + "modified_beam_search", + ) + params.res_dir = params.exp_dir / "streaming" / params.decoding_method + + if params.iter > 0: + params.suffix = f"iter-{params.iter}-avg-{params.avg}" + else: + params.suffix = f"epoch-{params.epoch}-avg-{params.avg}" + + if "fast_beam_search" in params.decoding_method: + params.suffix += f"-beam-{params.beam}" + params.suffix += f"-max-contexts-{params.max_contexts}" + params.suffix += f"-max-states-{params.max_states}" + elif "beam_search" in params.decoding_method: + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) + else: + params.suffix += f"-context-{params.context_size}" + params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" + + if params.use_averaged_model: + params.suffix += "-use-averaged-model" + + setup_logger(f"{params.res_dir}/log-streaming-decode") + logging.info("Decoding started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # and are defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id(" ") + params.unk_id = sp.piece_to_id(" ") + params.vocab_size = sp.get_piece_size() + + params.device = device + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + if not params.use_averaged_model: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + elif params.avg == 1: + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + else: + start = params.epoch - params.avg + 1 + filenames = [] + for i in range(start, params.epoch + 1): + if i >= 1: + filenames.append(f"{params.exp_dir}/epoch-{i}.pt") + logging.info(f"averaging {filenames}") + model.to(device) + model.load_state_dict(average_checkpoints(filenames, device=device)) + else: + if params.iter > 0: + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] + if len(filenames) == 0: + raise ValueError( + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + elif len(filenames) < params.avg + 1: + raise ValueError( + f"Not enough checkpoints ({len(filenames)}) found for" + f" --iter {params.iter}, --avg {params.avg}" + ) + filename_start = filenames[-1] + filename_end = filenames[0] + logging.info( + "Calculating the averaged model over iteration checkpoints" + f" from {filename_start} (excluded) to {filename_end}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + else: + assert params.avg > 0, params.avg + start = params.epoch - params.avg + assert start >= 1, start + filename_start = f"{params.exp_dir}/epoch-{start}.pt" + filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" + logging.info( + f"Calculating the averaged model over epoch range from " + f"{start} (excluded) to {params.epoch}" + ) + model.to(device) + model.load_state_dict( + average_checkpoints_with_averaged_model( + filename_start=filename_start, + filename_end=filename_end, + device=device, + ) + ) + + model.eval() + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + librispeech = LibriSpeechAsrDataModule(args) + + test_clean_cuts = librispeech.test_clean_cuts() + test_other_cuts = librispeech.test_other_cuts() + + test_sets = ["test-clean", "test-other"] + test_cuts = [test_clean_cuts, test_other_cuts] + + for test_set, test_cut in zip(test_sets, test_cuts): + results_dict = decode_dataset( + cuts=test_cut, + model=model, + params=params, + sp=sp, + decoding_graph=decoding_graph, + ) + + save_results( + params=params, + test_set_name=test_set, + results_dict=results_dict, + ) + + logging.info("Done!") + + +if __name__ == "__main__": + torch.manual_seed(20220810) + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py old mode 100644 new mode 100755 index e69de29bb..d30fc260a --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -0,0 +1,1157 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo,) +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +./lstm_transducer_stateless/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --exp-dir lstm_transducer_stateless/exp \ + --full-libri 1 \ + --max-duration 300 + +# For mix precision training: + +./lstm_transducer_stateless/train.py \ + --world-size 4 \ + --num-epochs 30 \ + --start-epoch 1 \ + --use-fp16 1 \ + --exp-dir lstm_transducer_stateless/exp \ + --full-libri 1 \ + --max-duration 550 +""" + +import argparse +import copy +import logging +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import k2 +import optim +import sentencepiece as spm +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from asr_datamodule import LibriSpeechAsrDataModule +from decoder import Decoder +from joiner import Joiner +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from lstm import RNN +from model import Transducer +from optim import Eden, Eve +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from icefall import diagnostics +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.utils import ( + AttributeDict, + MetricsTracker, + display_and_save_batch, + setup_logger, + str2bool, +) + +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--num-encoder-layers", + type=int, + default=12, + help="Number of RNN encoder layers..", + ) + + parser.add_argument( + "--encoder-dim", + type=int, + default=512, + help="Encoder output dimesion.", + ) + + parser.add_argument( + "--rnn-hidden-size", + type=int, + default=1024, + help="Hidden dim for LSTM layers.", + ) + + parser.add_argument( + "--aux-layer-period", + type=int, + default=0, + help="""Peroid of auxiliary layers used for randomly combined during training. + If set to 0, will not use the random combiner (Default). + You can set a positive integer to use the random combiner, e.g., 3. + """, + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=35, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="lstm_transducer_stateless/exp", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--initial-lr", + type=float, + default=0.003, + help="""The initial learning rate. This value should not need to be + changed.""", + ) + + parser.add_argument( + "--lr-batches", + type=float, + default=5000, + help="""Number of steps that affects how rapidly the learning rate decreases. + We suggest not to change this.""", + ) + + parser.add_argument( + "--lr-epochs", + type=float, + default=10, + help="""Number of epochs that affects how rapidly the learning rate decreases. + """, + ) + + parser.add_argument( + "--context-size", + type=int, + default=2, + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", + ) + + parser.add_argument( + "--prune-range", + type=int, + default=5, + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", + ) + + parser.add_argument( + "--lm-scale", + type=float, + default=0.25, + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", + ) + + parser.add_argument( + "--am-scale", + type=float, + default=0.0, + help="The scale to smooth the loss with am (output of encoder network)" + "part.", + ) + + parser.add_argument( + "--simple-loss-scale", + type=float, + default=0.5, + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--print-diagnostics", + type=str2bool, + default=False, + help="Accumulate stats on activations, print them and exit.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=4000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train % save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=100, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--use-fp16", + type=str2bool, + default=False, + help="Whether to use half precision training.", + ) + + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + + - feature_dim: The model input dim. It has to match the one used + in computing features. + + - subsampling_factor: The subsampling factor for the model. + + - num_decoder_layers: Number of decoder layer of transformer decoder. + + - warm_step: The warm_step for Noam optimizer. + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 50, + "reset_interval": 200, + "valid_interval": 3000, # For the 100h subset, use 800 + # parameters for conformer + "feature_dim": 80, + "subsampling_factor": 4, + "dim_feedforward": 2048, + # parameters for decoder + "decoder_dim": 512, + # parameters for joiner + "joiner_dim": 512, + # parameters for Noam + "model_warm_step": 3000, # arg given to model, not for lrate + "env_info": get_env_info(), + } + ) + + return params + + +def get_encoder_model(params: AttributeDict) -> nn.Module: + encoder = RNN( + num_features=params.feature_dim, + subsampling_factor=params.subsampling_factor, + d_model=params.encoder_dim, + rnn_hidden_size=params.rnn_hidden_size, + dim_feedforward=params.dim_feedforward, + num_encoder_layers=params.num_encoder_layers, + aux_layer_period=params.aux_layer_period, + ) + return encoder + + +def get_decoder_model(params: AttributeDict) -> nn.Module: + decoder = Decoder( + vocab_size=params.vocab_size, + decoder_dim=params.decoder_dim, + blank_id=params.blank_id, + context_size=params.context_size, + ) + return decoder + + +def get_joiner_model(params: AttributeDict) -> nn.Module: + joiner = Joiner( + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return joiner + + +def get_transducer_model(params: AttributeDict) -> nn.Module: + encoder = get_encoder_model(params) + decoder = get_decoder_model(params) + joiner = get_joiner_model(params) + + model = Transducer( + encoder=encoder, + decoder=decoder, + joiner=joiner, + encoder_dim=params.encoder_dim, + decoder_dim=params.decoder_dim, + joiner_dim=params.joiner_dim, + vocab_size=params.vocab_size, + ) + return model + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + batch: dict, + is_training: bool, + warmup: float = 1.0, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute RNN-T loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Conformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) + feature = batch["inputs"] + # at entry, feature is (N, T, C) + assert feature.ndim == 3 + feature = feature.to(device) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + texts = batch["supervisions"]["text"] + y = sp.encode(texts, out_type=int) + y = k2.RaggedTensor(y).to(device) + + with torch.set_grad_enabled(is_training): + simple_loss, pruned_loss = model( + x=feature, + x_lens=feature_lens, + y=y, + prune_range=params.prune_range, + am_scale=params.am_scale, + lm_scale=params.lm_scale, + warmup=warmup, + reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, + ) + simple_loss_is_finite = torch.isfinite(simple_loss) + pruned_loss_is_finite = torch.isfinite(pruned_loss) + is_finite = simple_loss_is_finite & pruned_loss_is_finite + if not torch.all(is_finite): + logging.info( + "Not all losses are finite!\n" + f"simple_loss: {simple_loss}\n" + f"pruned_loss: {pruned_loss}" + ) + display_and_save_batch(batch, params=params, sp=sp) + simple_loss = simple_loss[simple_loss_is_finite] + pruned_loss = pruned_loss[pruned_loss_is_finite] + + # If either all simple_loss or pruned_loss is inf or nan, + # we stop the training process by raising an exception + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): + raise ValueError( + "There are too many utterances in this batch " + "leading to inf or nan losses." + ) + + simple_loss = simple_loss.sum() + pruned_loss = pruned_loss.sum() + # after the main warmup step, we keep pruned_loss_scale small + # for the same amount of time (model_warm_step), to avoid + # overwhelming the simple_loss and causing it to diverge, + # in case it had not fully learned the alignment yet. + pruned_loss_scale = ( + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + # info["frames"] is an approximate number for two reasons: + # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 + # (2) If some utterances in the batch lead to inf/nan loss, they + # are filtered out. + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) + + # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa + info["utterances"] = feature.size(0) + # averaged input duration in frames over utterances + info["utt_duration"] = feature_lens.sum().item() + # averaged padding proportion over utterances + info["utt_pad_proportion"] = ( + ((feature.size(1) - feature_lens) / feature.size(1)).sum().item() + ) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + info["simple_loss"] = simple_loss.detach().cpu().item() + info["pruned_loss"] = pruned_loss.detach().cpu().item() + + return loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + sp: spm.SentencePieceProcessor, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + model.eval() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + + if world_size > 1: + tot_loss.reduce(loss.device) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + sp: spm.SentencePieceProcessor, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(train_dl): + params.batch_idx_train += 1 + batch_size = len(batch["supervisions"]["text"]) + + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, loss_info = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=(params.batch_idx_train / params.model_warm_step), + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + scaler.scale(loss).backward() + scheduler.step_batch(params.batch_idx_train) + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + except: # noqa + display_and_save_batch(batch, params=params, sp=sp) + raise + + if params.print_diagnostics and batch_idx == 30: + return + + if ( + rank == 0 + and params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, loss[{loss_info}], " + f"tot_loss[{tot_loss}], batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + + loss_info.write_summary( + tb_writer, "train/current_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + + if batch_idx > 0 and batch_idx % params.valid_interval == 0: + logging.info("Computing validation loss") + valid_info = compute_validation_loss( + params=params, + model=model, + sp=sp, + valid_dl=valid_dl, + world_size=world_size, + ) + model.train() + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + if params.full_libri is False: + params.valid_interval = 800 + + fix_random_seed(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"Device: {device}") + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id(" ") + params.vocab_size = sp.get_piece_size() + + logging.info(params) + + logging.info("About to create model") + model = get_transducer_model(params) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank]) + + optimizer = Eve(model.parameters(), lr=params.initial_lr) + + scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + # # overwrite it + # scheduler.base_lrs = [params.initial_lr for _ in scheduler.base_lrs] + # print(scheduler.base_lrs) + + if params.print_diagnostics: + diagnostic = diagnostics.attach_diagnostics(model) + + librispeech = LibriSpeechAsrDataModule(args) + + train_cuts = librispeech.train_clean_100_cuts() + if params.full_libri: + train_cuts += librispeech.train_clean_360_cuts() + train_cuts += librispeech.train_other_500_cuts() + + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 1 second and 20 seconds + # + # Caution: There is a reason to select 20.0 here. Please see + # ../local/display_manifest_statistics.py + # + # You should use ../local/display_manifest_statistics.py to get + # an utterance duration distribution for your dataset to select + # the threshold + if c.duration < 1.0 or c.duration > 20.0: + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" + ) + return False + + # In pruned RNN-T, we require that T >= S + # where T is the number of feature frames after subsampling + # and S is the number of tokens in the utterance + + # In ./lstm.py, the conv module uses the following expression + # for subsampling + T = ((c.num_frames - 3) // 2 - 1) // 2 + tokens = sp.encode(c.supervisions[0].text, out_type=str) + + if T < len(tokens): + logging.warning( + f"Exclude cut with ID {c.id} from training. " + f"Number of frames (before subsampling): {c.num_frames}. " + f"Number of frames (after subsampling): {T}. " + f"Text: {c.supervisions[0].text}. " + f"Tokens: {tokens}. " + f"Number of tokens: {len(tokens)}" + ) + return False + + return True + + train_cuts = train_cuts.filter(remove_short_and_long_utt) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + # We only load the sampler's state dict when it loads a checkpoint + # saved in the middle of an epoch + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + train_dl = librispeech.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + + valid_cuts = librispeech.dev_clean_cuts() + valid_cuts += librispeech.dev_other_cuts() + valid_dl = librispeech.valid_dataloaders(valid_cuts) + + if not params.print_diagnostics: + scan_pessimistic_batches_for_oom( + model=model, + train_dl=train_dl, + optimizer=optimizer, + sp=sp, + params=params, + warmup=0.0 if params.start_epoch == 1 else 1.0, + ) + + scaler = GradScaler(enabled=params.use_fp16) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + scheduler.step_epoch(epoch - 1) + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sp=sp, + train_dl=train_dl, + valid_dl=valid_dl, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + if params.print_diagnostics: + diagnostic.print_diagnostics() + break + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + sp: spm.SentencePieceProcessor, + params: AttributeDict, + warmup: float, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(enabled=params.use_fp16): + loss, _ = compute_loss( + params=params, + model=model, + sp=sp, + batch=batch, + is_training=True, + warmup=warmup, + ) + loss.backward() + optimizer.step() + optimizer.zero_grad() + except RuntimeError as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + raise + + +def main(): + parser = get_parser() + LibriSpeechAsrDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py index f7e1b5a54..bad4e243e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/decode.py @@ -185,24 +185,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -299,7 +295,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -477,7 +474,9 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens, _ = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -536,7 +535,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -698,7 +700,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -731,7 +735,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -784,7 +789,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -819,12 +826,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -852,12 +860,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -886,7 +895,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -952,7 +961,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py index 0ad00cda3..190673638 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/export.py @@ -146,24 +146,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -229,7 +225,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -345,7 +342,9 @@ def export_encoder_model_onnx( x = torch.zeros(N, 9, 80, dtype=torch.float32) x_lens = torch.tensor([9], dtype=torch.int64) h = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.d_model) - c = torch.rand(encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size) + c = torch.rand( + encoder_model.num_encoder_layers, N, encoder_model.rnn_hidden_size + ) warmup = 1.0 torch.onnx.export( @@ -446,9 +445,13 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) - decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -547,12 +550,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -581,12 +585,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -615,7 +620,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -689,7 +694,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py index 5a8efd718..da184b76f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/jit_pretrained.py @@ -86,12 +86,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -126,9 +124,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -316,7 +315,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index 4957d14b1..fadeb4ac2 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -84,7 +84,9 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -188,7 +190,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py index 3b471fa85..410de8d3d 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/ncnn-decode.py @@ -156,7 +156,9 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( + torch.int32 + ) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -198,9 +200,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -283,7 +286,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py index 7d931a286..bef0ad760 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/pretrained.py @@ -92,11 +92,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -121,12 +119,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -173,7 +169,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -204,9 +201,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -269,11 +267,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens, _ = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -345,7 +347,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py index baff15ea6..e47a05a9e 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-ncnn-decode.py @@ -144,7 +144,9 @@ class Model: assert ret == 0, ret encoder_out = torch.from_numpy(ncnn_out0.numpy()).clone() - encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to(torch.int32) + encoder_out_lens = torch.from_numpy(ncnn_out1.numpy()).to( + torch.int32 + ) hx = torch.from_numpy(ncnn_out2.numpy()).clone() cx = torch.from_numpy(ncnn_out3.numpy()).clone() return encoder_out, encoder_out_lens, hx, cx @@ -186,9 +188,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -226,7 +229,9 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor(hyp, dtype=torch.int32) # (1, context_size) + decoder_input = torch.tensor( + hyp, dtype=torch.int32 + ) # (1, context_size) decoder_out = model.run_decoder(decoder_input).squeeze(0) else: assert decoder_out.ndim == 1 @@ -305,7 +310,9 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder( + frames, states + ) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -321,7 +328,9 @@ def main(): frames.append(online_fbank.get_frame(num_processed_frames + i)) num_processed_frames += offset frames = torch.cat(frames, dim=0) - encoder_out, encoder_out_lens, hx, cx = model.run_encoder(frames, states) + encoder_out, encoder_out_lens, hx, cx = model.run_encoder( + frames, states + ) states = (hx, cx) hyp, decoder_out = greedy_search( model, encoder_out.squeeze(0), decoder_out, hyp @@ -334,7 +343,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py index b31fefa0a..232d3dd18 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/streaming-onnx-decode.py @@ -109,12 +109,10 @@ def get_args(): parser.add_argument( "sound_filename", type=str, - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -149,9 +147,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -200,7 +199,9 @@ class Model: sess_options=self.session_opts, ) - def run_encoder(self, x, h0, c0) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def run_encoder( + self, x, h0, c0 + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Args: x: @@ -257,7 +258,9 @@ class Model: }, )[0] - return self.run_joiner_decoder_proj(torch.from_numpy(decoder_out).squeeze(1)) + return self.run_joiner_decoder_proj( + torch.from_numpy(decoder_out).squeeze(1) + ) def run_joiner( self, @@ -300,7 +303,11 @@ class Model: projected_encoder_out = self.joiner_encoder_proj.run( [self.joiner_encoder_proj.get_outputs()[0].name], - {self.joiner_encoder_proj.get_inputs()[0].name: encoder_out.numpy()}, + { + self.joiner_encoder_proj.get_inputs()[ + 0 + ].name: encoder_out.numpy() + }, )[0] return torch.from_numpy(projected_encoder_out) @@ -319,7 +326,11 @@ class Model: projected_decoder_out = self.joiner_decoder_proj.run( [self.joiner_decoder_proj.get_outputs()[0].name], - {self.joiner_decoder_proj.get_inputs()[0].name: decoder_out.numpy()}, + { + self.joiner_decoder_proj.get_inputs()[ + 0 + ].name: decoder_out.numpy() + }, )[0] return torch.from_numpy(projected_decoder_out) @@ -358,7 +369,9 @@ def greedy_search( if decoder_out is None: assert hyp is None, hyp hyp = [blank_id] * context_size - decoder_input = torch.tensor([hyp], dtype=torch.int64) # (1, context_size) + decoder_input = torch.tensor( + [hyp], dtype=torch.int64 + ) # (1, context_size) decoder_out = model.run_decoder(decoder_input) else: assert decoder_out.shape[0] == 1 @@ -461,7 +474,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 08a895a75..5eaaf321f 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -95,7 +95,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -161,7 +163,8 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) parser.add_argument( @@ -235,45 +238,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -645,7 +645,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -688,7 +692,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -701,9 +707,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -714,7 +725,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -945,7 +958,9 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -991,7 +1006,8 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -1139,7 +1155,9 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) else: cuts_musan = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py index a8d5605fb..9eee19379 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/decode.py @@ -182,24 +182,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -294,7 +290,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -389,7 +386,9 @@ def decode_one_batch( ) feature_lens += num_tail_padded_frames - encoder_out, encoder_out_lens, _ = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens, _ = model.encoder( + x=feature, x_lens=feature_lens + ) if params.decoding_method == "fast_beam_search": res = fast_beam_search_one_best( @@ -442,7 +441,10 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -520,7 +522,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: +) -> Dict[ + str, List[Tuple[str, List[str], List[str], List[float], List[float]]] +]: """Decode dataset. Args: @@ -595,7 +599,9 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) + this_batch.append( + (cut_id, ref_words, hyp_words, time_ref, time_hyp) + ) results[name].extend(this_batch) @@ -604,7 +610,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -642,7 +650,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -669,7 +678,9 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) + s = "\nFor {}, symbol-delay of different settings are:\n".format( + test_set_name + ) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -713,7 +724,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -745,12 +758,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -773,12 +787,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -806,7 +821,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -833,7 +848,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py index 51238f768..212c7bad6 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/export.py @@ -122,24 +122,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -176,7 +172,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -284,12 +281,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -312,12 +310,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -345,7 +344,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -381,7 +380,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py index 180ba8c72..a3443cf0a 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/jit_pretrained.py @@ -85,12 +85,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -125,9 +123,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -315,7 +314,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py index 6e51b85e4..90bc351f4 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/lstm.py @@ -661,7 +661,9 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) + torch.tensor( + (final_weight / (1 - final_weight)) * (self.num_inputs - 1) + ) .log() .item() ) @@ -758,14 +760,16 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. # noqa - nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) + nonfinal = torch.randint( + self.num_inputs - 1, (num_frames,), device=device + ) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( - dtype=dtype - ) + ans = torch.nn.functional.one_hot( + indexes, num_classes=self.num_inputs + ).to(dtype=dtype) return ans def _get_random_mixed_weights( diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py index 4f8049245..0e48fef04 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/pretrained.py @@ -89,11 +89,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -118,12 +116,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -170,7 +166,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -201,9 +198,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -266,11 +264,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens, _ = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens, _ = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -342,7 +344,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py index 4e9063a40..cfa918ed5 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/streaming_decode.py @@ -101,9 +101,8 @@ def get_parser(): "--epoch", type=int, default=40, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( @@ -120,24 +119,20 @@ def get_parser(): "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -204,7 +199,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -363,7 +359,9 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -380,7 +378,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -539,7 +539,9 @@ def decode_one_chunk( feature_list, batch_first=True, padding_value=LOG_EPSILON ).to(device) feature_lens = torch.tensor(feature_len_list, device=device) - num_processed_frames = torch.tensor(num_processed_frames_list, device=device) + num_processed_frames = torch.tensor( + num_processed_frames_list, device=device + ) # Make sure it has at least 1 frame after subsampling tail_length = params.subsampling_factor + 5 @@ -581,7 +583,8 @@ def decode_one_chunk( with warnings.catch_warnings(): warnings.simplefilter("ignore") processed_lens = ( - num_processed_frames // params.subsampling_factor + encoder_out_lens + num_processed_frames // params.subsampling_factor + + encoder_out_lens ) fast_beam_search_one_best( model=model, @@ -593,7 +596,9 @@ def decode_one_chunk( max_states=params.max_states, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) # Update cached states of each stream state_list = unstack_states(states) @@ -768,7 +773,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -810,7 +816,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -844,12 +852,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -872,12 +881,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -905,7 +915,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index a1d19fb73..60a5a2be7 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -87,7 +87,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -230,45 +232,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -607,7 +606,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -647,7 +650,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -660,9 +665,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -673,7 +683,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -840,7 +852,10 @@ def train_one_epoch( rank=rank, ) - if batch_idx % params.log_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % params.log_interval == 0 + and not params.print_diagnostics + ): cur_lr = scheduler.get_last_lr()[0] logging.info( f"Epoch {params.cur_epoch}, " @@ -857,7 +872,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if ( batch_idx > 0 @@ -992,7 +1009,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py index fd2a5354a..8dd1459ca 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned2_knowledge/asr_datamodule.py @@ -74,18 +74,17 @@ class LibriSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) group.add_argument( "--manifest-dir", @@ -97,91 +96,75 @@ class LibriSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the BucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -195,22 +178,18 @@ class LibriSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", ) def train_dataloaders( @@ -229,16 +208,20 @@ class LibriSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.json.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "cuts_musan.json.gz" + ) transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -253,7 +236,9 @@ class LibriSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -296,7 +281,9 @@ class LibriSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -353,7 +340,9 @@ class LibriSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -400,17 +389,23 @@ class LibriSpeechAsrDataModule: @lru_cache() def train_clean_100_cuts(self) -> CutSet: logging.info("About to get train-clean-100 cuts") - return load_manifest(self.args.manifest_dir / "cuts_train-clean-100.json.gz") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-100.json.gz" + ) @lru_cache() def train_clean_360_cuts(self) -> CutSet: logging.info("About to get train-clean-360 cuts") - return load_manifest(self.args.manifest_dir / "cuts_train-clean-360.json.gz") + return load_manifest( + self.args.manifest_dir / "cuts_train-clean-360.json.gz" + ) @lru_cache() def train_other_500_cuts(self) -> CutSet: logging.info("About to get train-other-500 cuts") - return load_manifest(self.args.manifest_dir / "cuts_train-other-500.json.gz") + return load_manifest( + self.args.manifest_dir / "cuts_train-other-500.json.gz" + ) @lru_cache() def dev_clean_cuts(self) -> CutSet: diff --git a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py index 785a8f097..2e9bf3e0b 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/beam_search.py +++ b/egs/librispeech/ASR/pruned2_knowledge/beam_search.py @@ -172,9 +172,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -302,7 +302,9 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -318,7 +320,9 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -492,7 +496,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) diff --git a/egs/librispeech/ASR/pruned2_knowledge/conformer.py b/egs/librispeech/ASR/pruned2_knowledge/conformer.py index 3b6d0549d..295a35204 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/conformer.py +++ b/egs/librispeech/ASR/pruned2_knowledge/conformer.py @@ -18,10 +18,10 @@ import math import warnings from typing import Optional, Tuple +from sampling import create_knowledge_base, KnowledgeBaseLookup import torch from encoder_interface import EncoderInterface -from sampling import KnowledgeBaseLookup, create_knowledge_base from scaling import ( ActivationBalancer, BasicNorm, @@ -73,9 +73,9 @@ class Conformer(EncoderInterface): if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - self.knowledge_base = create_knowledge_base( - knowledge_M, knowledge_N, knowledge_D - ) + + self.knowledge_base = create_knowledge_base(knowledge_M, knowledge_N, + knowledge_D) # self.encoder_embed converts the input of shape (N, T, num_features) # to the shape (N, T//subsampling_factor, d_model). @@ -89,7 +89,7 @@ class Conformer(EncoderInterface): # Pass in a lambda that creates a new ConformerEncoderLayer with these # args. Don't use deepcopy because we need the knowledge_base # to be shared. - encoder_layer_fn = lambda: ConformerEncoderLayer( # noqa: E731 + encoder_layer_fn = lambda: ConformerEncoderLayer( self.knowledge_base, d_model, nhead, @@ -100,7 +100,7 @@ class Conformer(EncoderInterface): knowledge_M, knowledge_N, knowledge_D, - knowledge_K, + knowledge_K ) self.encoder = ConformerEncoder(encoder_layer_fn, num_encoder_layers) @@ -187,7 +187,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -207,14 +209,10 @@ class ConformerEncoderLayer(nn.Module): self.conv_module = ConvolutionModule(d_model, cnn_module_kernel) - self.lookup = KnowledgeBaseLookup( - knowledge_M, - knowledge_N, - knowledge_D, - knowledge_K, - d_model, - knowledge_base, - ) + self.lookup = KnowledgeBaseLookup(knowledge_M, knowledge_N, + knowledge_D, knowledge_K, + d_model, + knowledge_base) self.norm_final = BasicNorm(d_model) @@ -313,7 +311,9 @@ class ConformerEncoder(nn.Module): def __init__(self, encoder_layer_fn, num_layers: int) -> None: super().__init__() - self.layers = nn.ModuleList([encoder_layer_fn() for i in range(num_layers)]) + self.layers = nn.ModuleList( + [encoder_layer_fn() for i in range(num_layers)] + ) self.num_layers = num_layers def forward( @@ -367,7 +367,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -382,7 +384,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -657,9 +661,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -728,25 +732,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -783,7 +795,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -791,9 +805,13 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -827,9 +845,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -852,7 +874,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned2_knowledge/decode.py b/egs/librispeech/ASR/pruned2_knowledge/decode.py index 65da19f27..b4a9af55a 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/decode.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decode.py @@ -76,7 +76,11 @@ from beam_search import ( ) from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -94,19 +98,16 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -185,7 +186,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -243,7 +245,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -258,7 +262,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -302,7 +309,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -374,7 +385,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -406,7 +419,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder.py b/egs/librispeech/ASR/pruned2_knowledge/decoder.py index 0b9c886c7..b6d94aaf1 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder.py @@ -90,7 +90,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py index 2ca76a30c..db51fb1cd 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/decoder2.py +++ b/egs/librispeech/ASR/pruned2_knowledge/decoder2.py @@ -14,13 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional - import torch import torch.nn as nn import torch.nn.functional as F -from subsampling import ScaledConv1d from torch import Tensor +from typing import Optional +from subsampling import ScaledConv1d class Decoder(nn.Module): @@ -91,7 +90,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output @@ -101,6 +102,7 @@ class Decoder(nn.Module): return embedding_out + class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -169,13 +171,8 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = [ - "num_embeddings", - "embedding_dim", - "padding_idx", - "scale_grad_by_freq", - "sparse", - ] + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] num_embeddings: int embedding_dim: int @@ -184,41 +181,34 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - scale_speed: float = 5.0, - ) -> None: + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False, + scale_speed: float = 5.0) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert ( - padding_idx < self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' elif padding_idx < 0: - assert ( - padding_idx >= -self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq self.scale_speed = scale_speed - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() + + def reset_parameters(self) -> None: nn.init.normal_(self.weight, std=0.05) - nn.init.constant_(self.scale, torch.tensor(1.0 / 0.05).log() / self.scale_speed) + nn.init.constant_(self.scale, torch.tensor(1.0/0.05).log() / self.scale_speed) if self.padding_idx is not None: with torch.no_grad(): @@ -227,38 +217,22 @@ class ScaledEmbedding(nn.Module): def forward(self, input: Tensor) -> Tensor: scale = (self.scale * self.scale_speed).exp() if input.numel() < self.num_embeddings: - return ( - F.embedding( - input, - self.weight, - self.padding_idx, - None, - 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, - self.sparse, - ) - * scale - ) + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale else: return F.embedding( - input, - self.weight * scale, - self.padding_idx, - None, - 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, - self.sparse, - ) + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) def extra_repr(self) -> str: - s = ( - "{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}," - " scale={scale}" - ) + s = '{num_embeddings}, {embedding_dim}, scale_speed={scale_speed}, scale={scale}' if self.padding_idx is not None: - s += ", padding_idx={padding_idx}" + s += ', padding_idx={padding_idx}' if self.scale_grad_by_freq is not False: - s += ", scale_grad_by_freq={scale_grad_by_freq}" + s += ', scale_grad_by_freq={scale_grad_by_freq}' if self.sparse is not False: - s += ", sparse=True" + s += ', sparse=True' return s.format(**self.__dict__) diff --git a/egs/librispeech/ASR/pruned2_knowledge/export.py b/egs/librispeech/ASR/pruned2_knowledge/export.py index 1af05d9c8..96d1a30fb 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/export.py +++ b/egs/librispeech/ASR/pruned2_knowledge/export.py @@ -64,20 +64,17 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -108,7 +105,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -176,7 +174,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned2_knowledge/joiner.py b/egs/librispeech/ASR/pruned2_knowledge/joiner.py index 68c663b66..35f75ed2a 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/joiner.py +++ b/egs/librispeech/ASR/pruned2_knowledge/joiner.py @@ -56,7 +56,9 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned2_knowledge/model.py b/egs/librispeech/ASR/pruned2_knowledge/model.py index ca8c28af1..599bf2506 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/model.py +++ b/egs/librispeech/ASR/pruned2_knowledge/model.py @@ -63,7 +63,9 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -134,7 +136,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned2_knowledge/optim.py b/egs/librispeech/ASR/pruned2_knowledge/optim.py index 76cd4e11e..432bf8220 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/optim.py +++ b/egs/librispeech/ASR/pruned2_knowledge/optim.py @@ -72,11 +72,17 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -112,7 +118,9 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") + raise RuntimeError( + "AdamW does not support sparse gradients" + ) state = self.state[p] @@ -139,7 +147,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( group["eps"] ) @@ -150,7 +158,9 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -166,14 +176,18 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] + self.base_lrs = [ + group["initial_lr"] for group in optimizer.param_groups + ] self.epoch = 0 self.batch = 0 @@ -281,9 +295,10 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned2_knowledge/sampling.py b/egs/librispeech/ASR/pruned2_knowledge/sampling.py index 8cc930927..7b05e2f00 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/sampling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/sampling.py @@ -3,29 +3,32 @@ # This was copied from /ceph-dan/torch-sampling/torch_sampling/sampling_ref.py, # its git history is there. -import random import timeit -from typing import Optional, Tuple - import torch +from torch import Tensor +from torch import nn +from torch.cuda.amp import GradScaler, custom_fwd, custom_bwd +from typing import Tuple, Optional from scaling import ScaledLinear -from torch import Tensor, nn -from torch.cuda.amp import GradScaler, custom_bwd, custom_fwd +import random from torch_scheduled_sampling import sample_combined # The main exports of this file are the module KnowledgeBaseLookup and the # function create_knowledge_base. + + + + def create_knowledge_base(M: int, N: int, D: int) -> nn.Parameter: std = 0.1 - a = (3**0.5) * std # this sqrt(3) thing is intended to get variance of - # 0.1 from uniform distribution - ans = nn.Parameter(torch.ones(M**N, D)) + a = (3 ** 0.5) * std # this sqrt(3) thing is intended to get variance of + # 0.1 from uniform distribution + ans = nn.Parameter(torch.ones(M ** N, D)) nn.init.uniform_(ans, -a, a) return ans - def join_indexes(indexes: Tensor, M: int) -> Tensor: """ Combines N-tuples of indexes into single indexes that can be used for @@ -44,9 +47,9 @@ def join_indexes(indexes: Tensor, M: int) -> Tensor: # Note, we don't use this, we -def weighted_matrix_lookup( - weights: Tensor, indexes: Tensor, knowledge_base: Tensor -) -> Tensor: +def weighted_matrix_lookup(weights: Tensor, + indexes: Tensor, + knowledge_base: Tensor) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -62,9 +65,9 @@ def weighted_matrix_lookup( # simpler but less memory-efficient implementation lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) ans = ans.squeeze(-2) assert list(ans.shape) == list(weights.shape[:-2]) + [D] return ans @@ -73,9 +76,7 @@ def weighted_matrix_lookup( class WeightedMatrixLookupFunction(torch.autograd.Function): @staticmethod @custom_fwd - def forward( - ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor - ) -> Tensor: + def forward(ctx, weights: Tensor, indexes: Tensor, knowledge_base: Tensor) -> Tensor: """ Weighted combination of specified rows of a matrix. weights: Tensor of shape (*, K), can contain any value but probably in [0..1]. @@ -87,16 +88,15 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): """ if random.random() < 0.001: print("dtype[1] = ", weights.dtype) - ctx.save_for_backward( - weights.detach(), indexes.detach(), knowledge_base.detach() - ) + ctx.save_for_backward(weights.detach(), indexes.detach(), + knowledge_base.detach()) with torch.no_grad(): lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) D = knowledge_base.shape[-1] - weights = weights.unsqueeze(-2) # (*, 1, K) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - ans = torch.matmul(weights, lookup) # ans: (*, 1, D) - ans = ans.squeeze(-2) # (*, D) + weights = weights.unsqueeze(-2) # (*, 1, K) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + ans = torch.matmul(weights, lookup) # ans: (*, 1, D) + ans = ans.squeeze(-2) #(*, D) return ans @staticmethod @@ -107,7 +107,7 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): knowledge_base.requires_grad = True dtype = ans_grad.dtype ans_grad = ans_grad.to(weights.dtype) - assert weights.requires_grad is False + assert weights.requires_grad == False D = knowledge_base.shape[-1] with torch.enable_grad(): # we'll use torch's autograd to differentiate this operation, which @@ -115,19 +115,16 @@ class WeightedMatrixLookupFunction(torch.autograd.Function): # We don't save `lookup` because it's large, that is the reason # we override Torch autograd. lookup = torch.index_select(knowledge_base, dim=0, index=indexes.flatten()) - lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) - weights = weights.unsqueeze(-1) # (*, K, 1) + lookup = lookup.reshape(*indexes.shape, D) # (*, K, D) + weights = weights.unsqueeze(-1) # (*, K, 1) # forward pass: was: ## ans = torch.matmul(weights, lookup) ## ans: (*, 1, D) ## ans = ans.squeeze(-2) # ans, ans_grad: (*, D) - weights_grad = torch.matmul( - lookup, ans_grad.unsqueeze(-1) # (*, K, D) - ) # (*, D, 1) - weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) - lookup_grad = weights * ans_grad.unsqueeze( - -2 - ) # (*, K, 1) * (*, 1, D) = (*, K, D) + weights_grad = torch.matmul(lookup, # (*, K, D) + ans_grad.unsqueeze(-1)) # (*, D, 1) + weights_grad = weights_grad.squeeze(-1) # (*, K, 1) -> (*, K) + lookup_grad = weights * ans_grad.unsqueeze(-2) # (*, K, 1) * (*, 1, D) = (*, K, D) lookup.backward(gradient=lookup_grad) return weights_grad.to(dtype), None, knowledge_base.grad.to(dtype) @@ -149,7 +146,6 @@ class PenalizeNegentropyFunction(torch.autograd.Function): Returns: logprobs """ - @staticmethod def forward(ctx, logprobs: Tensor, alpha: float): ctx.save_for_backward(logprobs.detach()) @@ -158,23 +154,18 @@ class PenalizeNegentropyFunction(torch.autograd.Function): @staticmethod def backward(ctx, logprobs_grad: Tensor) -> Tuple[Tensor, None]: - (logprobs,) = ctx.saved_tensors + logprobs, = ctx.saved_tensors with torch.enable_grad(): logprobs.requires_grad = True # `negentropy` is the negative entropy of the average distribution. # distributions. It will be <= 0. - l = logprobs.reshape(-1, logprobs.shape[-1]) # noqa: E741 + l = logprobs.reshape(-1, logprobs.shape[-1]) scale = ctx.alpha * l.shape[0] avg_dist = l.exp().mean(dim=0) negentropy = (avg_dist * (avg_dist + 1.0e-20).log()).sum() if random.random() < 0.0005: negentropy_individual = (l * l.exp()).sum(dim=-1).mean() - print( - "Negentropy[individual,combined] = ", - negentropy_individual.item(), - ", ", - negentropy.item(), - ) + print("Negentropy[individual,combined] = ", negentropy_individual.item(), ", ", negentropy.item()) loss = negentropy * scale loss.backward() return logprobs_grad + logprobs.grad, None @@ -192,23 +183,18 @@ class KnowledgeBaseLookup(nn.Module): embedding_dim: the dimension to project from and to, e.g. the d_model of the conformer. """ - - def __init__( - self, - M: int, - N: int, - D: int, - K: int, - embedding_dim: int, - knowledge_base: nn.Parameter, - negentropy_penalty: float = 0.001, - ): + def __init__(self, M: int, N: int, D: int, + K: int, embedding_dim: int, + knowledge_base: nn.Parameter, + negentropy_penalty: float = 0.001): super(KnowledgeBaseLookup, self).__init__() self.knowledge_base = knowledge_base # shared! - self.in_proj = ScaledLinear(embedding_dim, M * N, initial_scale=1.0) + self.in_proj = ScaledLinear(embedding_dim, M * N, + initial_scale=1.0) # initial_scale = 4.0 because the knowlege_base activations are # quite small -- if we use our optimizer they'll have stddev <= 0.1. - self.out_proj = ScaledLinear(D, embedding_dim, initial_scale=4.0) + self.out_proj = ScaledLinear(D, embedding_dim, + initial_scale = 4.0) self.M = M self.N = N self.K = K @@ -224,14 +210,14 @@ class KnowledgeBaseLookup(nn.Module): # TODO: later we can try multiplying by a projection of x or something like that. """ - x = self.in_proj(x) # now (*, M*N) - x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) - x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) + x = self.in_proj(x) # now (*, M*N) + x = x.reshape(*x.shape[:-1], self.N, self.M) # now (*, N, M) + x = x.log_softmax(dim=-1) # now normalized logprobs, dim= (*, N, M) x = PenalizeNegentropyFunction.apply(x, self.negentropy_penalty) _, indexes, weights = sample_combined(x, self.K, input_is_log=True) - x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) - x = self.out_proj(x) # now (*, self.embedding_dim) + x = weighted_matrix_lookup(weights, indexes, self.knowledge_base) # now (*, D) + x = self.out_proj(x) # now (*, self.embedding_dim) return x @@ -251,44 +237,38 @@ def _test_knowledge_base_lookup(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) dtype = torch.float32 - device = torch.device("cuda") - train_pairs = [ - ( - torch.randn(B, T, E, device=device, dtype=dtype), - torch.randn(B, T, E, device=device, dtype=dtype), - ) - for _ in range(10) - ] + device = torch.device('cuda') + train_pairs = [ (torch.randn(B, T, E, device=device, dtype=dtype), torch.randn(B, T, E, device=device, dtype=dtype)) for _ in range(10) ] from optim import Eve - optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device).to(dtype) + start = timeit.default_timer() - # Epoch 0, batch 0, loss 1.0109944343566895 - # Epoch 10, batch 0, loss 1.0146660804748535 - # Epoch 20, batch 0, loss 1.0119813680648804 - # Epoch 30, batch 0, loss 1.0105408430099487 - # Epoch 40, batch 0, loss 1.0077732801437378 - # Epoch 50, batch 0, loss 1.0050103664398193 - # Epoch 60, batch 0, loss 1.0033129453659058 - # Epoch 70, batch 0, loss 1.0014232397079468 - # Epoch 80, batch 0, loss 0.9977912306785583 - # Epoch 90, batch 0, loss 0.8274348974227905 - # Epoch 100, batch 0, loss 0.3368612825870514 - # Epoch 110, batch 0, loss 0.11323091387748718 - # Time taken: 17.591704960912466 +# Epoch 0, batch 0, loss 1.0109944343566895 +# Epoch 10, batch 0, loss 1.0146660804748535 +# Epoch 20, batch 0, loss 1.0119813680648804 +# Epoch 30, batch 0, loss 1.0105408430099487 +# Epoch 40, batch 0, loss 1.0077732801437378 +# Epoch 50, batch 0, loss 1.0050103664398193 +# Epoch 60, batch 0, loss 1.0033129453659058 +# Epoch 70, batch 0, loss 1.0014232397079468 +# Epoch 80, batch 0, loss 0.9977912306785583 +# Epoch 90, batch 0, loss 0.8274348974227905 +# Epoch 100, batch 0, loss 0.3368612825870514 +# Epoch 110, batch 0, loss 0.11323091387748718 +# Time taken: 17.591704960912466 for epoch in range(150): - for n, (x, y) in enumerate(train_pairs): + for n, (x,y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y) ** 2).mean() * 100.0 + loss = ((y_out - y)**2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") loss.backward() @@ -296,8 +276,7 @@ def _test_knowledge_base_lookup(): optimizer.zero_grad() stop = timeit.default_timer() - print("Time taken: ", stop - start) - + print('Time taken: ', stop - start) def _test_knowledge_base_lookup_autocast(): K = 16 @@ -315,21 +294,14 @@ def _test_knowledge_base_lookup_autocast(): x.requires_grad = True y = m(x) assert y.shape == x.shape - y.sum().backward() # make sure backward doesn't crash.. + y.sum().backward() # make sure backward doesn't crash.. print("y = ", y) print("x.grad = ", x.grad) print("knowlege_base.grad norm = ", knowledge_base.grad.norm()) - device = torch.device("cuda") - train_pairs = [ - ( - torch.randn(B, T, E, device=device), - torch.randn(B, T, E, device=device), - ) - for _ in range(10) - ] + device = torch.device('cuda') + train_pairs = [ (torch.randn(B, T, E, device=device), torch.randn(B, T, E, device=device)) for _ in range(10) ] from optim import Eve - optimizer = Eve(m.parameters(), lr=0.005, eps=1.0e-04) m = m.to(device) @@ -337,11 +309,12 @@ def _test_knowledge_base_lookup_autocast(): start = timeit.default_timer() + for epoch in range(150): - for n, (x, y) in enumerate(train_pairs): + for n, (x,y) in enumerate(train_pairs): y_out = m(x) with torch.cuda.amp.autocast(enabled=True): - loss = ((y_out - y) ** 2).mean() * 100.0 + loss = ((y_out - y)**2).mean() * 100.0 if n % 10 == 0 and epoch % 10 == 0: print(f"Epoch {epoch}, batch {n}, loss {loss.item()}") scaler.scale(loss).backward() @@ -350,9 +323,10 @@ def _test_knowledge_base_lookup_autocast(): optimizer.zero_grad() stop = timeit.default_timer() - print("Time taken: ", stop - start) + print('Time taken: ', stop - start) -if __name__ == "__main__": + +if __name__ == '__main__': _test_knowledge_base_lookup() _test_knowledge_base_lookup_autocast() diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling.py b/egs/librispeech/ASR/pruned2_knowledge/scaling.py index 527c735eb..f726c2583 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling.py @@ -18,11 +18,11 @@ import collections from itertools import repeat from typing import Optional, Tuple +from torch.cuda.amp import custom_fwd, custom_bwd import torch import torch.nn as nn from torch import Tensor -from torch.cuda.amp import custom_bwd, custom_fwd def _ntuple(n): @@ -79,7 +79,9 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) + ctx.save_for_backward( + factor, xgt0, below_threshold, above_threshold + ) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -147,7 +149,8 @@ class BasicNorm(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + + self.eps.exp() ) ** -0.5 return x * scales @@ -179,7 +182,11 @@ class ScaledLinear(nn.Linear): """ def __init__( - self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs ): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -195,12 +202,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -211,13 +218,19 @@ class ScaledLinear(nn.Linear): return None if self.bias is None else self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) + return torch.nn.functional.linear( + input, self.get_weight(), self.get_bias() + ) class ScaledConv1d(nn.Conv1d): # See docs for ScaledLinear def __init__( - self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs ): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -232,12 +245,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -277,7 +290,11 @@ class ScaledConv1d(nn.Conv1d): class ScaledConv2d(nn.Conv2d): # See docs for ScaledLinear def __init__( - self, *args, initial_scale: float = 1.0, initial_speed: float = 1.0, **kwargs + self, + *args, + initial_scale: float = 1.0, + initial_speed: float = 1.0, + **kwargs ): super(ScaledConv2d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() @@ -292,12 +309,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -636,7 +653,9 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -666,8 +685,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms diff --git a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py index 3f21133a0..6293e081a 100644 --- a/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py +++ b/egs/librispeech/ASR/pruned2_knowledge/scaling_tmp.py @@ -15,23 +15,21 @@ # limitations under the License. -from typing import Optional, Tuple - import torch import torch.nn as nn from torch import Tensor +from typing import Tuple, Optional -def _activation_balancer_loss( - mean_pos: Tensor, - mean_neg: Tensor, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 - eps: float = 1.0e-10, -): + +def _activation_balancer_loss(mean_pos: Tensor, + mean_neg: Tensor, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 + eps: float = 1.0e-10): """ Returns a loss-function for the ActivationBalancer module. This loss function is not exposed to the user but is used internally, and eventually @@ -52,32 +50,28 @@ def _activation_balancer_loss( """ loss_parts = [] - x_mean = mean_pos - mean_neg - x_mean_abs = (mean_pos + mean_neg + eps).detach() - x_rel_mean = x_mean / x_mean_abs + x_mean = mean_positive - mean_negative + x_mean_abs = (mean_positive + mean_negative + eps).detach() + x_rel_mean= x_mean / x_mean_abs if min_positive != 0.0: # e.g. x_mean_floor = -0.95 + 0.05 = -0.9 - x_rel_mean_floor = -(1 - min_positive) + min_positive - min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * ( - 1.0 / (2 * min_positive) - ) + x_rel_mean_floor = (-(1-min_positive) + min_positive) + min_positive_loss = (x_rel_mean_floor - x_rel_mean).relu().sum() * (1.0 / (2*min_positive)) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_positive_loss) if max_positive != 1.0: # e.g. x_mean_floor = -0.05 + 0.95 = 0.8 - x_rel_mean_ceil = -(1.0 - max_positive) + max_positive - max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * ( - 1.0 / (1 - x_rel_mean_ceil) - ) + x_rel_mean_ceil = - (1.0-max_positive) + max_positive + max_positive_loss = (x_rel_mean - x_rel_mean_ceil).relu().sum() * (1.0 / (1 - x_rel_mean_ceil)) # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(max_positive_loss) if min_abs != 0.0: - min_abs_loss = (min_abs - x_mean_abs).relu().sum() / min_abs + min_abs_loss = min_abs - x_mean_abs).relu().sum() / min_abs # this part of the loss would be 1.0 * num_channels if all these constraints were # 100% violated. loss_parts.append(min_abs_loss) @@ -88,53 +82,43 @@ def _activation_balancer_loss( # 100% violated. loss_parts.append(max_abs_loss) + # the min_positive and 1 - max_positive are "ballast" added to the denom = mean_pos + mean_neg + (min_positive + (1 - max_positive)) - # num + num if min_positive != 0.0: - pass + + class ActivationBalancerFunction(torch.autograd.Function): @staticmethod - def forward( - ctx, - x: Tensor, - channel_dim: int, - min_positive: float, # e.g. 0.05 - max_positive: float, # e.g. 0.95 - max_factor: float, # e.g. 0.01 - min_abs: float, # e.g. 0.2 - max_abs: float, # e.g. 100.0 + def forward(ctx, x: Tensor, + channel_dim: int, + min_positive: float, # e.g. 0.05 + max_positive: float, # e.g. 0.95 + max_factor: float, # e.g. 0.01 + min_abs: float, # e.g. 0.2 + max_abs: float, # e.g. 100.0 ) -> Tensor: if x.requires_grad: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] xgt0 = x > 0 - proportion_positive = torch.mean( - xgt0.to(x.dtype), dim=sum_dims, keepdim=True - ) - factor1 = ( - (min_positive - proportion_positive).relu() - * (max_factor / min_positive) - if min_positive != 0.0 - else 0.0 - ) - factor2 = ( - (proportion_positive - max_positive).relu() - * (max_factor / (max_positive - 1.0)) - if max_positive != 1.0 - else 0.0 - ) + proportion_positive = torch.mean(xgt0.to(x.dtype), dim=sum_dims, keepdim=True) + factor1 = ((min_positive - proportion_positive).relu() * (max_factor / min_positive) + if min_positive != 0.0 else 0.0) + factor2 = ((proportion_positive - max_positive).relu() * (max_factor / (max_positive - 1.0)) + if max_positive != 1.0 else 0.0) factor = factor1 + factor2 if isinstance(factor, float): factor = torch.zeros_like(proportion_positive) mean_abs = torch.mean(x.abs(), dim=sum_dims, keepdim=True) - below_threshold = mean_abs < min_abs - above_threshold = mean_abs > max_abs + below_threshold = (mean_abs < min_abs) + above_threshold = (mean_abs > max_abs) ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) ctx.max_factor = max_factor @@ -142,16 +126,11 @@ class ActivationBalancerFunction(torch.autograd.Function): return x @staticmethod - def backward( - ctx, x_grad: Tensor - ) -> Tuple[Tensor, None, None, None, None, None, None]: + def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None, None, None, None]: factor, xgt0, below_threshold, above_threshold = ctx.saved_tensors dtype = x_grad.dtype - scale_factor = ( - (below_threshold.to(dtype) - above_threshold.to(dtype)) - * (xgt0.to(dtype) - 0.5) - * (ctx.max_factor * 2.0) - ) + scale_factor = ((below_threshold.to(dtype) - above_threshold.to(dtype)) * + (xgt0.to(dtype) - 0.5) * (ctx.max_factor * 2.0)) neg_delta_grad = x_grad.abs() * (factor + scale_factor) return x_grad - neg_delta_grad, None, None, None, None, None, None @@ -184,30 +163,29 @@ class BasicNorm(torch.nn.Module): learn_eps: if true, we learn epsilon; if false, we keep it at the initial value. """ - - def __init__( - self, - num_channels: int, - channel_dim: int = -1, # CAUTION: see documentation. - eps: float = 0.25, - learn_eps: bool = True, - ) -> None: + def __init__(self, + num_channels: int, + channel_dim: int = -1, # CAUTION: see documentation. + eps: float = 0.25, + learn_eps: bool = True) -> None: super(BasicNorm, self).__init__() self.num_channels = num_channels self.channel_dim = channel_dim if learn_eps: self.eps = nn.Parameter(torch.tensor(eps).log().detach()) else: - self.register_buffer("eps", torch.tensor(eps).log().detach()) + self.register_buffer('eps', torch.tensor(eps).log().detach()) + def forward(self, x: Tensor) -> Tensor: assert x.shape[self.channel_dim] == self.num_channels - scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() - ) ** -0.5 + scales = (torch.mean(x**2, dim=self.channel_dim, keepdim=True) + + self.eps.exp()) ** -0.5 return x * scales + + class ScaledLinear(nn.Linear): """ A modified version of nn.Linear where the parameters are scaled before @@ -229,26 +207,27 @@ class ScaledLinear(nn.Linear): inherited from nn.Linear. For modules with small fan-in, this may be larger than optimal. """ - - def __init__(self, *args, initial_scale: float = 1.0, **kwargs): + def __init__(self, *args, + initial_scale: float = 1.0, + **kwargs): super(ScaledLinear, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter("bias_scale", None) + self.register_parameter('bias_scale', None) self._reset_parameters() # Overrides the reset_parameters in nn.Linear def _reset_parameters(self): std = 0.01 - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: @@ -258,67 +237,56 @@ class ScaledLinear(nn.Linear): return self.weight * self.weight_scale.exp() def get_bias(self): - return None if self.bias is None else self.bias * self.bias_scale.exp() + return (None if self.bias is None else + self.bias * self.bias_scale.exp()) def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) + return torch.nn.functional.linear(input, self.get_weight(), + self.get_bias()) class ScaledConv1d(nn.Conv1d): - def __init__(self, *args, initial_scale=1.0, **kwargs): + def __init__(self, *args, + initial_scale=1.0, **kwargs): super(ScaledConv1d, self).__init__(*args, **kwargs) initial_scale = torch.tensor(initial_scale).log() self.weight_scale = nn.Parameter(initial_scale.clone().detach()) if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter("bias_scale", None) + self.register_parameter('bias_scale', None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() + def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return None if self.bias is None else self.bias * self.bias_scale.exp() + return (None if self.bias is None else + self.bias * self.bias_scale.exp()) def forward(self, input: Tensor) -> Tensor: F = torch.nn.functional - if self.padding_mode != "zeros": - return F.conv1d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - self.get_weight(), - self.get_bias(), - self.stride, - _single(0), # noqa: F821 - self.dilation, - self.groups, - ) - return F.conv1d( - input, - self.get_weight(), - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) + if self.padding_mode != 'zeros': + return F.conv1d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + self.get_weight(), self.get_bias(), self.stride, + _single(0), self.dilation, self.groups) + return F.conv1d(input, self.get_weight(), self.get_bias(), self.stride, + self.padding, self.dilation, self.groups) + class ScaledConv2d(nn.Conv2d): @@ -329,58 +297,45 @@ class ScaledConv2d(nn.Conv2d): if self.bias is not None: self.bias_scale = nn.Parameter(initial_scale.clone().detach()) else: - self.register_parameter("bias_scale", None) + self.register_parameter('bias_scale', None) self._reset_parameters() # Overrides the reset_parameters in base class def _reset_parameters(self): std = 0.01 - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() if self.bias is not None: self.bias_scale += torch.tensor(scale / std).log() + def get_weight(self): return self.weight * self.weight_scale.exp() def get_bias(self): - return None if self.bias is None else self.bias * self.bias_scale.exp() + return (None if self.bias is None else + self.bias * self.bias_scale.exp()) def _conv_forward(self, input, weight): F = torch.nn.functional - if self.padding_mode != "zeros": - return F.conv2d( - F.pad( - input, - self._reversed_padding_repeated_twice, - mode=self.padding_mode, - ), - weight, - self.get_bias(), - self.stride, - _pair(0), # noqa: F821 - self.dilation, - self.groups, - ) - return F.conv2d( - input, - weight, - self.get_bias(), - self.stride, - self.padding, - self.dilation, - self.groups, - ) + if self.padding_mode != 'zeros': + return F.conv2d(F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode), + weight, self.get_bias(), self.stride, + _pair(0), self.dilation, self.groups) + return F.conv2d(input, weight, self.get_bias(), self.stride, + self.padding, self.dilation, self.groups) def forward(self, input: Tensor) -> Tensor: return self._conv_forward(input, self.get_weight()) + + class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -409,16 +364,12 @@ class ActivationBalancer(torch.nn.Module): we allow, before we start to modify the derivatives to prevent this. """ - - def __init__( - self, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.01, - min_abs: float = 0.2, - max_abs: float = 100.0, - ): + def __init__(self, channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.01, + min_abs: float = 0.2, + max_abs: float = 100.0): super(ActivationBalancer, self).__init__() self.channel_dim = channel_dim self.min_positive = min_positive @@ -428,15 +379,10 @@ class ActivationBalancer(torch.nn.Module): self.max_abs = max_abs def forward(self, x: Tensor) -> Tensor: - return ActivationBalancerFunction.apply( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - self.max_factor, - self.min_abs, - self.max_abs, - ) + return ActivationBalancerFunction.apply(x, self.channel_dim, + self.min_positive, self.max_positive, + self.max_factor, self.min_abs, + self.max_abs) class DoubleSwishFunction(torch.autograd.Function): @@ -454,7 +400,6 @@ class DoubleSwishFunction(torch.autograd.Function): = double_swish(x) * (1-s(x)) + s(x) ... so we just need to remember s(x) but not x itself. """ - @staticmethod def forward(ctx, x: Tensor) -> Tensor: x = x.detach() @@ -466,17 +411,18 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: s, y = ctx.saved_tensors - return (y * (1 - s) + s) * y_grad - + return (y * (1-s) + s) * y_grad class DoubleSwish(torch.nn.Module): def forward(self, x: Tensor) -> Tensor: """Return double-swish activation function which is an approximation to Swish(Swish(x)), - that we approximate closely with x * sigmoid(x-1). + that we approximate closely with x * sigmoid(x-1). """ return DoubleSwishFunction.apply(x) + + class ScaledEmbedding(nn.Module): r"""A simple lookup table that stores embeddings of a fixed dictionary and size. @@ -545,13 +491,8 @@ class ScaledEmbedding(nn.Module): [ 0.0000, 0.0000, 0.0000], [-0.1655, 0.9897, 0.0635]]]) """ - __constants__ = [ - "num_embeddings", - "embedding_dim", - "padding_idx", - "scale_grad_by_freq", - "sparse", - ] + __constants__ = ['num_embeddings', 'embedding_dim', 'padding_idx', + 'scale_grad_by_freq', 'sparse'] num_embeddings: int embedding_dim: int @@ -560,40 +501,33 @@ class ScaledEmbedding(nn.Module): weight: Tensor sparse: bool - def __init__( - self, - num_embeddings: int, - embedding_dim: int, - padding_idx: Optional[int] = None, - scale_grad_by_freq: bool = False, - sparse: bool = False, - ) -> None: + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None, + scale_grad_by_freq: bool = False, + sparse: bool = False) -> None: super(ScaledEmbedding, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: - assert ( - padding_idx < self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' elif padding_idx < 0: - assert ( - padding_idx >= -self.num_embeddings - ), "Padding_idx must be within num_embeddings" + assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.scale_grad_by_freq = scale_grad_by_freq - self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() + self.scale = nn.Parameter(torch.zeros(())) # see reset_parameters() self.sparse = sparse self.weight = nn.Parameter(torch.Tensor(num_embeddings, embedding_dim)) self.reset_parameters() + + def reset_parameters(self) -> None: std = 0.01 nn.init.normal_(self.weight, std=std) - nn.init.constant_(self.scale, torch.tensor(1.0 / std).log()) + nn.init.constant_(self.scale, torch.tensor(1.0/std).log()) if self.padding_idx is not None: with torch.no_grad(): @@ -603,37 +537,24 @@ class ScaledEmbedding(nn.Module): F = torch.nn.functional scale = self.scale.exp() if input.numel() < self.num_embeddings: - return ( - F.embedding( - input, - self.weight, - self.padding_idx, - None, - 2.0, # None, 2.0 relate to normalization - self.scale_grad_by_freq, - self.sparse, - ) - * scale - ) + return F.embedding( + input, self.weight, self.padding_idx, + None, 2.0, # None, 2.0 relate to normalization + self.scale_grad_by_freq, self.sparse) * scale else: return F.embedding( - input, - self.weight * scale, - self.padding_idx, - None, - 2.0, # None, 2.0 relates to normalization - self.scale_grad_by_freq, - self.sparse, - ) + input, self.weight * scale, self.padding_idx, + None, 2.0, # None, 2.0 relates to normalization + self.scale_grad_by_freq, self.sparse) def extra_repr(self) -> str: - s = "{num_embeddings}, {embedding_dim}, scale={scale}" + s = '{num_embeddings}, {embedding_dim}, scale={scale}' if self.padding_idx is not None: - s += ", padding_idx={padding_idx}" + s += ', padding_idx={padding_idx}' if self.scale_grad_by_freq is not False: - s += ", scale_grad_by_freq={scale_grad_by_freq}" + s += ', scale_grad_by_freq={scale_grad_by_freq}' if self.sparse is not False: - s += ", sparse=True" + s += ', sparse=True' return s.format(**self.__dict__) @@ -644,13 +565,8 @@ def _test_activation_balancer_sign(): x = 1.0 * (torch.rand(probs.numel(), N) < probs.unsqueeze(-1)) x = x.detach() x.requires_grad = True - m = ActivationBalancer( - channel_dim=0, - min_positive=0.05, - max_positive=0.95, - max_factor=0.2, - min_abs=0.0, - ) + m = ActivationBalancer(channel_dim=0, min_positive=0.05, max_positive=0.95, + max_factor=0.2, min_abs=0.0) y_grad = torch.sign(torch.randn(probs.numel(), N)) @@ -660,22 +576,17 @@ def _test_activation_balancer_sign(): print("_test_activation_balancer_sign: y grad = ", y_grad) print("_test_activation_balancer_sign: x grad = ", x.grad) - def _test_activation_balancer_magnitude(): channel_dim = 0 magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) x = x.detach() x.requires_grad = True - m = ActivationBalancer( - channel_dim=0, - min_positive=0.0, - max_positive=1.0, - max_factor=0.2, - min_abs=0.2, - max_abs=0.8, - ) + m = ActivationBalancer(channel_dim=0, + min_positive=0.0, max_positive=1.0, + max_factor=0.2, + min_abs=0.2, max_abs=0.8) y_grad = torch.sign(torch.randn(magnitudes.numel(), N)) @@ -710,7 +621,7 @@ def _test_double_swish_deriv(): torch.autograd.gradcheck(m, x) -if __name__ == "__main__": +if __name__ == '__main__': _test_activation_balancer_sign() _test_activation_balancer_magnitude() _test_basic_norm() diff --git a/egs/librispeech/ASR/pruned2_knowledge/train.py b/egs/librispeech/ASR/pruned2_knowledge/train.py index a60d15c3b..2f6840166 100755 --- a/egs/librispeech/ASR/pruned2_knowledge/train.py +++ b/egs/librispeech/ASR/pruned2_knowledge/train.py @@ -78,7 +78,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -177,45 +179,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -555,16 +554,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -727,7 +733,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -827,7 +835,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py index 1df1650f3..2d5724d30 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/decode.py @@ -123,24 +123,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -208,7 +204,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -275,7 +272,9 @@ def decode_one_batch( value=LOG_EPS, ) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -290,7 +289,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -336,7 +338,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -409,7 +415,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -442,7 +450,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -485,7 +494,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -517,12 +528,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -545,12 +557,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -578,7 +591,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py index 008f40fb1..318cd5094 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/emformer.py @@ -272,9 +272,13 @@ class Emformer(EncoderInterface): # Caution: We assume the subsampling factor is 4! x_lens = (((x_lens - 1) >> 1) - 1) >> 1 - emformer_out, emformer_out_lens, states = self.model.infer(x, x_lens, states) + emformer_out, emformer_out_lens, states = self.model.infer( + x, x_lens, states + ) - if x.size(1) != (self.model.segment_length + self.model.right_context_length): + if x.size(1) != ( + self.model.segment_length + self.model.right_context_length + ): raise ValueError( "Incorrect input shape." f"{x.size(1)} vs {self.model.segment_length} + " diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py index 81afb523d..2375f5001 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/export.py @@ -89,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -137,7 +133,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -173,12 +170,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -201,12 +199,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -234,7 +233,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -274,7 +273,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py index ed6848879..2f019bcdb 100644 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/model.py @@ -122,7 +122,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py index 6b30d3be8..fed814f19 100755 --- a/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py +++ b/egs/librispeech/ASR/pruned_stateless_emformer_rnnt2/train.py @@ -209,45 +209,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -569,7 +566,11 @@ def compute_loss( function enables autograd during computation; when it is False, it disables autograd. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -598,7 +599,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -779,7 +782,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -903,7 +908,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py index 830b37cfb..7af9cc3d7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/beam_search.py @@ -509,9 +509,9 @@ def greedy_search( y = logits.argmax().item() if y not in (blank_id, unk_id): hyp.append(y) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -670,7 +670,9 @@ class HypothesisList(object): if use_max: old_hyp.log_prob = max(old_hyp.log_prob, hyp.log_prob) else: - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -686,7 +688,9 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -888,7 +892,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1082,7 +1088,9 @@ def beam_search( t = 0 B = HypothesisList() - B.add(Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max) + B.add( + Hypothesis(ys=[blank_id] * context_size, log_prob=0.0), use_max=use_max + ) max_sym_per_utt = 20000 @@ -1122,7 +1130,9 @@ def beam_search( cached_key += f"-t-{t}" if cached_key not in joint_cache: - logits = model.joiner(current_encoder_out, decoder_out.unsqueeze(1)) + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1) + ) # TODO(fangjun): Scale the blank posterior diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py index 03ad45f49..7b6338948 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode.py @@ -128,7 +128,11 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -167,11 +171,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -267,7 +269,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -380,7 +383,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if ( @@ -445,7 +450,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -576,7 +584,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -609,7 +619,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -667,7 +678,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -705,7 +718,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -743,7 +757,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py index e522943c0..386248554 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decode_stream.py @@ -75,7 +75,9 @@ class DecodeStream(object): # encoder.streaming_forward self.done_frames: int = 0 - self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3 + self.pad_length = ( + params.right_context + 2 + ) * params.subsampling_factor + 3 if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size @@ -89,11 +91,13 @@ class DecodeStream(object): ) elif params.decoding_method == "fast_beam_search": # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( - decoding_graph + self.rnnt_decoding_stream: k2.RnntDecodingStream = ( + k2.RnntDecodingStream(decoding_graph) ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) @property def done(self) -> bool: @@ -122,10 +126,13 @@ class DecodeStream(object): """Consume chunk_size frames of features""" chunk_length = chunk_size + self.pad_length - ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) + ret_length = min( + self.num_frames - self.num_processed_frames, chunk_length + ) ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames + ret_length # noqa + self.num_processed_frames : self.num_processed_frames # noqa + + ret_length ] self.num_processed_frames += chunk_size diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py index 72593173c..f4355e8a0 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/decoder.py @@ -92,7 +92,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/export.py b/egs/librispeech/ASR/pruned_transducer_stateless/export.py index 64708e524..b5a151878 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/export.py @@ -64,20 +64,17 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -108,7 +105,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -194,7 +192,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/model.py b/egs/librispeech/ASR/pruned_transducer_stateless/model.py index 2cca7fa27..73b651b3f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/model.py @@ -130,7 +130,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py index a42b63b9c..eb95827af 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/pretrained.py @@ -91,11 +91,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -120,12 +118,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -172,7 +168,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -224,9 +221,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -294,7 +292,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -381,7 +381,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py index 9e09200a1..dcf6dc42f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_beam_search.py @@ -166,10 +166,14 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk( + num_active_paths + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py index a50b4d4f0..d2cae4f9f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/streaming_decode.py @@ -51,7 +51,11 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -90,11 +94,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -160,7 +162,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -266,7 +269,9 @@ def decode_one_chunk( ) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -286,7 +291,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -342,7 +349,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -413,7 +422,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -449,7 +460,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -521,7 +533,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/train.py b/egs/librispeech/ASR/pruned_transducer_stateless/train.py index dd0331a60..399b11a29 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/train.py @@ -203,45 +203,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -565,7 +562,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -585,7 +584,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -776,7 +777,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -894,7 +897,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -952,7 +956,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 5e9428b60..b7c2010f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -580,9 +580,9 @@ def greedy_search( if y not in (blank_id, unk_id): hyp.append(y) timestamp.append(t) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) decoder_out = model.joiner.decoder_proj(decoder_out) @@ -775,7 +775,9 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -791,7 +793,9 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -986,7 +990,9 @@ def modified_beam_search( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) @@ -998,7 +1004,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1668,7 +1676,9 @@ def fast_beam_search_with_nbest_rnn_rescoring( for rnn_scale in rnn_lm_scale_list: key = f"ngram_lm_scale_{n_scale}_rnn_lm_scale_{rnn_scale}" tot_scores = ( - am_scores.values + n_scale * ngram_lm_scores + rnn_scale * rnn_lm_scores + am_scores.values + + n_scale * ngram_lm_scores + + rnn_scale * rnn_lm_scores ) ragged_tot_scores = k2.RaggedTensor(nbest.shape, tot_scores) max_indexes = ragged_tot_scores.argmax() @@ -1794,7 +1804,9 @@ def modified_beam_search_ngram_rescoring( logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) - log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs = (logits / temperature).log_softmax( + dim=-1 + ) # (num_hyps, vocab_size) log_probs.add_(ys_log_probs) vocab_size = log_probs.size(-1) @@ -1804,7 +1816,9 @@ def modified_beam_search_ngram_rescoring( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -1827,7 +1841,9 @@ def modified_beam_search_ngram_rescoring( state_cost = hyp.state_cost # We only keep AM scores in new_hyp.log_prob - new_log_prob = topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale + new_log_prob = ( + topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale + ) new_hyp = Hypothesis( ys=new_ys, log_prob=new_log_prob, state_cost=state_cost @@ -1979,7 +1995,9 @@ def modified_beam_search_rnnlm_shallow_fusion( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) """ for all hyps with a non-blank new token, score this token. It is a little confusing here because this for-loop @@ -2014,7 +2032,10 @@ def modified_beam_search_rnnlm_shallow_fusion( # forward RNNLM to get new states and scores if len(token_list) != 0: tokens_to_score = ( - torch.tensor(token_list).to(torch.int64).to(device).reshape(-1, 1) + torch.tensor(token_list) + .to(torch.int64) + .to(device) + .reshape(-1, 1) ) hs = torch.cat(hs, dim=1).to(device) @@ -2046,7 +2067,9 @@ def modified_beam_search_rnnlm_shallow_fusion( ys.append(new_token) new_timestamp.append(t) - hyp_log_prob += lm_score[new_token] * lm_scale # add the lm score + hyp_log_prob += ( + lm_score[new_token] * lm_scale + ) # add the lm score lm_score = scores[count] state = ( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py index 34ff0d7e2..bc273d33b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py @@ -214,7 +214,10 @@ class Conformer(EncoderInterface): NOTE: the returned tensors are on the given device. """ - if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: + if ( + len(self._init_state) == 2 + and self._init_state[0].size(1) == left_context + ): # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -436,7 +439,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -454,7 +459,9 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, causal=causal + ) self.norm_final = BasicNorm(d_model) @@ -520,7 +527,9 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + conv, _ = self.conv_module( + src, src_key_padding_mask=src_key_padding_mask + ) src = src + self.dropout(conv) # feed forward module @@ -776,7 +785,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() if is_jit_tracing(): @@ -800,7 +811,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -1114,9 +1127,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -1185,25 +1198,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1243,15 +1264,23 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul( + q_with_bias_v, p + ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) if not is_jit_tracing(): assert list(attn_output_weights.size()) == [ @@ -1293,17 +1322,21 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) - else: - # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + combined_mask = attn_mask | key_padding_mask.unsqueeze( 1 ).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze( + 0 + ) | key_padding_mask.unsqueeze(1).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.masked_fill( + combined_mask, 0.0 + ) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1322,9 +1355,13 @@ class RelPositionMultiheadAttention(nn.Module): ] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1461,12 +1498,16 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert not self.training, "Cache should be None in training time" + assert ( + not self.training + ), "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : (-right_context), # noqa + -(self.lorder + right_context) : ( # noqa + -right_context + ), ..., ] else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py index 32cd53be3..979a0e02e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decode.py @@ -132,7 +132,11 @@ from beam_search import ( ) from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -173,11 +177,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -273,7 +275,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -394,7 +397,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -460,7 +465,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -506,7 +514,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } elif "fast_beam_search" in params.decoding_method: key = f"beam_{params.beam}_" @@ -596,7 +608,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -629,7 +643,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -685,7 +700,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -723,7 +740,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -761,7 +779,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py index b59928103..ba91302ce 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/decoder.py @@ -107,11 +107,15 @@ class Decoder(nn.Module): # This is for exporting to PNNX via ONNX embedding_out = self.embedding(y) else: - embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze(-1) + embedding_out = self.embedding(y.clamp(min=0)) * (y >= 0).unsqueeze( + -1 + ) if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py index 90367bd03..f1a8ea589 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/export.py @@ -51,7 +51,11 @@ import sentencepiece as spm import torch from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -83,11 +87,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -118,7 +120,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -170,7 +173,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -218,7 +222,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py index 1954f4724..6a9d08033 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/joiner.py @@ -60,7 +60,9 @@ class Joiner(nn.Module): assert encoder_out.shape == decoder_out.shape if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index 272d06c37..417c391d9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -66,7 +66,9 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) def forward( @@ -150,7 +152,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py index 2d7f557ad..041a81f45 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/optim.py @@ -72,11 +72,17 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -112,7 +118,9 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") + raise RuntimeError( + "AdamW does not support sparse gradients" + ) state = self.state[p] @@ -139,7 +147,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( group["eps"] ) @@ -150,7 +158,9 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) @@ -170,14 +180,18 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("initial_lr", group["lr"]) - self.base_lrs = [group["initial_lr"] for group in optimizer.param_groups] + self.base_lrs = [ + group["initial_lr"] for group in optimizer.param_groups + ] self.epoch = 0 self.batch = 0 @@ -285,9 +299,10 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 ) return [x * factor for x in self.base_lrs] diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py index 58de6875f..f52cb22ab 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -91,11 +91,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -120,12 +118,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -172,7 +168,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -225,9 +222,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -295,7 +293,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -382,7 +382,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py index f671e97b1..8c572a9ef 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py @@ -89,7 +89,9 @@ class ActivationBalancerFunction(torch.autograd.Function): below_threshold = mean_abs < min_abs above_threshold = mean_abs > max_abs - ctx.save_for_backward(factor, xgt0, below_threshold, above_threshold) + ctx.save_for_backward( + factor, xgt0, below_threshold, above_threshold + ) ctx.max_factor = max_factor ctx.sum_dims = sum_dims return x @@ -135,7 +137,7 @@ class GradientFilterFunction(torch.autograd.Function): eps = 1.0e-20 dim = ctx.batch_dim norm_dims = [d for d in range(x_grad.ndim) if d != dim] - norm_of_batch = (x_grad**2).mean(dim=norm_dims, keepdim=True).sqrt() + norm_of_batch = (x_grad ** 2).mean(dim=norm_dims, keepdim=True).sqrt() median_norm = norm_of_batch.median() cutoff = median_norm * ctx.threshold @@ -227,7 +229,8 @@ class BasicNorm(torch.nn.Module): if not is_jit_tracing(): assert x.shape[self.channel_dim] == self.num_channels scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + self.eps.exp() + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + + self.eps.exp() ) ** -0.5 return x * scales @@ -279,12 +282,12 @@ class ScaledLinear(nn.Linear): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -298,7 +301,9 @@ class ScaledLinear(nn.Linear): return self.bias * self.bias_scale.exp() def forward(self, input: Tensor) -> Tensor: - return torch.nn.functional.linear(input, self.get_weight(), self.get_bias()) + return torch.nn.functional.linear( + input, self.get_weight(), self.get_bias() + ) class ScaledConv1d(nn.Conv1d): @@ -326,12 +331,12 @@ class ScaledConv1d(nn.Conv1d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -395,12 +400,12 @@ class ScaledConv2d(nn.Conv2d): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std + a = (3 ** 0.5) * std nn.init.uniform_(self.weight, -a, a) if self.bias is not None: nn.init.constant_(self.bias, 0.0) fan_in = self.weight.shape[1] * self.weight[0][0].numel() - scale = fan_in**-0.5 # 1/sqrt(fan_in) + scale = fan_in ** -0.5 # 1/sqrt(fan_in) with torch.no_grad(): self.weight_scale += torch.tensor(scale / std).log() @@ -471,7 +476,9 @@ class ScaledLSTM(nn.LSTM): setattr(self, scale_name, param) self._scales.append(param) - self.grad_filter = GradientFilter(batch_dim=1, threshold=grad_norm_threshold) + self.grad_filter = GradientFilter( + batch_dim=1, threshold=grad_norm_threshold + ) self._reset_parameters( initial_speed @@ -479,8 +486,8 @@ class ScaledLSTM(nn.LSTM): def _reset_parameters(self, initial_speed: float): std = 0.1 / initial_speed - a = (3**0.5) * std - scale = self.hidden_size**-0.5 + a = (3 ** 0.5) * std + scale = self.hidden_size ** -0.5 v = scale / std for idx, name in enumerate(self._flat_weights_names): if "weight" in name: @@ -552,11 +559,15 @@ class ScaledLSTM(nn.LSTM): """Get scaled weights, and resets their data pointer.""" flat_weights = [] for idx in range(len(self._flat_weights_names)): - flat_weights.append(self._flat_weights[idx] * self._scales[idx].exp()) + flat_weights.append( + self._flat_weights[idx] * self._scales[idx].exp() + ) self._flatten_parameters(flat_weights) return flat_weights - def forward(self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None): + def forward( + self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None + ): # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa # The change for calling `_VF.lstm()` is: # self._flat_weights -> self._get_flat_weights() @@ -904,7 +915,9 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -934,8 +947,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -988,18 +1001,17 @@ def _test_grad_filter(): ) print( - "_test_grad_filter: for gradient norms, the first element > median *" - " threshold ", # noqa + "_test_grad_filter: for gradient norms, the first element > median * threshold ", # noqa i % 2 == 1, ) print( "_test_grad_filter: x_out_grad norm = ", - (x_out_grad**2).mean(dim=(0, 2)).sqrt(), + (x_out_grad ** 2).mean(dim=(0, 2)).sqrt(), ) print( "_test_grad_filter: x.grad norm = ", - (x.grad**2).mean(dim=(0, 2)).sqrt(), + (x.grad ** 2).mean(dim=(0, 2)).sqrt(), ) print("_test_grad_filter: w_out_grad = ", w_out_grad) print("_test_grad_filter: w.grad = ", w.grad) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py index e6e0fb1c8..9bcd2f9f9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_beam_search.py @@ -153,7 +153,9 @@ def modified_beam_search( index=hyps_shape.row_ids(1).to(torch.int64), ) # (num_hyps, encoder_out_dim) - logits = model.joiner(current_encoder_out, decoder_out, project_input=False) + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # logits is of shape (num_hyps, 1, 1, vocab_size) logits = logits.squeeze(1).squeeze(1) @@ -170,10 +172,14 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk( + num_active_paths + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py index 0139863a1..d76a03946 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/streaming_decode.py @@ -51,7 +51,11 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -90,11 +94,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -160,7 +162,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -268,7 +271,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -288,7 +293,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -344,7 +351,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -416,7 +425,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -451,7 +462,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -524,7 +536,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 623bdd51a..1947834bf 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -96,7 +96,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -208,7 +210,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need to " + "be changed.", ) parser.add_argument( @@ -231,45 +234,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -634,7 +634,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -647,9 +649,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -660,7 +667,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -828,7 +837,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -952,7 +963,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py index 5e81aef07..1df7f9ee5 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/asr_datamodule.py @@ -27,7 +27,10 @@ from lhotse.dataset import ( K2SpeechRecognitionDataset, SpecAugment, ) -from lhotse.dataset.input_strategies import OnTheFlyFeatures, PrecomputedFeatures +from lhotse.dataset.input_strategies import ( + OnTheFlyFeatures, + PrecomputedFeatures, +) from torch.utils.data import DataLoader from icefall.utils import str2bool @@ -41,69 +44,59 @@ class AsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler. " - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler. " + "(you might want to increase it for larger datasets).", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -124,22 +117,18 @@ class AsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", ) group.add_argument( @@ -153,11 +142,9 @@ class AsrDataModule: "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available. Used only in dev/test CutSet" - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available. Used only in dev/test CutSet", ) def train_dataloaders( @@ -180,7 +167,9 @@ class AsrDataModule: if cuts_musan is not None: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") @@ -189,7 +178,9 @@ class AsrDataModule: if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -259,7 +250,9 @@ class AsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py index 66c8e30ba..5784a78ba 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode-giga.py @@ -79,7 +79,11 @@ from gigaspeech import GigaSpeech from gigaspeech_scoring import asr_text_post_processing from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -116,11 +120,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -190,7 +192,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -277,7 +280,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -307,7 +312,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -351,11 +359,21 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } elif params.decoding_method == "fast_beam_search_nbest_oracle": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}_num_paths_{params.num_paths}_nbest_scale_{params.nbest_scale}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}_" + f"num_paths_{params.num_paths}_" + f"nbest_scale_{params.nbest_scale}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -428,7 +446,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -461,7 +481,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -511,7 +532,9 @@ def main(): params.suffix += f"-num-paths-{params.num_paths}" params.suffix += f"-nbest-scale-{params.nbest_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -544,7 +567,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py index d90497e26..8025d6be1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/decode.py @@ -120,7 +120,11 @@ from beam_search import ( from librispeech import LibriSpeech from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon from icefall.rnn_lm.model import RnnLmModel from icefall.utils import ( @@ -163,11 +167,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -263,7 +265,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -475,7 +478,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -545,7 +550,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -638,11 +646,21 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}temperature_{params.temperature}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + f"temperature_{params.temperature}" + ): hyps } elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}temperature_{params.temperature}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + f"temperature_{params.temperature}" + ): hyps } elif params.decoding_method in [ "fast_beam_search_with_nbest_rescoring", @@ -672,7 +690,12 @@ def decode_one_batch( key += f"_ngram_lm_scale_{params.ngram_lm_scale}" return {key: hyps} else: - return {f"beam_size_{params.beam_size}_temperature_{params.temperature}": hyps} + return { + ( + f"beam_size_{params.beam_size}_" + f"temperature_{params.temperature}" + ): hyps + } def decode_dataset( @@ -756,7 +779,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -789,7 +814,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -913,7 +939,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) params.suffix += f"-temperature-{params.temperature}" else: params.suffix += f"-context-{params.context_size}" @@ -953,7 +981,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -1003,10 +1032,15 @@ def main(): word_table=word_table, device=device, ) - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) logging.info(f"G properties_str: {G.properties_str}") rnn_lm_model = None - if params.decoding_method == "fast_beam_search_with_nbest_rnn_rescoring": + if ( + params.decoding_method + == "fast_beam_search_with_nbest_rnn_rescoring" + ): rnn_lm_model = RnnLmModel( vocab_size=params.vocab_size, embedding_dim=params.rnn_lm_embedding_dim, @@ -1031,7 +1065,9 @@ def main(): rnn_lm_model.eval() else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) rnn_lm_model = None else: decoding_graph = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py index dcf65e937..47217ba05 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/export.py @@ -128,7 +128,11 @@ import torch.nn as nn from scaling_converter import convert_scaled_to_non_scaled from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -160,11 +164,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -233,7 +235,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -506,9 +509,13 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) - decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -609,7 +616,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -707,7 +715,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py index 598434f54..36f32c6b3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/gigaspeech.py @@ -52,14 +52,18 @@ class GigaSpeech: ) pattern = re.compile(r"gigaspeech_cuts_XL.([0-9]+).jsonl.gz") - idx_filenames = [(int(pattern.search(f).group(1)), f) for f in filenames] + idx_filenames = [ + (int(pattern.search(f).group(1)), f) for f in filenames + ] idx_filenames = sorted(idx_filenames, key=lambda x: x[0]) sorted_filenames = [f[1] for f in idx_filenames] logging.info(f"Loading {len(sorted_filenames)} splits") - return lhotse.combine(lhotse.load_manifest_lazy(p) for p in sorted_filenames) + return lhotse.combine( + lhotse.load_manifest_lazy(p) for p in sorted_filenames + ) def train_L_cuts(self) -> CutSet: f = self.manifest_dir / "gigaspeech_cuts_L_raw.jsonl.gz" diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py index 108915389..162f8c7db 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/jit_pretrained.py @@ -104,12 +104,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -144,9 +142,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -331,7 +330,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index d45f6dadc..7852f84e9 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -84,7 +84,9 @@ class Transducer(nn.Module): self.decoder_giga = decoder_giga self.joiner_giga = joiner_giga - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) if decoder_giga is not None: @@ -188,7 +190,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = encoder_out_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py index 163d737e3..d03d1d7ef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_check.py @@ -203,7 +203,9 @@ def test_joiner( ) # Now test encoder_proj - joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} + joiner_encoder_proj_inputs = { + encoder_proj_input_name: encoder_out.numpy() + } joiner_encoder_proj_out = joiner_encoder_proj_session.run( [encoder_proj_output_name], joiner_encoder_proj_inputs )[0] @@ -212,10 +214,16 @@ def test_joiner( torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) assert torch.allclose( joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) + ), ( + (joiner_encoder_proj_out - torch_joiner_encoder_proj_out) + .abs() + .max() + ) # Now test decoder_proj - joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} + joiner_decoder_proj_inputs = { + decoder_proj_input_name: decoder_out.numpy() + } joiner_decoder_proj_out = joiner_decoder_proj_session.run( [decoder_proj_output_name], joiner_decoder_proj_inputs )[0] @@ -224,7 +232,11 @@ def test_joiner( torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) assert torch.allclose( joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) + ), ( + (joiner_decoder_proj_out - torch_joiner_decoder_proj_out) + .abs() + .max() + ) @torch.no_grad() @@ -276,7 +288,9 @@ def main(): if __name__ == "__main__": torch.manual_seed(20220727) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py index 11597aa49..ea5d4e674 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/onnx_pretrained.py @@ -102,12 +102,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -142,9 +140,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -192,7 +191,11 @@ def greedy_search( projected_encoder_out = joiner_encoder_proj.run( [joiner_encoder_proj.get_outputs()[0].name], - {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, + { + joiner_encoder_proj.get_inputs()[ + 0 + ].name: packed_encoder_out.data.numpy() + }, )[0] blank_id = 0 # hard-code to 0 @@ -379,7 +382,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py index 849d6cf4e..19b636a23 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/pretrained.py @@ -100,11 +100,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -129,12 +127,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -181,7 +177,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -234,9 +231,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -304,7 +302,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -391,7 +391,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py index 85d87f8f2..1e6022b57 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/scaling_converter.py @@ -234,7 +234,9 @@ def scaled_lstm_to_lstm(scaled_lstm: ScaledLSTM) -> nn.LSTM: assert lstm._flat_weights_names == scaled_lstm._flat_weights_names for idx in range(len(scaled_lstm._flat_weights_names)): - scaled_weight = scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() + scaled_weight = ( + scaled_lstm._flat_weights[idx] * scaled_lstm._scales[idx].exp() + ) lstm._flat_weights[idx].data.copy_(scaled_weight) return lstm @@ -249,10 +251,12 @@ def get_submodule(model, target): mod: torch.nn.Module = model for item in atoms: if not hasattr(mod, item): - raise AttributeError(mod._get_name() + " has no attribute `" + item + "`") + raise AttributeError( + mod._get_name() + " has no " "attribute `" + item + "`" + ) mod = getattr(mod, item) if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not an nn.Module") + raise AttributeError("`" + item + "` is not " "an nn.Module") return mod diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py index 41a712498..10bb44e00 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/streaming_decode.py @@ -52,7 +52,11 @@ from streaming_beam_search import ( from torch.nn.utils.rnn import pad_sequence from train import add_model_arguments, get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -91,11 +95,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -161,7 +163,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -269,7 +272,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -289,7 +294,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -345,7 +352,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -417,7 +426,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -450,7 +461,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -523,7 +535,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py index 598fcf344..66ffbd3ec 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/test_onnx.py @@ -90,7 +90,9 @@ def test_conv2d_subsampling(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) os.remove(filename) @@ -145,7 +147,9 @@ def test_rel_pos(): onnx_pos_emb = torch.from_numpy(onnx_pos_emb) torch_y, torch_pos_emb = jit_model(x) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) assert torch.allclose(onnx_pos_emb, torch_pos_emb, atol=1e-05), ( (onnx_pos_emb - torch_pos_emb).abs().max() @@ -193,7 +197,9 @@ def test_conformer_encoder_layer(): encoder_layer.eval() encoder_layer = convert_scaled_to_non_scaled(encoder_layer, inplace=True) - jit_model = torch.jit.trace(encoder_layer, (x, pos_emb, src_key_padding_mask)) + jit_model = torch.jit.trace( + encoder_layer, (x, pos_emb, src_key_padding_mask) + ) torch.onnx.export( encoder_layer, @@ -230,7 +236,9 @@ def test_conformer_encoder_layer(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -314,7 +322,9 @@ def test_conformer_encoder(): onnx_y = torch.from_numpy(onnx_y) torch_y = jit_model(x, pos_emb, src_key_padding_mask) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) print(onnx_y.abs().sum(), torch_y.abs().sum(), onnx_y.shape, torch_y.shape) @@ -369,7 +379,9 @@ def test_conformer(): onnx_y_lens = torch.from_numpy(onnx_y_lens) torch_y, torch_y_lens = jit_model(x, x_lens) - assert torch.allclose(onnx_y, torch_y, atol=1e-05), (onnx_y - torch_y).abs().max() + assert torch.allclose(onnx_y, torch_y, atol=1e-05), ( + (onnx_y - torch_y).abs().max() + ) assert torch.allclose(onnx_y_lens, torch_y_lens, atol=1e-05), ( (onnx_y_lens - torch_y_lens).abs().max() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index 6724343dd..44e96644a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -92,7 +92,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -161,7 +163,8 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) parser.add_argument( @@ -211,7 +214,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need " + "to be changed.", ) parser.add_argument( @@ -234,45 +238,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -671,7 +672,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -684,9 +687,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -697,7 +705,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -909,7 +919,9 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -955,7 +967,8 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -1096,7 +1109,9 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) else: cuts_musan = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py index 69cfcd298..4f043e5a6 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/decode.py @@ -197,24 +197,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -310,7 +306,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -430,7 +427,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) if ( params.decoding_method == "fast_beam_search" @@ -486,7 +485,10 @@ def decode_one_batch( nbest_scale=params.nbest_scale, return_timestamps=True, ) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): res = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -564,7 +566,9 @@ def decode_dataset( sp: spm.SentencePieceProcessor, word_table: Optional[k2.SymbolTable] = None, decoding_graph: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: +) -> Dict[ + str, List[Tuple[str, List[str], List[str], List[float], List[float]]] +]: """Decode dataset. Args: @@ -639,7 +643,9 @@ def decode_dataset( cut_ids, hyps, texts, timestamps_hyp, timestamps_ref ): ref_words = ref_text.split() - this_batch.append((cut_id, ref_words, hyp_words, time_ref, time_hyp)) + this_batch.append( + (cut_id, ref_words, hyp_words, time_ref, time_hyp) + ) results[name].extend(this_batch) @@ -648,7 +654,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -686,7 +694,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -713,7 +722,9 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) + s = "\nFor {}, symbol-delay of different settings are:\n".format( + test_set_name + ) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) @@ -762,7 +773,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -799,12 +812,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -827,12 +841,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -860,7 +875,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -887,7 +902,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py index bd5801a78..ce7518ceb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/export.py @@ -89,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -137,7 +133,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -186,12 +183,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -214,12 +212,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -247,7 +246,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -283,7 +282,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py index a28e52c78..7af9ea9b8 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/streaming_decode.py @@ -96,24 +96,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -179,7 +175,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -287,7 +284,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -307,7 +306,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -363,7 +364,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -435,7 +438,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -468,7 +473,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -541,12 +547,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -569,12 +576,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -602,7 +610,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 76785a845..cf32e565b 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -101,7 +101,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -237,45 +239,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -622,7 +621,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -662,7 +665,9 @@ def compute_loss( # If either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -675,9 +680,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -688,7 +698,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -867,7 +879,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -999,7 +1013,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 8499651d7..427b06294 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -214,7 +214,10 @@ class Conformer(EncoderInterface): (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). NOTE: the returned tensors are on the given device. """ - if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: + if ( + len(self._init_state) == 2 + and self._init_state[0].size(1) == left_context + ): # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -436,7 +439,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -454,7 +459,9 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, causal=causal + ) self.norm_final = BasicNorm(d_model) @@ -520,7 +527,9 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + conv, _ = self.conv_module( + src, src_key_padding_mask=src_key_padding_mask + ) src = src + self.dropout(conv) # feed forward module @@ -793,7 +802,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -809,7 +820,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -835,7 +848,9 @@ class RelPositionalEncoding(torch.nn.Module): pe = torch.cat([pe_positive, pe_negative], dim=1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]: + def forward( + self, x: torch.Tensor, left_context: int = 0 + ) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -1103,9 +1118,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -1174,25 +1189,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1230,15 +1253,23 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul( + q_with_bias_v, p + ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -1279,17 +1310,21 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) - else: - # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + combined_mask = attn_mask | key_padding_mask.unsqueeze( 1 ).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze( + 0 + ) | key_padding_mask.unsqueeze(1).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.masked_fill( + combined_mask, 0.0 + ) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1301,9 +1336,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1442,12 +1481,16 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert not self.training, "Cache should be None in training time" + assert ( + not self.training + ), "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : (-right_context), # noqa + -(self.lorder + right_context) : ( # noqa + -right_context + ), ..., ] else: @@ -1623,7 +1666,9 @@ class RandomCombine(nn.Module): self.stddev = stddev self.final_log_weight = ( - torch.tensor((final_weight / (1 - final_weight)) * (self.num_inputs - 1)) + torch.tensor( + (final_weight / (1 - final_weight)) * (self.num_inputs - 1) + ) .log() .item() ) @@ -1720,14 +1765,16 @@ class RandomCombine(nn.Module): # final contains self.num_inputs - 1 in all elements final = torch.full((num_frames,), self.num_inputs - 1, device=device) # nonfinal contains random integers in [0..num_inputs - 2], these are for non-final weights. - nonfinal = torch.randint(self.num_inputs - 1, (num_frames,), device=device) + nonfinal = torch.randint( + self.num_inputs - 1, (num_frames,), device=device + ) indexes = torch.where( torch.rand(num_frames, device=device) < final_prob, final, nonfinal ) - ans = torch.nn.functional.one_hot(indexes, num_classes=self.num_inputs).to( - dtype=dtype - ) + ans = torch.nn.functional.one_hot( + indexes, num_classes=self.num_inputs + ).to(dtype=dtype) return ans def _get_random_mixed_weights( @@ -1757,8 +1804,7 @@ class RandomCombine(nn.Module): def _test_random_combine(final_weight: float, pure_prob: float, stddev: float): print( - f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}," - f" stddev={stddev}" + f"_test_random_combine: final_weight={final_weight}, pure_prob={pure_prob}, stddev={stddev}" ) num_inputs = 3 num_channels = 50 diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py index f462cc42f..22bcdd88e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/decode.py @@ -179,24 +179,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -307,7 +303,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -480,7 +477,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -546,7 +545,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -694,7 +696,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -727,7 +731,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -782,7 +787,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -821,12 +828,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -849,12 +857,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -882,7 +891,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -928,7 +937,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py index a739c17bc..b2e5b430e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/export.py @@ -89,24 +89,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -137,7 +133,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -184,12 +181,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -212,12 +210,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -245,7 +244,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -281,7 +280,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py index e2da0da4c..1e100fcbd 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -89,11 +89,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -118,12 +116,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -170,7 +166,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -201,9 +198,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -266,11 +264,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -342,7 +344,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 59a0e8fa2..6fee9483e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -96,24 +96,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -179,7 +175,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -287,7 +284,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -307,7 +306,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -363,7 +364,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -435,7 +438,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -468,7 +473,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -541,12 +547,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -569,12 +576,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -602,7 +610,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 75696d61b..179d9372e 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -89,7 +89,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -246,7 +248,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need " + "to be changed.", ) parser.add_argument( @@ -269,45 +272,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -645,7 +645,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -686,7 +690,9 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -699,9 +705,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training @@ -712,7 +723,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -895,7 +908,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1008,7 +1023,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1030,7 +1045,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py index 40ad61fd4..53788b3f7 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/conformer.py @@ -90,7 +90,10 @@ class Conformer(EncoderInterface): output_layers = [] if middle_output_layer is not None: - assert middle_output_layer >= 0 and middle_output_layer < num_encoder_layers + assert ( + middle_output_layer >= 0 + and middle_output_layer < num_encoder_layers + ) output_layers.append(middle_output_layer) # The last layer is always needed. @@ -175,7 +178,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -357,7 +362,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -372,7 +379,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(1) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -647,9 +656,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -718,25 +727,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -773,7 +790,9 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d matrix_bd = torch.matmul( @@ -781,9 +800,13 @@ class RelPositionMultiheadAttention(nn.Module): ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -817,9 +840,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -842,7 +869,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py index 600aa9b39..74df04006 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/decode.py @@ -120,24 +120,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -212,7 +208,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -270,7 +267,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - layer_results, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + layer_results, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) encoder_out = layer_results[-1] hyps = [] @@ -286,7 +285,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -332,7 +334,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -405,7 +411,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -438,7 +446,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -481,7 +490,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -513,12 +524,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -541,12 +553,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -574,7 +587,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py index 17f8614dc..cff9c7377 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/export.py @@ -51,7 +51,11 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -83,11 +87,9 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -118,7 +120,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -157,7 +160,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -205,7 +209,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py index 86cf34877..21409287c 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/extract_codebook_index.py @@ -21,10 +21,9 @@ import os from pathlib import Path import torch +from vq_utils import CodebookIndexExtractor from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned -from vq_utils import CodebookIndexExtractor - from icefall.utils import AttributeDict, str2bool diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py index b8440f90a..49b557814 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_decode.py @@ -23,6 +23,7 @@ from pathlib import Path from typing import Dict, List, Tuple import torch + from asr_datamodule import LibriSpeechAsrDataModule from hubert_xlarge import HubertXlargeFineTuned @@ -98,7 +99,9 @@ def decode_dataset( if batch_idx % 20 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -121,7 +124,9 @@ def save_results( ) test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.res_dir / f"wer-summary-{test_set_name}.txt" @@ -150,7 +155,9 @@ def main(): # reset some parameters needed by hubert. params.update(HubertXlargeFineTuned.get_params()) - params.res_dir = params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" + params.res_dir = ( + params.exp_dir / f"ctc_greedy_search-{params.teacher_model_id}" + ) setup_logger(f"{params.res_dir}/log/log-ctc_greedy_search") logging.info("Decoding started") @@ -183,7 +190,9 @@ def main(): params=params, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py index 4f9417c9f..55ce7b00d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/hubert_xlarge.py @@ -22,7 +22,11 @@ from pathlib import Path from typing import Dict, List, Tuple import torch -from fairseq import checkpoint_utils, tasks, utils +from fairseq import ( + checkpoint_utils, + tasks, + utils, +) from fairseq.data.data_utils import post_process from omegaconf import OmegaConf @@ -47,7 +51,9 @@ def _load_hubert_model(params: AttributeDict): "data": str(params.hubert_model_dir), } ) - model_path = Path(params.hubert_model_dir) / (params.teacher_model_id + ".pt") + model_path = Path(params.hubert_model_dir) / ( + params.teacher_model_id + ".pt" + ) task = tasks.setup_task(cfg_task) processor = task.target_dictionary models, saved_cfg = checkpoint_utils.load_model_ensemble( @@ -145,7 +151,9 @@ class HubertXlargeFineTuned: supervisions = batch["supervisions"] num_samples = supervisions["num_samples"] B, T = features.shape - padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape([-1, 1]) + padding_mask = torch.arange(0, T).expand(B, T) > num_samples.reshape( + [-1, 1] + ) padding_mask = padding_mask.to(self.params.device) features = features.to(self.params.device) @@ -155,7 +163,9 @@ class HubertXlargeFineTuned: features = features.transpose(1, 2) features = self.w2v_model.layer_norm(features) - padding_mask = self.w2v_model.forward_padding_mask(features, padding_mask) + padding_mask = self.w2v_model.forward_padding_mask( + features, padding_mask + ) if self.w2v_model.post_extract_proj is not None: features = self.w2v_model.post_extract_proj(features) @@ -202,7 +212,9 @@ class HubertXlargeFineTuned: toks = encoder_out.argmax(dim=-1) blank = 0 toks = [tok.unique_consecutive() for tok in toks] - hyps = [self.processor.string(tok[tok != blank].int().cpu()) for tok in toks] + hyps = [ + self.processor.string(tok[tok != blank].int().cpu()) for tok in toks + ] hyps = [post_process(hyp, "letter") for hyp in hyps] return hyps diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py index daadb70c9..7716d19cf 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/model.py @@ -69,7 +69,9 @@ class Transducer(nn.Module): self.decoder = decoder self.joiner = joiner - self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) + self.simple_am_proj = ScaledLinear( + encoder_dim, vocab_size, initial_speed=0.5 + ) self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) from icefall import is_module_available @@ -178,7 +180,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens @@ -233,7 +237,9 @@ class Transducer(nn.Module): return (simple_loss, pruned_loss, codebook_loss) @staticmethod - def concat_successive_codebook_indexes(middle_layer_output, codebook_indexes): + def concat_successive_codebook_indexes( + middle_layer_output, codebook_indexes + ): # Output rate of hubert is 50 frames per second, # while that of current encoder is 25. # Following code handling two issues: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py index be54ff0ce..f717d85fb 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/train.py @@ -101,7 +101,9 @@ from icefall.utils import ( str2bool, ) -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -201,45 +203,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -570,7 +569,9 @@ def save_checkpoint( def extract_codebook_indexes(batch): cuts = batch["supervisions"]["cut"] # -100 is identical to ignore_value in CE loss computation. - cuts_pre_mixed = [c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts] + cuts_pre_mixed = [ + c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts + ] codebook_indexes, codebook_indexes_lens = collate_custom_field( cuts_pre_mixed, "codebook_indexes", pad_value=-100 ) @@ -603,7 +604,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -650,7 +655,9 @@ def compute_loss( # If the batch contains more than 10 utterances AND # if either all simple_loss or pruned_loss is inf or nan, # we stop the training process by raising an exception - if torch.all(~simple_loss_is_finite) or torch.all(~pruned_loss_is_finite): + if torch.all(~simple_loss_is_finite) or torch.all( + ~pruned_loss_is_finite + ): raise ValueError( "There are too many utterances in this batch " "leading to inf or nan losses." @@ -663,9 +670,14 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss if is_training and params.enable_distillation: assert codebook_loss is not None loss += params.codebook_loss_scale * codebook_loss @@ -678,7 +690,9 @@ def compute_loss( # (1) The acutal subsampling factor is ((lens - 1) // 2 - 1) // 2 # (2) If some utterances in the batch lead to inf/nan loss, they # are filtered out. - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # `utt_duration` and `utt_pad_proportion` would be normalized by `utterances` # noqa info["utterances"] = feature.size(0) @@ -859,7 +873,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -991,7 +1007,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False diff --git a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py index 40f97f662..47cf2b14b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless6/vq_utils.py @@ -68,7 +68,9 @@ class CodebookIndexExtractor: def init_dirs(self): # vq_dir is the root dir for quantization, containing: # training data, trained quantizer, and extracted codebook indexes - self.vq_dir = self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" + self.vq_dir = ( + self.params.exp_dir / f"vq/{self.params.teacher_model_id}/" + ) self.vq_dir.mkdir(parents=True, exist_ok=True) # manifest_dir contains: @@ -206,7 +208,9 @@ class CodebookIndexExtractor: start = cur_offset % (data.shape[0] + 1 - B) end = start + B cur_offset += B - yield data[start:end, :].to(self.params.device).to(dtype=torch.float) + yield data[start:end, :].to(self.params.device).to( + dtype=torch.float + ) for x in minibatch_generator(train, repeat=True): trainer.step(x) @@ -223,11 +227,10 @@ class CodebookIndexExtractor: """ for subset in self.params.subsets: logging.info(f"About to split {subset}.") - ori_manifest = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" - split_cmd = ( - "lhotse split" - f" {self.params.world_size} {ori_manifest} {self.manifest_dir}" + ori_manifest = ( + f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" ) + split_cmd = f"lhotse split {self.params.world_size} {ori_manifest} {self.manifest_dir}" os.system(f"{split_cmd}") def join_manifests(self): @@ -237,13 +240,16 @@ class CodebookIndexExtractor: logging.info("Start to join manifest files.") for subset in self.params.subsets: vq_manifest_path = ( - self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) ori_manifest_path = ( - self.ori_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.ori_manifest_dir + / f"librispeech_cuts_train-{subset}.jsonl.gz" ) dst_vq_manifest_path = ( - self.dst_manifest_dir / f"librispeech_cuts_train-{subset}.jsonl.gz" + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}.jsonl.gz" ) cuts_vq = load_manifest(vq_manifest_path) cuts_ori = load_manifest(ori_manifest_path) @@ -263,7 +269,8 @@ class CodebookIndexExtractor: for subset in self.params.subsets: vq_manifests = f"{self.manifest_dir}/with_codebook_indexes-librispeech-cuts_train-{subset}*.jsonl.gz" dst_vq_manifest = ( - self.dst_manifest_dir / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" + self.dst_manifest_dir + / f"librispeech_cuts_train-{subset}-vq.jsonl.gz" ) if 1 == self.params.world_size: merge_cmd = f"cp {vq_manifests} {dst_vq_manifest}" @@ -323,7 +330,9 @@ class CodebookIndexExtractor: def load_ori_dl(self, subset): if self.params.world_size == 1: - ori_manifest_path = f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + ori_manifest_path = ( + f"./data/fbank/librispeech_cuts_train-{subset}.jsonl.gz" + ) else: ori_manifest_path = ( self.manifest_dir diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py index fa8144935..06c5863f1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode.py @@ -164,24 +164,20 @@ def get_parser(): "--avg", type=int, default=9, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -276,7 +272,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -396,7 +393,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -455,7 +454,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -586,7 +588,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -619,7 +623,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -674,7 +679,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -711,12 +718,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -739,12 +747,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -772,7 +781,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -799,7 +808,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py index 5f90e6375..712dc8ce1 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decoder.py @@ -69,7 +69,7 @@ class Decoder(nn.Module): out_channels=decoder_dim, kernel_size=context_size, padding=0, - groups=decoder_dim // 4, # group size == 4 + groups=decoder_dim//4, # group size == 4 bias=False, ) @@ -91,7 +91,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py index 43ac658e5..5744ea3ea 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/export.py @@ -129,24 +129,20 @@ def get_parser(): "--avg", type=int, default=9, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -180,7 +176,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -218,12 +215,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -246,12 +244,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -279,7 +278,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -317,7 +316,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py index c94a34d58..e2405d5ef 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/jit_pretrained.py @@ -69,12 +69,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -95,9 +93,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -268,7 +267,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py index 3ddac2cf2..7d8de5afe 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/joiner.py @@ -56,7 +56,9 @@ class Joiner(nn.Module): assert encoder_out.shape[:-1] == decoder_out.shape[:-1] if project_input: - logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) + logit = self.encoder_proj(encoder_out) + self.decoder_proj( + decoder_out + ) else: logit = encoder_out + decoder_out diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py index 0e59b0f2f..53cde6c6f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/model.py @@ -15,15 +15,14 @@ # limitations under the License. -import random - import k2 import torch import torch.nn as nn +import random from encoder_interface import EncoderInterface -from scaling import penalize_abs_values_gt from icefall.utils import add_sos +from scaling import penalize_abs_values_gt class Transducer(nn.Module): @@ -66,8 +65,7 @@ class Transducer(nn.Module): self.joiner = joiner self.simple_am_proj = nn.Linear( - encoder_dim, - vocab_size, + encoder_dim, vocab_size, ) self.simple_lm_proj = nn.Linear(decoder_dim, vocab_size) @@ -135,16 +133,18 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens lm = self.simple_lm_proj(decoder_out) am = self.simple_am_proj(encoder_out) - # if self.training and random.random() < 0.25: + #if self.training and random.random() < 0.25: # lm = penalize_abs_values_gt(lm, 100.0, 1.0e-04) - # if self.training and random.random() < 0.25: + #if self.training and random.random() < 0.25: # am = penalize_abs_values_gt(am, 30.0, 1.0e-04) with torch.cuda.amp.autocast(enabled=False): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py index 460ac2c3e..bb8b0a0e3 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/optim.py @@ -14,17 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import logging -import random from collections import defaultdict -from typing import List, Optional, Tuple, Union - -import torch +from typing import List, Optional, Union, Tuple, List from lhotse.utils import fix_random_seed +import torch from scaling import ActivationBalancer +import random from torch import Tensor from torch.optim import Optimizer +import logging +import contextlib + class BatchedOptimizer(Optimizer): @@ -37,10 +37,11 @@ class BatchedOptimizer(Optimizer): Args: params: """ - def __init__(self, params, defaults): super(BatchedOptimizer, self).__init__(params, defaults) + + @contextlib.contextmanager def batched_params(self, param_group): """ @@ -72,9 +73,7 @@ class BatchedOptimizer(Optimizer): group: a parameter group, which is a list of parameters; should be one of self.groups. """ - batches = defaultdict( - list - ) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter + batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter for p in param_group: key = (str(p.dtype), *p.shape) @@ -83,7 +82,7 @@ class BatchedOptimizer(Optimizer): stacked_params_dict = dict() # turn batches into a list, in deterministic order. - batches = [batches[key] for key in sorted(batches.keys())] + batches = [ batches[key] for key in sorted(batches.keys()) ] # pairs will contain pairs of (stacked_param, state), one for each batch # in `batches`. pairs = [] @@ -95,78 +94,77 @@ class BatchedOptimizer(Optimizer): # group. class Optimizer will take care of saving/loading state. state = self.state[p] p_stacked = torch.stack(batch) - grad = torch.stack( - [torch.zeros_like(p) if p.grad is None else p.grad for p in batch] - ) + grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch ]) p_stacked.grad = grad stacked_params_dict[key] = p_stacked pairs.append((p_stacked, state)) - yield pairs # <-- calling code will do the actual optimization here! + yield pairs # <-- calling code will do the actual optimization here! for ((stacked_params, _state), batch) in zip(pairs, batches): for i, p in enumerate(batch): # batch is list of Parameter p.copy_(stacked_params[i]) + class ScaledAdam(BatchedOptimizer): """ - Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update - proportional to the norm of that parameter; and also learn the scale of the parameter, - in log space, subject to upper and lower limits (as if we had factored each parameter as - param = underlying_param * log_scale.exp()) + Implements 'Scaled Adam', a variant of Adam where we scale each parameter's update + proportional to the norm of that parameter; and also learn the scale of the parameter, + in log space, subject to upper and lower limits (as if we had factored each parameter as + param = underlying_param * log_scale.exp()) - Args: - params: The parameters or param_groups to optimize (like other Optimizer subclasses) - lr: The learning rate. We will typically use a learning rate schedule that starts - at 0.03 and decreases over time, i.e. much higher than other common - optimizers. - clipping_scale: (e.g. 2.0) - A scale for gradient-clipping: if specified, the normalized gradients - over the whole model will be clipped to have 2-norm equal to - `clipping_scale` times the median 2-norm over the most recent period - of `clipping_update_period` minibatches. By "normalized gradients", - we mean after multiplying by the rms parameter value for this tensor - [for non-scalars]; this is appropriate because our update is scaled - by this quantity. - betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. - Must satisfy 0 < beta <= beta2 < 1. - scalar_lr_scale: A scaling factor on the learning rate, that we use to update the - scale of each parameter tensor and scalar parameters of the mode.. - If each parameter were decomposed - as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale - would be a the scaling factor on the learning rate of p_scale. - eps: A general-purpose epsilon to prevent division by zero - param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be >= this value) - param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of - learning the scale on the parameters (we'll constrain the rms of each non-scalar - parameter tensor to be <= this value) - scalar_max: Maximum absolute value for scalar parameters (applicable if your - model has any parameters with numel() == 1). - size_update_period: The periodicity, in steps, with which we update the size (scale) - of the parameter tensor. This is provided to save a little time - in the update. - clipping_update_period: if clipping_scale is specified, this is the period + Args: + params: The parameters or param_groups to optimize (like other Optimizer subclasses) + lr: The learning rate. We will typically use a learning rate schedule that starts + at 0.03 and decreases over time, i.e. much higher than other common + optimizers. + clipping_scale: (e.g. 2.0) + A scale for gradient-clipping: if specified, the normalized gradients + over the whole model will be clipped to have 2-norm equal to + `clipping_scale` times the median 2-norm over the most recent period + of `clipping_update_period` minibatches. By "normalized gradients", + we mean after multiplying by the rms parameter value for this tensor + [for non-scalars]; this is appropriate because our update is scaled + by this quantity. + betas: beta1,beta2 are momentum constants for regular momentum, and moving sum-sq grad. + Must satisfy 0 < beta <= beta2 < 1. + scalar_lr_scale: A scaling factor on the learning rate, that we use to update the + scale of each parameter tensor and scalar parameters of the mode.. + If each parameter were decomposed + as p * p_scale.exp(), where (p**2).mean().sqrt() == 1.0, scalar_lr_scale + would be a the scaling factor on the learning rate of p_scale. + eps: A general-purpose epsilon to prevent division by zero + param_min_rms: Minimum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be >= this value) + param_max_rms: Maximum root-mean-square value of parameter tensor, for purposes of + learning the scale on the parameters (we'll constrain the rms of each non-scalar + parameter tensor to be <= this value) + scalar_max: Maximum absolute value for scalar parameters (applicable if your + model has any parameters with numel() == 1). + size_update_period: The periodicity, in steps, with which we update the size (scale) + of the parameter tensor. This is provided to save a little time + in the update. + clipping_update_period: if clipping_scale is specified, this is the period """ - def __init__( - self, - params, - lr=3e-02, - clipping_scale=None, - betas=(0.9, 0.98), - scalar_lr_scale=0.1, - eps=1.0e-08, - param_min_rms=1.0e-05, - param_max_rms=3.0, - scalar_max=10.0, - size_update_period=4, - clipping_update_period=100, + self, + params, + lr=3e-02, + clipping_scale=None, + betas=(0.9, 0.98), + scalar_lr_scale=0.1, + eps=1.0e-08, + param_min_rms=1.0e-05, + param_max_rms=3.0, + scalar_max=10.0, + size_update_period=4, + clipping_update_period=100, ): + defaults = dict( lr=lr, clipping_scale=clipping_scale, @@ -185,6 +183,7 @@ class ScaledAdam(BatchedOptimizer): def __setstate__(self, state): super(ScaledAdam, self).__setstate__(state) + @torch.no_grad() def step(self, closure=None): """Performs a single optimization step. @@ -207,9 +206,7 @@ class ScaledAdam(BatchedOptimizer): # a regular parameter, and will have a .grad, but the 1st dim corresponds to # a stacking dim, it is not a real dim. - if ( - len(batches[0][1]) == 0 - ): # if len(first state) == 0: not yet initialized + if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized clipping_scale = 1 else: clipping_scale = self._get_clipping_scale(group, batches) @@ -228,9 +225,13 @@ class ScaledAdam(BatchedOptimizer): self._step_one_batch(group, p, state, clipping_scale) + return loss - def _init_state(self, group: dict, p: Tensor, state: dict): + def _init_state(self, + group: dict, + p: Tensor, + state: dict): """ Initializes state dict for parameter 'p'. Assumes that dim 0 of tensor p is actually the batch dimension, corresponding to batched-together @@ -246,7 +247,7 @@ class ScaledAdam(BatchedOptimizer): state["step"] = 0 - kwargs = {"device": p.device, "dtype": p.dtype} + kwargs = {'device':p.device, 'dtype':p.dtype} # 'delta' implements conventional momentum. There are # several different kinds of update going on, so rather than @@ -254,30 +255,36 @@ class ScaledAdam(BatchedOptimizer): # parameter-change "delta", which combines all forms of # update. this is equivalent to how it's done in Adam, # except for the first few steps. - state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["delta"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) batch_size = p.shape[0] numel = p.numel() // batch_size numel = p.numel() + if numel > 1: # "param_rms" just periodically records the scalar root-mean-square value of # the parameter tensor. # it has a shape like (batch_size, 1, 1, 1, 1) - param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() + param_rms = (p**2).mean(dim=list(range(1, p.ndim)), + keepdim=True).sqrt() state["param_rms"] = param_rms state["scale_exp_avg_sq"] = torch.zeros_like(param_rms) - state["scale_grads"] = torch.zeros( - size_update_period, *param_rms.shape, **kwargs - ) + state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, + **kwargs) + # exp_avg_sq is the weighted sum of scaled gradients. as in Adam. - state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg_sq"] = torch.zeros_like( + p, memory_format=torch.preserve_format + ) - def _get_clipping_scale( - self, group: dict, pairs: List[Tuple[Tensor, dict]] - ) -> float: + def _get_clipping_scale(self, + group: dict, + pairs: List[Tuple[Tensor, dict]]) -> float: """ Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients by this amount before applying the rest of the update. @@ -307,67 +314,57 @@ class ScaledAdam(BatchedOptimizer): if p.numel() == p.shape[0]: # a batch of scalars tot_sumsq += (grad**2).sum() # sum() to change shape [1] to [] else: - tot_sumsq += ((grad * state["param_rms"]) ** 2).sum() + tot_sumsq += ((grad * state["param_rms"])**2).sum() tot_norm = tot_sumsq.sqrt() - if "model_norms" not in first_state: - first_state["model_norms"] = torch.zeros( - clipping_update_period, device=p.device - ) + if not "model_norms" in first_state: + first_state["model_norms"] = torch.zeros(clipping_update_period, + device=p.device) first_state["model_norms"][step % clipping_update_period] = tot_norm if step % clipping_update_period == 0: # Print some stats. # We don't reach here if step == 0 because we would have returned # above. - sorted_norms = first_state["model_norms"].sort()[0].to("cpu") + sorted_norms = first_state["model_norms"].sort()[0].to('cpu') quartiles = [] for n in range(0, 5): - index = min( - clipping_update_period - 1, - (clipping_update_period // 4) * n, - ) + index = min(clipping_update_period - 1, + (clipping_update_period // 4) * n) quartiles.append(sorted_norms[index].item()) median = quartiles[2] threshold = clipping_scale * median first_state["model_norm_threshold"] = threshold - percent_clipped = ( - first_state["num_clipped"] * 100.0 / clipping_update_period - if "num_clipped" in first_state - else 0.0 - ) + percent_clipped = (first_state["num_clipped"] * 100.0 / clipping_update_period + if "num_clipped" in first_state else 0.0) first_state["num_clipped"] = 0 - quartiles = " ".join(["%.3e" % x for x in quartiles]) - logging.info( - f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " - f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}" - ) + quartiles = ' '.join([ '%.3e' % x for x in quartiles ]) + logging.info(f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, " + f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}") if step < clipping_update_period: return 1.0 # We have not yet estimated a norm to clip to. else: try: model_norm_threshold = first_state["model_norm_threshold"] - except KeyError: - logging.info( - "Warning: model_norm_threshold not in state: possibly " - "you changed config when restarting, adding clipping_scale option?" - ) + except: + logging.info("Warning: model_norm_threshold not in state: possibly " + "you changed config when restarting, adding clipping_scale option?") return 1.0 - ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item()) + ans = min(1.0,(model_norm_threshold / (tot_norm + 1.0e-20)).item()) if ans < 1.0: first_state["num_clipped"] += 1 if ans < 0.1: - logging.warn( - f"Scaling gradients by {ans}," - f" model_norm_threshold={model_norm_threshold}" - ) + logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}") return ans - def _step_one_batch( - self, group: dict, p: Tensor, state: dict, clipping_scale: float - ): + + def _step_one_batch(self, + group: dict, + p: Tensor, + state: dict, + clipping_scale: float): """ Do the step for one parameter, which is actually going to be a batch of `real` parameters, with dim 0 as the batch dim. @@ -394,18 +391,17 @@ class ScaledAdam(BatchedOptimizer): # Update the size/scale of p, and set param_rms scale_grads = state["scale_grads"] scale_grads[step % size_update_period] = (p * grad).sum( - dim=list(range(1, p.ndim)), keepdim=True - ) + dim=list(range(1, p.ndim)), keepdim=True) if step % size_update_period == size_update_period - 1: param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..) - param_rms.copy_( - (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt() - ) + param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), + keepdim=True).sqrt()) if step > 0: # self._size_update() learns the overall scale on the # parameter, by shrinking or expanding it. self._size_update(group, scale_grads, p, state) + if numel == 1: # For parameters with 1 element we just use regular Adam. # Updates delta. @@ -415,21 +411,24 @@ class ScaledAdam(BatchedOptimizer): state["step"] = step + 1 - def _size_update( - self, group: dict, scale_grads: Tensor, p: Tensor, state: dict - ) -> None: - """ - Called only where p.numel() > 1, this updates the scale of the parameter. - If we imagine: p = underlying_param * scale.exp(), and we are doing - gradient descent on underlying param and on scale, this function does the update - on `scale`. - Args: - group: dict to look up configuration values - scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing - grads w.r.t. the scales. - p: The parameter to update - state: The state-dict of p + def _size_update(self, + group: dict, + scale_grads: Tensor, + p: Tensor, + state: dict) -> None: + """ + Called only where p.numel() > 1, this updates the scale of the parameter. + If we imagine: p = underlying_param * scale.exp(), and we are doing + gradient descent on underlying param and on scale, this function does the update + on `scale`. + + Args: + group: dict to look up configuration values + scale_grads: a tensor of shape (size_update_period, batch_size, 1, 1,...) containing + grads w.r.t. the scales. + p: The parameter to update + state: The state-dict of p """ param_rms = state["param_rms"] @@ -444,28 +443,25 @@ class ScaledAdam(BatchedOptimizer): size_update_period = scale_grads.shape[0] # correct beta2 for the size update period: we will have # faster decay at this level. - beta2_corr = beta2**size_update_period + beta2_corr = beta2 ** size_update_period scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..) scale_exp_avg_sq.mul_(beta2_corr).add_( - (scale_grads**2).mean(dim=0), # mean over dim `size_update_period` - alpha=1 - beta2_corr, - ) # shape is (batch_size, 1, 1, ...) + (scale_grads ** 2).mean(dim=0), # mean over dim `size_update_period` + alpha=1-beta2_corr) # shape is (batch_size, 1, 1, ...) # The 1st time we reach here is when size_step == 1. size_step = (step + 1) // size_update_period - bias_correction2 = 1 - beta2_corr**size_step + bias_correction2 = 1 - beta2_corr ** size_step # we don't bother with bias_correction1; this will help prevent divergence # at the start of training. denom = scale_exp_avg_sq.sqrt() + eps - scale_step = ( - -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom - ) + scale_step = -size_lr * (bias_correction2 ** 0.5) * scale_grads.sum(dim=0) / denom - is_too_small = param_rms < param_min_rms - is_too_large = param_rms > param_max_rms + is_too_small = (param_rms < param_min_rms) + is_too_large = (param_rms > param_max_rms) # when the param gets too small, just don't shrink it any further. scale_step.masked_fill_(is_too_small, 0.0) @@ -473,9 +469,13 @@ class ScaledAdam(BatchedOptimizer): scale_step.masked_fill_(is_too_large, -size_lr * size_update_period) delta = state["delta"] # the factor of (1-beta1) relates to momentum. - delta.add_(p * scale_step, alpha=(1 - beta1)) + delta.add_(p * scale_step, alpha=(1-beta1)) - def _step(self, group: dict, p: Tensor, state: dict): + + def _step(self, + group: dict, + p: Tensor, + state: dict): """ This function does the core update of self.step(), in the case where the members of the batch have more than 1 element. @@ -496,7 +496,8 @@ class ScaledAdam(BatchedOptimizer): step = state["step"] exp_avg_sq = state["exp_avg_sq"] - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, + value=(1-beta2)) this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0) bias_correction2 = 1 - beta2 ** (this_step + 1) @@ -508,13 +509,17 @@ class ScaledAdam(BatchedOptimizer): denom += eps grad = grad / denom - alpha = -lr * (1 - beta1) * state["param_rms"].clamp(min=param_min_rms) + alpha = -lr * (1-beta1) * state["param_rms"].clamp(min=param_min_rms) delta = state["delta"] delta.add_(grad * alpha) p.add_(delta) - def _step_scalar(self, group: dict, p: Tensor, state: dict): + + def _step_scalar(self, + group: dict, + p: Tensor, + state: dict): """ A simplified form of the core update for scalar tensors, where we cannot get a good estimate of the parameter rms. @@ -526,7 +531,8 @@ class ScaledAdam(BatchedOptimizer): grad = p.grad exp_avg_sq = state["exp_avg_sq"] # shape: (batch_size,) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, + value=1-beta2) # bias_correction2 is like in Adam. Don't bother with bias_correction1; # slower update at the start will help stability anyway. @@ -534,11 +540,12 @@ class ScaledAdam(BatchedOptimizer): denom = (exp_avg_sq / bias_correction2).sqrt() + eps delta = state["delta"] - delta.add_(grad / denom, alpha=-lr * (1 - beta1)) + delta.add_(grad / denom, alpha=-lr*(1-beta1)) p.clamp_(min=-scalar_max, max=scalar_max) p.add_(delta) + class LRScheduler(object): """ Base-class for learning rate schedulers where the learning-rate depends on both the @@ -548,14 +555,18 @@ class LRScheduler(object): def __init__(self, optimizer: Optimizer, verbose: bool = False): # Attach optimizer if not isinstance(optimizer, Optimizer): - raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) self.optimizer = optimizer self.verbose = verbose for group in optimizer.param_groups: group.setdefault("base_lr", group["lr"]) - self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] + self.base_lrs = [ + group["base_lr"] for group in optimizer.param_groups + ] self.epoch = 0 self.batch = 0 @@ -669,15 +680,13 @@ class Eden(LRScheduler): def get_lr(self): factor = ( - (self.batch**2 + self.lr_batches**2) / self.lr_batches**2 + (self.batch ** 2 + self.lr_batches ** 2) / self.lr_batches ** 2 ) ** -0.25 * ( - ((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 - ) - warmup_factor = ( - 1.0 - if self.batch >= self.warmup_batches - else 0.5 + 0.5 * (self.batch / self.warmup_batches) + ((self.epoch ** 2 + self.lr_epochs ** 2) / self.lr_epochs ** 2) + ** -0.25 ) + warmup_factor = (1.0 if self.batch >= self.warmup_batches + else 0.5 + 0.5 * (self.batch / self.warmup_batches)) return [x * factor * warmup_factor for x in self.base_lrs] @@ -736,14 +745,13 @@ class Eve(Optimizer): parameters, if they fall below this we will stop applying weight decay. - .. _Adam: A Method for Stochastic Optimization: + .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__( self, params, @@ -758,11 +766,17 @@ class Eve(Optimizer): if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: - raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + raise ValueError( + "Invalid beta parameter at index 0: {}".format(betas[0]) + ) if not 0.0 <= betas[1] < 1.0: - raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + raise ValueError( + "Invalid beta parameter at index 1: {}".format(betas[1]) + ) if not 0 <= weight_decay <= 0.1: - raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + raise ValueError( + "Invalid weight_decay value: {}".format(weight_decay) + ) if not 0 < target_rms <= 10.0: raise ValueError("Invalid target_rms value: {}".format(target_rms)) defaults = dict( @@ -798,7 +812,9 @@ class Eve(Optimizer): # Perform optimization step grad = p.grad if grad.is_sparse: - raise RuntimeError("AdamW does not support sparse gradients") + raise RuntimeError( + "AdamW does not support sparse gradients" + ) state = self.state[p] @@ -825,7 +841,7 @@ class Eve(Optimizer): # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) - denom = (exp_avg_sq.sqrt() * (bias_correction2**-0.5)).add_( + denom = (exp_avg_sq.sqrt() * (bias_correction2 ** -0.5)).add_( group["eps"] ) @@ -836,31 +852,30 @@ class Eve(Optimizer): if p.numel() > 1: # avoid applying this weight-decay on "scaling factors" # (which are scalar). - is_above_target_rms = p.norm() > (target_rms * (p.numel() ** 0.5)) + is_above_target_rms = p.norm() > ( + target_rms * (p.numel() ** 0.5) + ) p.mul_(1 - (weight_decay * is_above_target_rms)) p.addcdiv_(exp_avg, denom, value=-step_size) if random.random() < 0.0005: - step = (exp_avg / denom) * step_size - logging.info( - f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}" - ) + step = (exp_avg/denom) * step_size + logging.info(f"Delta rms = {(step**2).mean().item()}, shape = {step.shape}") + return loss def _test_scaled_adam(hidden_dim: int): import timeit - from scaling import ScaledLinear - E = 100 B = 4 T = 2 logging.info("in test_eve_cain") - # device = torch.device('cuda') - device = torch.device("cpu") + #device = torch.device('cuda') + device = torch.device('cpu') dtype = torch.float32 fix_random_seed(42) @@ -874,93 +889,79 @@ def _test_scaled_adam(hidden_dim: int): fix_random_seed(42) Linear = torch.nn.Linear if iter == 0 else ScaledLinear - m = torch.nn.Sequential( - Linear(E, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, hidden_dim), - torch.nn.PReLU(), - Linear(hidden_dim, E), - ).to(device) + m = torch.nn.Sequential(Linear(E, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, hidden_dim), + torch.nn.PReLU(), + Linear(hidden_dim, E), + ).to(device) - train_pairs = [ - ( - 100.0 - * torch.randn(B, T, E, device=device, dtype=dtype) - * input_magnitudes, - torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes, - ) - for _ in range(20) - ] + train_pairs = [ (100.0 * torch.randn(B, T, E, device=device, dtype=dtype) * input_magnitudes, + torch.randn(B, T, E, device=device, dtype=dtype) * output_magnitudes) for _ in range(20) ] - if iter == 0: - optim = Eve(m.parameters(), lr=0.003) - elif iter == 1: - optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) + if iter == 0: optim = Eve(m.parameters(), lr=0.003) + elif iter == 1: optim = ScaledAdam(m.parameters(), lr=0.03, clipping_scale=2.0) scheduler = Eden(optim, lr_batches=200, lr_epochs=5, verbose=False) + start = timeit.default_timer() avg_loss = 0.0 for epoch in range(180): scheduler.step_epoch() - # if epoch == 100 and iter in [2,3]: + #if epoch == 100 and iter in [2,3]: # optim.reset_speedup() # check it doesn't crash. - # if epoch == 130: + #if epoch == 130: # opts = diagnostics.TensorDiagnosticOptions( # 2 ** 22 # ) # allow 4 megabytes per sub-module # diagnostic = diagnostics.attach_diagnostics(m, opts) - for n, (x, y) in enumerate(train_pairs): + + for n, (x,y) in enumerate(train_pairs): y_out = m(x) - loss = ((y_out - y) ** 2).mean() * 100.0 + loss = ((y_out - y)**2).mean() * 100.0 if epoch == 0 and n == 0: avg_loss = loss.item() else: avg_loss = 0.98 * avg_loss + 0.02 * loss.item() if n == 0 and epoch % 5 == 0: - # norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() - # norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() - # norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() - # norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() - # scale1 = '%.2e' % (m[0].weight_scale.exp().item()) - # scale1b = '%.2e' % (m[0].bias_scale.exp().item()) - # scale2 = '%.2e' % (m[2].weight_scale.exp().item()) - # scale2b = '%.2e' % (m[2].bias_scale.exp().item()) + #norm1 = '%.2e' % (m[0].weight**2).mean().sqrt().item() + #norm1b = '%.2e' % (m[0].bias**2).mean().sqrt().item() + #norm2 = '%.2e' % (m[2].weight**2).mean().sqrt().item() + #norm2b = '%.2e' % (m[2].bias**2).mean().sqrt().item() + #scale1 = '%.2e' % (m[0].weight_scale.exp().item()) + #scale1b = '%.2e' % (m[0].bias_scale.exp().item()) + #scale2 = '%.2e' % (m[2].weight_scale.exp().item()) + #scale2b = '%.2e' % (m[2].bias_scale.exp().item()) lr = scheduler.get_last_lr()[0] - logging.info( - f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss" - f" {avg_loss:.4g}, lr={lr:.4e}" - ) # , norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} + logging.info(f"Iter {iter}, epoch {epoch}, batch {n}, avg_loss {avg_loss:.4g}, lr={lr:.4e}") #, norms={norm1,norm1b,norm2,norm2b}") # scales={scale1,scale1b,scale2,scale2b} loss.log().backward() optim.step() optim.zero_grad() scheduler.step_batch() - # diagnostic.print_diagnostics() + #diagnostic.print_diagnostics() stop = timeit.default_timer() logging.info(f"Iter={iter}, Time taken: {stop - start}") logging.info(f"last lr = {scheduler.get_last_lr()}") - # logging.info("state dict = ", scheduler.state_dict()) - # logging.info("optim state_dict = ", optim.state_dict()) + #logging.info("state dict = ", scheduler.state_dict()) + #logging.info("optim state_dict = ", optim.state_dict()) logging.info(f"input_magnitudes = {input_magnitudes}") logging.info(f"output_magnitudes = {output_magnitudes}") + if __name__ == "__main__": torch.set_num_threads(1) torch.set_num_interop_threads(1) logging.getLogger().setLevel(logging.INFO) import subprocess - - s = subprocess.check_output( - "git status -uno .; git log -1; git diff HEAD .", shell=True - ) + s = subprocess.check_output("git status -uno .; git log -1; git diff HEAD .", shell=True) logging.info(s) import sys - if len(sys.argv) > 1: hidden_dim = int(sys.argv[1]) else: diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py index 8b4d88871..7fe1e681a 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/pretrained.py @@ -100,11 +100,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -129,12 +127,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -181,7 +177,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -212,9 +209,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -277,11 +275,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -353,7 +355,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py index 4040065e1..50cedba56 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling.py @@ -16,12 +16,12 @@ import collections -import logging -import random -from functools import reduce from itertools import repeat from typing import Optional, Tuple, Union +from functools import reduce +import logging +import random import torch import torch.nn as nn import torch.nn.functional as F @@ -32,24 +32,27 @@ from torch.nn import Embedding as ScaledEmbedding class ActivationBalancerFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - scale_factor: Tensor, - sign_factor: Optional[Tensor], - channel_dim: int, + ctx, + x: Tensor, + scale_factor: Tensor, + sign_factor: Optional[Tensor], + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = x > 0 + xgt0 = (x > 0) if sign_factor is None: ctx.save_for_backward(xgt0, scale_factor) else: ctx.save_for_backward(xgt0, scale_factor, sign_factor) return x + @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None]: if len(ctx.saved_tensors) == 3: xgt0, scale_factor, sign_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): @@ -62,22 +65,14 @@ class ActivationBalancerFunction(torch.autograd.Function): scale_factor = scale_factor.unsqueeze(-1) factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return ( - x_grad - neg_delta_grad, - None, - None, - None, - ) + return x_grad - neg_delta_grad, None, None, None, - -def _compute_scale_factor( - x: Tensor, - channel_dim: int, - min_abs: float, - max_abs: float, - gain_factor: float, - max_factor: float, -) -> Tensor: +def _compute_scale_factor(x: Tensor, + channel_dim: int, + min_abs: float, + max_abs: float, + gain_factor: float, + max_factor: float) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] @@ -88,76 +83,71 @@ def _compute_scale_factor( else: # below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if # x_abs)_mean , min_abs. - below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp( - min=0, max=max_factor - ) + below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor) - above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp( - min=0, max=max_factor - ) + above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor) return below_threshold - above_threshold - -def _compute_sign_factor( - x: Tensor, - channel_dim: int, - min_positive: float, - max_positive: float, - gain_factor: float, - max_factor: float, -) -> Tensor: +def _compute_sign_factor(x: Tensor, + channel_dim: int, + min_positive: float, + max_positive: float, + gain_factor: float, + max_factor: float) -> Tensor: if channel_dim < 0: channel_dim += x.ndim sum_dims = [d for d in range(x.ndim) if d != channel_dim] - proportion_positive = torch.mean((x > 0).to(torch.float32), dim=sum_dims) + proportion_positive = torch.mean((x > 0).to(torch.float32), + dim=sum_dims) if min_positive == 0.0: factor1 = 0.0 else: # 0 if proportion_positive >= min_positive, else can be # as large as max_factor. - factor1 = ( - (min_positive - proportion_positive) * (gain_factor / min_positive) - ).clamp_(min=0, max=max_factor) + factor1 = ((min_positive - proportion_positive) * + (gain_factor / min_positive)).clamp_(min=0, max=max_factor) if max_positive == 1.0: factor2 = 0.0 else: # 0 if self.proportion_positive <= max_positive, else can be # as large as -max_factor. - factor2 = ( - (proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive)) - ).clamp_(min=0, max=max_factor) + factor2 = ((proportion_positive - max_positive) * + (gain_factor / (1.0 - max_positive))).clamp_(min=0, max=max_factor) sign_factor = factor1 - factor2 # require min_positive != 0 or max_positive != 1: assert not isinstance(sign_factor, float) return sign_factor + class ActivationScaleBalancerFunction(torch.autograd.Function): """ This object is used in class ActivationBalancer when the user specified min_positive=0, max_positive=1, so there are no constraints on the signs of the activations and only the absolute value has a constraint. """ - @staticmethod def forward( - ctx, - x: Tensor, - sign_factor: Tensor, - scale_factor: Tensor, - channel_dim: int, + ctx, + x: Tensor, + sign_factor: Tensor, + scale_factor: Tensor, + channel_dim: int, ) -> Tensor: if channel_dim < 0: channel_dim += x.ndim ctx.channel_dim = channel_dim - xgt0 = x > 0 + xgt0 = (x > 0) ctx.save_for_backward(xgt0, sign_factor, scale_factor) return x + @staticmethod - def backward(ctx, x_grad: Tensor) -> Tuple[Tensor, None, None, None]: + def backward( + ctx, x_grad: Tensor + ) -> Tuple[Tensor, None, None, None]: xgt0, sign_factor, scale_factor = ctx.saved_tensors for _ in range(ctx.channel_dim, x_grad.ndim - 1): sign_factor = sign_factor.unsqueeze(-1) @@ -165,24 +155,18 @@ class ActivationScaleBalancerFunction(torch.autograd.Function): factor = sign_factor + scale_factor * (xgt0.to(x_grad.dtype) - 0.5) neg_delta_grad = x_grad.abs() * factor - return ( - x_grad - neg_delta_grad, - None, - None, - None, - ) + return x_grad - neg_delta_grad, None, None, None, class RandomClampFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - min: Optional[float], - max: Optional[float], - prob: float, - reflect: float, - ) -> Tensor: + ctx, + x: Tensor, + min: Optional[float], + max: Optional[float], + prob: float, + reflect: float) -> Tensor: x_clamped = torch.clamp(x, min=min, max=max) mask = torch.rand_like(x) < prob ans = torch.where(mask, x_clamped, x) @@ -195,32 +179,30 @@ class RandomClampFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None, None, None, None]: - (is_same,) = ctx.saved_tensors + is_same, = ctx.saved_tensors x_grad = ans_grad * is_same.to(ans_grad.dtype) reflect = ctx.reflect - if reflect != 0.0: + if reflect != 0.0: x_grad = x_grad * (1.0 + reflect) - (ans_grad * reflect) return x_grad, None, None, None, None - -def random_clamp( - x: Tensor, - min: Optional[float] = None, - max: Optional[float] = None, - prob: float = 0.5, - reflect: float = 0.0, -): +def random_clamp(x: Tensor, + min: Optional[float] = None, + max: Optional[float] = None, + prob: float = 0.5, + reflect: float = 0.0): return RandomClampFunction.apply(x, min, max, prob, reflect) -def random_cast_to_half(x: Tensor, min_abs: float = 5.0e-06) -> Tensor: +def random_cast_to_half(x: Tensor, + min_abs: float = 5.0e-06) -> Tensor: """ A randomized way of casting a floating point value to half precision. """ if x.dtype == torch.float16: return x x_abs = x.abs() - is_too_small = x_abs < min_abs + is_too_small = (x_abs < min_abs) # for elements where is_too_small is true, random_val will contain +-min_abs with # probability (x.abs() / min_abs), and 0.0 otherwise. [so this preserves expectations, # for those elements]. @@ -233,7 +215,6 @@ class RandomGradFunction(torch.autograd.Function): Does nothing in forward pass; in backward pass, gets rid of very small grads using randomized approach that preserves expectations (intended to reduce roundoff). """ - @staticmethod def forward(ctx, x: Tensor, min_abs: float) -> Tensor: ctx.min_abs = min_abs @@ -242,37 +223,35 @@ class RandomGradFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor) -> Tuple[Tensor, None]: if ans_grad.dtype == torch.float16: - return ( - random_cast_to_half(ans_grad.to(torch.float32), min_abs=ctx.min_abs), - None, - ) + return random_cast_to_half(ans_grad.to(torch.float32), + min_abs=ctx.min_abs), None else: return ans_grad, None - class RandomGrad(torch.nn.Module): """ Gets rid of very small gradients using an expectation-preserving method, intended to increase accuracy of training when using amp (automatic mixed precision) """ - - def __init__(self, min_abs: float = 5.0e-06): + def __init__(self, + min_abs: float = 5.0e-06): super(RandomGrad, self).__init__() self.min_abs = min_abs - def forward(self, x: Tensor): + def forward(self, + x: Tensor): if torch.jit.is_scripting() or not self.training: return x else: return RandomGradFunction.apply(x, self.min_abs) + class SoftmaxFunction(torch.autograd.Function): """ Tries to handle half-precision derivatives in a randomized way that should be more accurate for training than the default behavior. """ - @staticmethod def forward(ctx, x: Tensor, dim: int): ans = x.softmax(dim=dim) @@ -288,7 +267,7 @@ class SoftmaxFunction(torch.autograd.Function): @staticmethod def backward(ctx, ans_grad: Tensor): - (ans,) = ctx.saved_tensors + ans, = ctx.saved_tensors with torch.cuda.amp.autocast(enabled=False): ans_grad = ans_grad.to(torch.float32) ans = ans.to(torch.float32) @@ -297,7 +276,9 @@ class SoftmaxFunction(torch.autograd.Function): return x_grad, None -def softmax(x: Tensor, dim: int): + +def softmax(x: Tensor, + dim: int): if torch.jit.is_scripting(): return x.softmax(dim) @@ -307,18 +288,20 @@ def softmax(x: Tensor, dim: int): class MaxEigLimiterFunction(torch.autograd.Function): @staticmethod def forward( - ctx, - x: Tensor, - coeffs: Tensor, - direction: Tensor, - channel_dim: int, - grad_scale: float, - ) -> Tensor: + ctx, + x: Tensor, + coeffs: Tensor, + direction: Tensor, + channel_dim: int, + grad_scale: float) -> Tensor: ctx.channel_dim = channel_dim ctx.grad_scale = grad_scale - ctx.save_for_backward(x.detach(), coeffs.detach(), direction.detach()) + ctx.save_for_backward(x.detach(), + coeffs.detach(), + direction.detach()) return x + @staticmethod def backward(ctx, x_grad, *args): with torch.enable_grad(): @@ -328,20 +311,15 @@ class MaxEigLimiterFunction(torch.autograd.Function): x = x_orig.transpose(ctx.channel_dim, -1).reshape(-1, num_channels) new_direction.requires_grad = False x = x - x.mean(dim=0) - x_var = (x**2).mean() + x_var = (x ** 2).mean() x_residual = x - coeffs * new_direction - x_residual_var = (x_residual**2).mean() + x_residual_var = (x_residual ** 2).mean() # `variance_proportion` is the proportion of the variance accounted for # by the top eigen-direction. This is to be minimized. variance_proportion = (x_var - x_residual_var) / (x_var + 1.0e-20) variance_proportion.backward() x_orig_grad = x_orig.grad - x_extra_grad = ( - x_orig.grad - * ctx.grad_scale - * x_grad.norm() - / (x_orig_grad.norm() + 1.0e-20) - ) + x_extra_grad = x_orig.grad * ctx.grad_scale * x_grad.norm() / (x_orig_grad.norm() + 1.0e-20) return x_grad + x_extra_grad.detach(), None, None, None, None @@ -407,12 +385,15 @@ class BasicNorm(torch.nn.Module): # region if it happens to exit it. eps = eps.clamp(min=self.eps_min, max=self.eps_max) scales = ( - torch.mean(x**2, dim=self.channel_dim, keepdim=True) + eps.exp() + torch.mean(x ** 2, dim=self.channel_dim, keepdim=True) + eps.exp() ) ** -0.5 return x * scales -def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: + +def ScaledLinear(*args, + initial_scale: float = 1.0, + **kwargs ) -> nn.Linear: """ Behaves like a constructor of a modified version of nn.Linear that gives an easy way to set the default initial parameter scale. @@ -431,11 +412,16 @@ def ScaledLinear(*args, initial_scale: float = 1.0, **kwargs) -> nn.Linear: with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) return ans -def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: + +def ScaledConv1d(*args, + initial_scale: float = 1.0, + **kwargs ) -> nn.Conv1d: """ Behaves like a constructor of a modified version of nn.Conv1d that gives an easy way to set the default initial parameter scale. @@ -454,10 +440,13 @@ def ScaledConv1d(*args, initial_scale: float = 1.0, **kwargs) -> nn.Conv1d: with torch.no_grad(): ans.weight[:] *= initial_scale if ans.bias is not None: - torch.nn.init.uniform_(ans.bias, -0.1 * initial_scale, 0.1 * initial_scale) + torch.nn.init.uniform_(ans.bias, + -0.1 * initial_scale, + 0.1 * initial_scale) return ans + class ActivationBalancer(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to encourage, for @@ -497,19 +486,18 @@ class ActivationBalancer(torch.nn.Module): from doing it at the same time. Early in training we may use higher probabilities than this; it will decay to this value. """ - def __init__( - self, - num_channels: int, - channel_dim: int, - min_positive: float = 0.05, - max_positive: float = 0.95, - max_factor: float = 0.04, - sign_gain_factor: float = 0.01, - scale_gain_factor: float = 0.02, - min_abs: float = 0.2, - max_abs: float = 100.0, - min_prob: float = 0.1, + self, + num_channels: int, + channel_dim: int, + min_positive: float = 0.05, + max_positive: float = 0.95, + max_factor: float = 0.04, + sign_gain_factor: float = 0.01, + scale_gain_factor: float = 0.02, + min_abs: float = 0.2, + max_abs: float = 100.0, + min_prob: float = 0.1, ): super(ActivationBalancer, self).__init__() self.num_channels = num_channels @@ -527,7 +515,9 @@ class ActivationBalancer(torch.nn.Module): # We occasionally sync this to a tensor called `count`, that exists to # make sure it is synced to disk when we load and save the model. self.cpu_count = 0 - self.register_buffer("count", torch.tensor(0, dtype=torch.int64)) + self.register_buffer('count', torch.tensor(0, dtype=torch.int64)) + + def forward(self, x: Tensor) -> Tensor: if torch.jit.is_scripting() or not x.requires_grad: @@ -545,35 +535,26 @@ class ActivationBalancer(torch.nn.Module): # the prob of doing some work exponentially decreases from 0.5 till it hits # a floor at min_prob (==0.1, by default) - prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0))) + prob = max(self.min_prob, 0.5 ** (1 + (count/4000.0))) if random.random() < prob: sign_gain_factor = 0.5 if self.min_positive != 0.0 or self.max_positive != 1.0: - sign_factor = _compute_sign_factor( - x, - self.channel_dim, - self.min_positive, - self.max_positive, - gain_factor=self.sign_gain_factor / prob, - max_factor=self.max_factor, - ) + sign_factor = _compute_sign_factor(x, self.channel_dim, + self.min_positive, self.max_positive, + gain_factor=self.sign_gain_factor / prob, + max_factor=self.max_factor) else: sign_factor = None - scale_factor = _compute_scale_factor( - x, - self.channel_dim, - min_abs=self.min_abs, - max_abs=self.max_abs, - gain_factor=self.scale_gain_factor / prob, - max_factor=self.max_factor, - ) + + scale_factor = _compute_scale_factor(x, self.channel_dim, + min_abs=self.min_abs, + max_abs=self.max_abs, + gain_factor=self.scale_gain_factor / prob, + max_factor=self.max_factor) return ActivationBalancerFunction.apply( - x, - scale_factor, - sign_factor, - self.channel_dim, + x, scale_factor, sign_factor, self.channel_dim, ) else: return _no_op(x) @@ -613,12 +594,13 @@ def _diag(x: Tensor): # like .diag(), but works for tensors with 3 dims. else: (batch, dim, dim) = x.shape x = x.reshape(batch, dim * dim) - x = x[:, :: dim + 1] + x = x[:, ::dim+1] assert x.shape == (batch, dim) return x -def _whitening_metric(x: Tensor, num_groups: int): +def _whitening_metric(x: Tensor, + num_groups: int): """ Computes the "whitening metric", a value which will be 1.0 if all the eigenvalues of of the centered feature covariance are the same within each group's covariance matrix @@ -648,21 +630,19 @@ def _whitening_metric(x: Tensor, num_groups: int): # the following expression is what we'd get if we took the matrix product # of each covariance and measured the mean of its trace, i.e. # the same as _diag(torch.matmul(x_covar, x_covar)).mean(). - x_covarsq_mean_diag = (x_covar**2).sum() / (num_groups * channels_per_group) + x_covarsq_mean_diag = (x_covar ** 2).sum() / (num_groups * channels_per_group) # this metric will be >= 1.0; the larger it is, the less 'white' the data was. - metric = x_covarsq_mean_diag / (x_covar_mean_diag**2 + 1.0e-20) + metric = x_covarsq_mean_diag / (x_covar_mean_diag ** 2 + 1.0e-20) return metric class WhiteningPenaltyFunction(torch.autograd.Function): @staticmethod - def forward( - ctx, - x: Tensor, - num_groups: int, - whitening_limit: float, - grad_scale: float, - ) -> Tensor: + def forward(ctx, + x: Tensor, + num_groups: int, + whitening_limit: float, + grad_scale: float) -> Tensor: ctx.save_for_backward(x) ctx.num_groups = num_groups ctx.whitening_limit = whitening_limit @@ -670,8 +650,9 @@ class WhiteningPenaltyFunction(torch.autograd.Function): return x @staticmethod - def backward(ctx, x_grad: Tensor): - (x_orig,) = ctx.saved_tensors + def backward(ctx, + x_grad: Tensor): + x_orig, = ctx.saved_tensors with torch.enable_grad(): with torch.cuda.amp.autocast(enabled=False): x_detached = x_orig.to(torch.float32).detach() @@ -680,29 +661,25 @@ class WhiteningPenaltyFunction(torch.autograd.Function): metric = _whitening_metric(x_detached, ctx.num_groups) if random.random() < 0.005 or __name__ == "__main__": - logging.info( - f"Whitening: num_groups={ctx.num_groups}," - f" num_channels={x_orig.shape[-1]}," - f" metric={metric.item():.2f} vs. limit={ctx.whitening_limit}" - ) + logging.info(f"Whitening: num_groups={ctx.num_groups}, num_channels={x_orig.shape[-1]}, " + f"metric={metric.item():.2f} vs. limit={ctx.whitening_limit}") (metric - ctx.whitening_limit).relu().backward() penalty_grad = x_detached.grad - scale = ctx.grad_scale * ( - x_grad.to(torch.float32).norm() / (penalty_grad.norm() + 1.0e-20) - ) + scale = ctx.grad_scale * (x_grad.to(torch.float32).norm() / + (penalty_grad.norm() + 1.0e-20)) penalty_grad = penalty_grad * scale return x_grad + penalty_grad.to(x_grad.dtype), None, None, None + class Whiten(nn.Module): def __init__( - self, - num_groups: int, - whitening_limit: float, - prob: Union[float, Tuple[float, float]], - grad_scale: float, - ): + self, + num_groups: int, + whitening_limit: float, + prob: Union[float, Tuple[float,float]], + grad_scale: float): """ Args: num_groups: the number of groups to divide the channel dim into before @@ -737,7 +714,8 @@ class Whiten(nn.Module): self.grad_scale = grad_scale - def forward(self, x: Tensor) -> Tensor: + def forward(self, + x: Tensor) -> Tensor: """ In the forward pass, this function just returns the input unmodified. In the backward pass, it will modify the gradients to ensure that the @@ -757,21 +735,19 @@ class Whiten(nn.Module): if not x.requires_grad or random.random() > self.prob or self.grad_scale == 0: return _no_op(x) else: - if hasattr(self, "min_prob") and random.random() < 0.25: + if hasattr(self, 'min_prob') and random.random() < 0.25: # occasionally switch between min_prob and max_prob, based on whether # we are above or below the threshold. - if ( - _whitening_metric(x.to(torch.float32), self.num_groups) - > self.whitening_limit - ): + if _whitening_metric(x.to(torch.float32), self.num_groups) > self.whitening_limit: # there would be a change to the grad. self.prob = self.max_prob else: self.prob = self.min_prob - return WhiteningPenaltyFunction.apply( - x, self.num_groups, self.whitening_limit, self.grad_scale - ) + return WhiteningPenaltyFunction.apply(x, + self.num_groups, + self.whitening_limit, + self.grad_scale) class WithLoss(torch.autograd.Function): @@ -779,14 +755,11 @@ class WithLoss(torch.autograd.Function): def forward(ctx, x: Tensor, y: Tensor): ctx.y_shape = y.shape return x - @staticmethod def backward(ctx, ans_grad: Tensor): - return ans_grad, torch.ones( - ctx.y_shape, dtype=ans_grad.dtype, device=ans_grad.device - ) - - + return ans_grad, torch.ones(ctx.y_shape, + dtype=ans_grad.dtype, + device=ans_grad.device) def with_loss(x, y): if torch.jit.is_scripting(): return x @@ -795,7 +768,7 @@ def with_loss(x, y): def _no_op(x: Tensor) -> Tensor: - if torch.jit.is_scripting(): + if (torch.jit.is_scripting()): return x else: # a no-op function that will have a node in the autograd graph, @@ -810,7 +783,6 @@ class Identity(torch.nn.Module): def forward(self, x): return _no_op(x) - class MaxEig(torch.nn.Module): """ Modifies the backpropped derivatives of a function to try to discourage @@ -831,14 +803,13 @@ class MaxEig(torch.nn.Module): scale: determines the scale with which we modify the gradients, relative to the existing / unmodified gradients """ - def __init__( - self, - num_channels: int, - channel_dim: int, - max_var_per_eig: float = 0.2, - min_prob: float = 0.01, - scale: float = 0.01, + self, + num_channels: int, + channel_dim: int, + max_var_per_eig: float = 0.2, + min_prob: float = 0.01, + scale: float = 0.01, ): super(MaxEig, self).__init__() self.num_channels = num_channels @@ -854,7 +825,7 @@ class MaxEig(torch.nn.Module): # random parameters unchanged for comparison direction = torch.arange(num_channels).to(torch.float) direction = direction / direction.norm() - self.register_buffer("max_eig_direction", direction) + self.register_buffer('max_eig_direction', direction) self.min_prob = min_prob # cur_prob is the current probability we'll use to apply the ActivationBalancer. @@ -862,12 +833,12 @@ class MaxEig(torch.nn.Module): # active. self.cur_prob = 1.0 + + def forward(self, x: Tensor) -> Tensor: - if ( - torch.jit.is_scripting() - or self.max_var_per_eig <= 0 - or random.random() > self.cur_prob - ): + if (torch.jit.is_scripting() or + self.max_var_per_eig <= 0 or + random.random() > self.cur_prob): return _no_op(x) with torch.cuda.amp.autocast(enabled=False): @@ -877,9 +848,7 @@ class MaxEig(torch.nn.Module): with torch.no_grad(): x = x.transpose(self.channel_dim, -1).reshape(-1, self.num_channels) x = x - x.mean(dim=0) - new_direction, coeffs = self._find_direction_coeffs( - x, self.max_eig_direction - ) + new_direction, coeffs = self._find_direction_coeffs(x, self.max_eig_direction) x_var = (x**2).mean() x_residual = x - coeffs * new_direction x_residual_var = (x_residual**2).mean() @@ -892,10 +861,7 @@ class MaxEig(torch.nn.Module): self._set_direction(0.1 * self.max_eig_direction + new_direction) if random.random() < 0.01 or __name__ == "__main__": - logging.info( - f"variance_proportion = {variance_proportion.item()}," - f" shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}" - ) + logging.info(f"variance_proportion = {variance_proportion.item()}, shape={tuple(orig_x.shape)}, cur_prob={self.cur_prob}") if variance_proportion >= self.max_var_per_eig: # The constraint is active. Note, we should quite rarely @@ -903,16 +869,17 @@ class MaxEig(torch.nn.Module): # starting to diverge, should this constraint be active. cur_prob = self.cur_prob self.cur_prob = 1.0 # next time, do the update with probability 1.0. - return MaxEigLimiterFunction.apply( - orig_x, coeffs, new_direction, self.channel_dim, self.scale - ) + return MaxEigLimiterFunction.apply(orig_x, coeffs, new_direction, + self.channel_dim, self.scale) else: # let self.cur_prob exponentially approach self.min_prob, as # long as the constraint is inactive. self.cur_prob = 0.75 * self.cur_prob + 0.25 * self.min_prob return orig_x - def _set_direction(self, direction: Tensor): + + def _set_direction(self, + direction: Tensor): """ Sets self.max_eig_direction to a normalized version of `direction` """ @@ -922,39 +889,40 @@ class MaxEig(torch.nn.Module): if direction_sum - direction_sum == 0: # no inf/nan self.max_eig_direction[:] = direction else: - logging.info( - f"Warning: sum of direction in MaxEig is {direction_sum}, " - "num_channels={self.num_channels}, channel_dim={self.channel_dim}" - ) + logging.info(f"Warning: sum of direction in MaxEig is {direction_sum}, " + "num_channels={self.num_channels}, channel_dim={self.channel_dim}") - def _find_direction_coeffs( - self, x: Tensor, prev_direction: Tensor - ) -> Tuple[Tensor, Tensor, Tensor]: - """ - Figure out (an approximation to) the proportion of the variance of a set of - feature vectors that can be attributed to the top eigen-direction. - Args: - x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. - prev_direction: a Tensor of shape (num_channels,), that is our previous estimate - of the top eigen-direction, or a random direction if this is the first - iteration. Does not have to be normalized, but should be nonzero. - Returns: (cur_direction, coeffs), where: - cur_direction: a Tensor of shape (num_channels,) that is the current - estimate of the top eigen-direction. - coeffs: a Tensor of shape (num_frames, 1) that minimizes, or - approximately minimizes, (x - coeffs * cur_direction).norm() + def _find_direction_coeffs(self, + x: Tensor, + prev_direction: Tensor) -> Tuple[Tensor, Tensor, Tensor]: """ + Figure out (an approximation to) the proportion of the variance of a set of + feature vectors that can be attributed to the top eigen-direction. + Args: + x: a Tensor of shape (num_frames, num_channels), with num_frames > 1. + prev_direction: a Tensor of shape (num_channels,), that is our previous estimate + of the top eigen-direction, or a random direction if this is the first + iteration. Does not have to be normalized, but should be nonzero. + + Returns: (cur_direction, coeffs), where: + cur_direction: a Tensor of shape (num_channels,) that is the current + estimate of the top eigen-direction. + coeffs: a Tensor of shape (num_frames, 1) that minimizes, or + approximately minimizes, (x - coeffs * cur_direction).norm() + """ (num_frames, num_channels) = x.shape assert num_channels > 1 and num_frames > 1 assert prev_direction.shape == (num_channels,) # `coeffs` are the coefficients of `prev_direction` in x. # actually represent the coeffs up to a constant positive factor. coeffs = (x * prev_direction).sum(dim=1, keepdim=True) + 1.0e-10 - cur_direction = (x * coeffs).sum(dim=0) / ((coeffs**2).sum() + 1.0e-20) + cur_direction = (x * coeffs).sum(dim=0) / ((coeffs ** 2).sum() + 1.0e-20) return cur_direction, coeffs + + class DoubleSwishFunction(torch.autograd.Function): """ double_swish(x) = x * torch.sigmoid(x-1) @@ -982,7 +950,7 @@ class DoubleSwishFunction(torch.autograd.Function): y = x * s if requires_grad: - deriv = y * (1 - s) + s + deriv = (y * (1 - s) + s) # notes on derivative of x * sigmoid(x - 1): # https://www.wolframalpha.com/input?i=d%2Fdx+%28x+*+sigmoid%28x-1%29%29 # min \simeq -0.043638. Take floor as -0.043637 so it's a lower bund @@ -991,9 +959,7 @@ class DoubleSwishFunction(torch.autograd.Function): # floors), should be expectation-preserving. floor = -0.043637 ceil = 1.2 - d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like( - deriv - ) + d_scaled = ((deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)) if __name__ == "__main__": # for self-testing only. assert d_scaled.min() >= 0.0 @@ -1006,12 +972,12 @@ class DoubleSwishFunction(torch.autograd.Function): @staticmethod def backward(ctx, y_grad: Tensor) -> Tensor: - (d,) = ctx.saved_tensors + d, = ctx.saved_tensors # the same constants as used in forward pass. floor = -0.043637 ceil = 1.2 - d = d * ((ceil - floor) / 255.0) + floor - return y_grad * d + d = (d * ((ceil - floor) / 255.0) + floor) + return (y_grad * d) class DoubleSwish(torch.nn.Module): @@ -1024,6 +990,7 @@ class DoubleSwish(torch.nn.Module): return DoubleSwishFunction.apply(x) + def _test_max_eig(): for proportion in [0.1, 0.5, 10.0]: logging.info(f"proportion = {proportion}") @@ -1035,9 +1002,11 @@ def _test_max_eig(): x.requires_grad = True num_channels = 128 - m = MaxEig( - num_channels, 1, 0.5, scale=0.1 # channel_dim # max_var_per_eig - ) # grad_scale + m = MaxEig(num_channels, + 1, # channel_dim + 0.5, # max_var_per_eig + scale=0.1) # grad_scale + for _ in range(4): y = m(x) @@ -1062,9 +1031,11 @@ def _test_whiten(): x.requires_grad = True num_channels = 128 - m = Whiten( - 1, 5.0, prob=1.0, grad_scale=0.1 # num_groups # whitening_limit, - ) # grad_scale + m = Whiten(1, # num_groups + 5.0, # whitening_limit, + prob=1.0, + grad_scale=0.1) # grad_scale + for _ in range(4): y = m(x) @@ -1078,6 +1049,7 @@ def _test_whiten(): assert not torch.allclose(x.grad, y_grad) + def _test_activation_balancer_sign(): probs = torch.arange(0, 1, 0.01) N = 1000 @@ -1105,7 +1077,9 @@ def _test_activation_balancer_sign(): def _test_activation_balancer_magnitude(): magnitudes = torch.arange(0, 1, 0.01) N = 1000 - x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze(-1) + x = torch.sign(torch.randn(magnitudes.numel(), N)) * magnitudes.unsqueeze( + -1 + ) x = x.detach() x.requires_grad = True m = ActivationBalancer( @@ -1137,8 +1111,8 @@ def _test_basic_norm(): y = m(x) assert y.shape == x.shape - x_rms = (x**2).mean().sqrt() - y_rms = (y**2).mean().sqrt() + x_rms = (x ** 2).mean().sqrt() + y_rms = (y ** 2).mean().sqrt() print("x rms = ", x_rms) print("y rms = ", y_rms) assert y_rms < x_rms @@ -1150,27 +1124,30 @@ def _test_double_swish_deriv(): x.requires_grad = True m = DoubleSwish() - tol = (1.2 - (-0.043637)) / 255.0 + tol = ((1.2-(-0.043637))/255.0) torch.autograd.gradcheck(m, x, atol=tol) + # for self-test. x = torch.randn(1000, 1000, dtype=torch.double) * 3.0 x.requires_grad = True y = m(x) + def _test_softmax(): a = torch.randn(2, 10, dtype=torch.float64) b = a.clone() a.requires_grad = True b.requires_grad = True - a.softmax(dim=1)[:, 0].sum().backward() + a.softmax(dim=1)[:,0].sum().backward() print("a grad = ", a.grad) - softmax(b, dim=1)[:, 0].sum().backward() + softmax(b, dim=1)[:,0].sum().backward() print("b grad = ", b.grad) assert torch.allclose(a.grad, b.grad) + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py index 46e775285..8d357b15f 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/scaling_converter.py @@ -26,7 +26,11 @@ from typing import List import torch import torch.nn as nn -from scaling import ActivationBalancer, BasicNorm, Whiten +from scaling import ( + ActivationBalancer, + BasicNorm, + Whiten, +) class NonScaledNorm(nn.Module): @@ -71,10 +75,12 @@ def get_submodule(model, target): mod: torch.nn.Module = model for item in atoms: if not hasattr(mod, item): - raise AttributeError(mod._get_name() + " has no attribute `" + item + "`") + raise AttributeError( + mod._get_name() + " has no " "attribute `" + item + "`" + ) mod = getattr(mod, item) if not isinstance(mod, torch.nn.Module): - raise AttributeError("`" + item + "` is not an nn.Module") + raise AttributeError("`" + item + "` is not " "an nn.Module") return mod diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py index 7f9526104..3f27736b3 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/train.py @@ -84,7 +84,9 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -122,10 +124,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help=( - "Embedding dimension in the 2 blocks of zipformer encoder layers, comma" - " separated" - ), + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", ) parser.add_argument( @@ -140,11 +139,9 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-unmasked-dims", type=str, default="256,256,256,256,256", - help=( - "Unmasked dimensions in the encoders, relates to augmentation during" - " training. Must be <= each of encoder_dims. Empirically, less than 256" - " seems to make performance worse." - ), + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", ) parser.add_argument( @@ -272,45 +269,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -652,7 +646,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -699,7 +697,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -870,7 +870,9 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + if cur_grad_scale < 1.0 or ( + cur_grad_scale < 8.0 and batch_idx % 400 == 0 + ): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -888,7 +890,11 @@ def train_one_epoch( f"batch {batch_idx}, loss[{loss_info}], " f"tot_loss[{tot_loss}], batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + + ( + f"grad_scale: {scaler._scale.item()}" + if params.use_fp16 + else "" + ) ) if tb_writer is not None: @@ -899,7 +905,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if params.use_fp16: tb_writer.add_scalar( "train/grad_scale", @@ -907,7 +915,10 @@ def train_one_epoch( params.batch_idx_train, ) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -919,8 +930,7 @@ def train_one_epoch( model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( - "Maximum memory allocated so far is" - f" {torch.cuda.max_memory_allocated()//1000000}MB" + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( @@ -999,7 +1009,9 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + optimizer = ScaledAdam( + model.parameters(), lr=params.base_lr, clipping_scale=2.0 + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1017,7 +1029,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1042,7 +1054,8 @@ def run(rank, world_size, args): # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -1216,8 +1229,7 @@ def scan_pessimistic_batches_for_oom( display_and_save_batch(batch, params=params, sp=sp) raise logging.info( - "Maximum memory allocated so far is" - f" {torch.cuda.max_memory_allocated()//1000000}MB" + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py index fcd9858cd..023dec97d 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/zipformer.py @@ -16,35 +16,32 @@ # limitations under the License. import copy -import itertools -import logging import math -import random import warnings +import itertools from typing import List, Optional, Tuple, Union - +import logging import torch +import random from encoder_interface import EncoderInterface -from scaling import ( - ScaledLinear, # not as in other dirs.. just scales down initial parameter values. -) from scaling import ( ActivationBalancer, BasicNorm, - DoubleSwish, - Identity, MaxEig, + DoubleSwish, ScaledConv1d, + ScaledLinear, # not as in other dirs.. just scales down initial parameter values. Whiten, + Identity, _diag, - penalize_abs_values_gt, random_clamp, + penalize_abs_values_gt, softmax, ) from torch import Tensor, nn -from icefall.dist import get_rank from icefall.utils import make_pad_mask +from icefall.dist import get_rank class Zipformer(EncoderInterface): @@ -92,7 +89,7 @@ class Zipformer(EncoderInterface): self.batch_count = 0 self.warmup_end = warmup_batches - for u, d in zip(encoder_unmasked_dims, encoder_dims): + for u,d in zip(encoder_unmasked_dims, encoder_dims): assert u <= d, (u, d) # self.encoder_embed converts the input of shape (N, T, num_features) @@ -100,9 +97,9 @@ class Zipformer(EncoderInterface): # That is, it does two things simultaneously: # (1) subsampling: T -> (T - 7)//2 # (2) embedding: num_features -> encoder_dims - self.encoder_embed = Conv2dSubsampling( - num_features, encoder_dims[0], dropout=dropout - ) + self.encoder_embed = Conv2dSubsampling(num_features, encoder_dims[0], + dropout=dropout) + # each one will be ZipformerEncoder or DownsampledZipformerEncoder encoders = [] @@ -126,13 +123,13 @@ class Zipformer(EncoderInterface): num_encoder_layers[i], dropout, warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1), - warmup_end=warmup_batches * (i + 2) / (num_encoders + 1), + warmup_end=warmup_batches * (i + 2) / (num_encoders + 1) ) if zipformer_downsampling_factors[i] != 1: encoder = DownsampledZipformerEncoder( encoder, - input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0], + input_dim=encoder_dims[i-1] if i > 0 else encoder_dims[0], output_dim=encoder_dims[i], downsample=zipformer_downsampling_factors[i], ) @@ -142,11 +139,10 @@ class Zipformer(EncoderInterface): # initializes self.skip_layers and self.skip_modules self._init_skip_modules() - self.downsample_output = AttentionDownsample( - encoder_dims[-1], - encoder_dims[-1], - downsample=output_downsampling_factor, - ) + self.downsample_output = AttentionDownsample(encoder_dims[-1], + encoder_dims[-1], + downsample=output_downsampling_factor) + def _get_layer_skip_dropout_prob(self): if not self.training: @@ -170,33 +166,27 @@ class Zipformer(EncoderInterface): skip_modules = [] z = self.zipformer_downsampling_factors for i in range(len(z)): - if i <= 1 or z[i - 1] <= z[i]: + if i <= 1 or z[i-1] <= z[i]: skip_layers.append(None) skip_modules.append(SimpleCombinerIdentity()) else: # TEMP - for j in range(i - 2, -1, -1): + for j in range(i-2, -1, -1): if z[j] <= z[i] or j == 0: # TEMP logging statement. - logging.info( - f"At encoder stack {i}, which has" - f" downsampling_factor={z[i]}, we will combine the outputs" - f" of layers {j} and {i-1}, with" - f" downsampling_factors={z[j]} and {z[i-1]}." - ) + logging.info(f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will " + f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}.") skip_layers.append(j) - skip_modules.append( - SimpleCombiner( - self.encoder_dims[j], - self.encoder_dims[i - 1], - min_weight=(0.0, 0.25), - ) - ) + skip_modules.append(SimpleCombiner(self.encoder_dims[j], + self.encoder_dims[i-1], + min_weight=(0.0,0.25))) break self.skip_layers = skip_layers self.skip_modules = nn.ModuleList(skip_modules) - def get_feature_masks(self, x: torch.Tensor) -> List[float]: + def get_feature_masks( + self, + x: torch.Tensor) -> List[float]: # Note: The actual return type is Union[List[float], List[Tensor]], # but to make torch.jit.script() work, we use List[float] """ @@ -216,56 +206,46 @@ class Zipformer(EncoderInterface): """ num_encoders = len(self.encoder_dims) if torch.jit.is_scripting() or not self.training: - return [1.0] * num_encoders + return [ 1.0 ] * num_encoders (num_frames0, batch_size, _encoder_dims0) = x.shape - assert self.encoder_dims[0] == _encoder_dims0, ( - self.encoder_dims, - _encoder_dims0, - ) + + assert self.encoder_dims[0] == _encoder_dims0, (self.encoder_dims, _encoder_dims0) max_downsampling_factor = max(self.zipformer_downsampling_factors) - num_frames_max = num_frames0 + max_downsampling_factor - 1 + num_frames_max = (num_frames0 + max_downsampling_factor - 1) + feature_mask_dropout_prob = 0.15 # frame_mask_max shape: (num_frames_max, batch_size, 1) - frame_mask_max = ( - torch.rand(num_frames_max, batch_size, 1, device=x.device) - > feature_mask_dropout_prob - ).to(x.dtype) + frame_mask_max = (torch.rand(num_frames_max, batch_size, 1, + device=x.device) > + feature_mask_dropout_prob).to(x.dtype) feature_masks = [] for i in range(num_encoders): ds = self.zipformer_downsampling_factors[i] - upsample_factor = max_downsampling_factor // ds + upsample_factor = (max_downsampling_factor // ds) - frame_mask = ( - frame_mask_max.unsqueeze(1) - .expand(num_frames_max, upsample_factor, batch_size, 1) - .reshape(num_frames_max * upsample_factor, batch_size, 1) - ) + frame_mask = (frame_mask_max.unsqueeze(1).expand(num_frames_max, upsample_factor, + batch_size, 1) + .reshape(num_frames_max * upsample_factor, batch_size, 1)) num_frames = (num_frames0 + ds - 1) // ds frame_mask = frame_mask[:num_frames] - feature_mask = torch.ones( - num_frames, - batch_size, - self.encoder_dims[i], - dtype=x.dtype, - device=x.device, - ) + feature_mask = torch.ones(num_frames, batch_size, self.encoder_dims[i], + dtype=x.dtype, device=x.device) u = self.encoder_unmasked_dims[i] feature_mask[:, :, u:] *= frame_mask feature_masks.append(feature_mask) return feature_masks + def forward( - self, - x: torch.Tensor, - x_lens: torch.Tensor, + self, x: torch.Tensor, x_lens: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -285,19 +265,13 @@ class Zipformer(EncoderInterface): x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) lengths = (x_lens - 7) >> 1 - assert x.size(0) == lengths.max().item(), ( - x.shape, - lengths, - lengths.max(), - ) + assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max()) mask = make_pad_mask(lengths) outputs = [] feature_masks = self.get_feature_masks(x) - for i, (module, skip_module) in enumerate( - zip(self.encoders, self.skip_modules) - ): + for i, (module, skip_module) in enumerate(zip(self.encoders, self.skip_modules)): ds = self.zipformer_downsampling_factors[i] k = self.skip_layers[i] if isinstance(k, int): @@ -306,11 +280,9 @@ class Zipformer(EncoderInterface): x = skip_module(outputs[k], x) elif (not self.training) or random.random() > layer_skip_dropout_prob: x = skip_module(outputs[k], x) - x = module( - x, - feature_mask=feature_masks[i], - src_key_padding_mask=None if mask is None else mask[..., ::ds], - ) + x = module(x, + feature_mask=feature_masks[i], + src_key_padding_mask=None if mask is None else mask[...,::ds]) outputs.append(x) x = self.downsample_output(x) @@ -340,16 +312,15 @@ class ZipformerEncoderLayer(nn.Module): >>> pos_emb = torch.rand(32, 19, 512) >>> out = encoder_layer(src, pos_emb) """ - def __init__( - self, - d_model: int, - attention_dim: int, - nhead: int, - feedforward_dim: int = 2048, - dropout: float = 0.1, - cnn_module_kernel: int = 31, - pos_dim: int = 4, + self, + d_model: int, + attention_dim: int, + nhead: int, + feedforward_dim: int = 2048, + dropout: float = 0.1, + cnn_module_kernel: int = 31, + pos_dim: int = 4, ) -> None: super(ZipformerEncoderLayer, self).__init__() @@ -359,24 +330,29 @@ class ZipformerEncoderLayer(nn.Module): self.batch_count = 0 self.self_attn = RelPositionMultiheadAttention( - d_model, - attention_dim, - nhead, - pos_dim, - dropout=0.0, + d_model, attention_dim, nhead, pos_dim, dropout=0.0, ) self.pooling = PoolingModule(d_model) - self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.feed_forward1 = FeedforwardModule(d_model, + feedforward_dim, + dropout) - self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.feed_forward2 = FeedforwardModule(d_model, + feedforward_dim, + dropout) - self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout) + self.feed_forward3 = FeedforwardModule(d_model, + feedforward_dim, + dropout) - self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel) - self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel) + self.conv_module1 = ConvolutionModule(d_model, + cnn_module_kernel) + + self.conv_module2 = ConvolutionModule(d_model, + cnn_module_kernel) self.norm_final = BasicNorm(d_model) @@ -384,18 +360,14 @@ class ZipformerEncoderLayer(nn.Module): # try to ensure the output is close to zero-mean (or at least, zero-median). self.balancer = ActivationBalancer( - d_model, - channel_dim=-1, - min_positive=0.45, - max_positive=0.55, + d_model, channel_dim=-1, + min_positive=0.45, max_positive=0.55, max_abs=6.0, ) - self.whiten = Whiten( - num_groups=1, - whitening_limit=5.0, - prob=(0.025, 0.25), - grad_scale=0.01, - ) + self.whiten = Whiten(num_groups=1, + whitening_limit=5.0, + prob=(0.025, 0.25), + grad_scale=0.01) def get_bypass_scale(self): if torch.jit.is_scripting() or not self.training: @@ -410,9 +382,8 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: clamp_min = final_clamp_min else: - clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * ( - initial_clamp_min - final_clamp_min - ) + clamp_min = (initial_clamp_min - + (self.batch_count / warmup_period) * (initial_clamp_min - final_clamp_min)) return self.bypass_scale.clamp(min=clamp_min, max=1.0) def get_dynamic_dropout_rate(self): @@ -427,9 +398,8 @@ class ZipformerEncoderLayer(nn.Module): if self.batch_count > warmup_period: return final_dropout_rate else: - return initial_dropout_rate - ( - initial_dropout_rate * final_dropout_rate - ) * (self.batch_count / warmup_period) + return (initial_dropout_rate - + (initial_dropout_rate * final_dropout_rate) * (self.batch_count / warmup_period)) def forward( self, @@ -538,14 +508,13 @@ class ZipformerEncoder(nn.Module): >>> src = torch.rand(10, 32, 512) >>> out = zipformer_encoder(src) """ - def __init__( - self, - encoder_layer: nn.Module, - num_layers: int, - dropout: float, - warmup_begin: float, - warmup_end: float, + self, + encoder_layer: nn.Module, + num_layers: int, + dropout: float, + warmup_begin: float, + warmup_end: float ) -> None: super().__init__() # will be written to, see set_batch_count() Note: in inference time this @@ -559,7 +528,8 @@ class ZipformerEncoder(nn.Module): # so that we can keep this consistent across worker tasks (for efficiency). self.module_seed = torch.randint(0, 1000, ()).item() - self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout) + self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, + dropout) self.layers = nn.ModuleList( [copy.deepcopy(encoder_layer) for i in range(num_layers)] @@ -568,13 +538,15 @@ class ZipformerEncoder(nn.Module): assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end) - delta = (1.0 / num_layers) * (warmup_end - warmup_begin) + + delta = (1. / num_layers) * (warmup_end - warmup_begin) cur_begin = warmup_begin for i in range(num_layers): self.layers[i].warmup_begin = cur_begin cur_begin += delta self.layers[i].warmup_end = cur_begin + def get_layers_to_drop(self, rnd_seed: int): ans = set() if not self.training: @@ -607,14 +579,12 @@ class ZipformerEncoder(nn.Module): # linearly interpolate t = (batch_count - layer_warmup_begin) / layer_warmup_end assert 0.0 <= t < 1.001, t - return initial_layerdrop_prob + t * ( - final_layerdrop_prob - initial_layerdrop_prob - ) + return initial_layerdrop_prob + t * (final_layerdrop_prob - initial_layerdrop_prob) shared_rng = random.Random(batch_count + self.module_seed) independent_rng = random.Random(rnd_seed) - layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)] + layerdrop_probs = [ get_layerdrop_prob(i) for i in range(num_layers) ] tot = sum(layerdrop_probs) # Instead of drawing the samples independently, we first randomly decide # how many layers to drop out, using the same random number generator between @@ -634,13 +604,11 @@ class ZipformerEncoder(nn.Module): if len(ans) == num_to_drop: break if shared_rng.random() < 0.005 or __name__ == "__main__": - logging.info( - f"warmup_begin={self.warmup_begin:.1f}," - f" warmup_end={self.warmup_end:.1f}, batch_count={batch_count:.1f}," - f" num_to_drop={num_to_drop}, layers_to_drop={ans}" - ) + logging.info(f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, " + f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}") return ans + def forward( self, src: Tensor, @@ -671,6 +639,7 @@ class ZipformerEncoder(nn.Module): pos_emb = self.encoder_pos(src) output = src + if torch.jit.is_scripting(): layers_to_drop = [] else: @@ -701,31 +670,28 @@ class DownsampledZipformerEncoder(nn.Module): after convolutional downsampling, and then upsampled again at the output, and combined with the origin input, so that the output has the same shape as the input. """ - - def __init__( - self, - encoder: nn.Module, - input_dim: int, - output_dim: int, - downsample: int, - ): + def __init__(self, + encoder: nn.Module, + input_dim: int, + output_dim: int, + downsample: int): super(DownsampledZipformerEncoder, self).__init__() self.downsample_factor = downsample self.downsample = AttentionDownsample(input_dim, output_dim, downsample) self.encoder = encoder self.upsample = SimpleUpsample(output_dim, downsample) - self.out_combiner = SimpleCombiner( - input_dim, output_dim, min_weight=(0.0, 0.25) - ) + self.out_combiner = SimpleCombiner(input_dim, + output_dim, + min_weight=(0.0, 0.25)) - def forward( - self, - src: Tensor, - # Note: the type of feature_mask should be Unino[float, Tensor], - # but to make torch.jit.script() happ, we use float here - feature_mask: float = 1.0, - mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, + + def forward(self, + src: Tensor, + # Note: the type of feature_mask should be Unino[float, Tensor], + # but to make torch.jit.script() happ, we use float here + feature_mask: float = 1.0, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: r"""Downsample, go through encoder, upsample. @@ -752,43 +718,42 @@ class DownsampledZipformerEncoder(nn.Module): src = self.downsample(src) ds = self.downsample_factor if mask is not None: - mask = mask[::ds, ::ds] + mask = mask[::ds,::ds] src = self.encoder( - src, - feature_mask=feature_mask, - mask=mask, - src_key_padding_mask=mask, + src, feature_mask=feature_mask, mask=mask, src_key_padding_mask=mask, ) src = self.upsample(src) # remove any extra frames that are not a multiple of downsample_factor - src = src[: src_orig.shape[0]] + src = src[:src_orig.shape[0]] return self.out_combiner(src_orig, src) - class AttentionDownsample(torch.nn.Module): """ Does downsampling with attention, by weighted sum, and a projection.. """ - - def __init__(self, in_channels: int, out_channels: int, downsample: int): + def __init__(self, + in_channels: int, + out_channels: int, + downsample: int): """ Require out_channels > in_channels. """ super(AttentionDownsample, self).__init__() - self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5)) + self.query = nn.Parameter(torch.randn(in_channels) * (in_channels ** -0.5)) # fill in the extra dimensions with a projection of the input if out_channels > in_channels: - self.extra_proj = nn.Linear( - in_channels * downsample, out_channels - in_channels, bias=False - ) + self.extra_proj = nn.Linear(in_channels * downsample, + out_channels - in_channels, + bias=False) else: self.extra_proj = None self.downsample = downsample - def forward(self, src: Tensor) -> Tensor: + def forward(self, + src: Tensor) -> Tensor: """ x: (seq_len, batch_size, in_channels) Returns a tensor of shape @@ -802,14 +767,16 @@ class AttentionDownsample(torch.nn.Module): if seq_len != d_seq_len * ds: # right-pad src, repeating the last element. pad = d_seq_len * ds - seq_len - src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2]) + src_extra = src[src.shape[0]-1:].expand(pad, src.shape[1], src.shape[2]) src = torch.cat((src, src_extra), dim=0) assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds) src = src.reshape(d_seq_len, ds, batch_size, in_channels) scores = (src * self.query).sum(dim=-1, keepdim=True) - scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, + limit=10.0, + penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -828,12 +795,14 @@ class SimpleUpsample(torch.nn.Module): A very simple form of upsampling that mostly just repeats the input, but also adds a position-specific bias. """ - - def __init__(self, num_channels: int, upsample: int): + def __init__(self, + num_channels: int, + upsample: int): super(SimpleUpsample, self).__init__() self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01) - def forward(self, src: Tensor) -> Tensor: + def forward(self, + src: Tensor) -> Tensor: """ x: (seq_len, batch_size, num_channels) Returns a tensor of shape @@ -846,7 +815,6 @@ class SimpleUpsample(torch.nn.Module): src = src.reshape(seq_len * upsample, batch_size, num_channels) return src - class SimpleCombinerIdentity(nn.Module): def __init__(self, *args, **kwargs): super().__init__() @@ -854,7 +822,6 @@ class SimpleCombinerIdentity(nn.Module): def forward(self, src1: Tensor, src2: Tensor) -> Tensor: return src1 - class SimpleCombiner(torch.nn.Module): """ A very simple way of combining 2 vectors of 2 different dims, via a @@ -864,14 +831,18 @@ class SimpleCombiner(torch.nn.Module): dim2: the dimension of the second input, e.g. 384. The output will have the same dimension as dim2. """ - - def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)): + def __init__(self, + dim1: int, + dim2: int, + min_weight: Tuple[float] = (0., 0.)): super(SimpleCombiner, self).__init__() assert dim2 >= dim1, (dim2, dim1) self.weight1 = nn.Parameter(torch.zeros(())) self.min_weight = min_weight - def forward(self, src1: Tensor, src2: Tensor) -> Tensor: + def forward(self, + src1: Tensor, + src2: Tensor) -> Tensor: """ src1: (*, dim1) src2: (*, dim2) @@ -882,14 +853,10 @@ class SimpleCombiner(torch.nn.Module): weight1 = self.weight1 if not torch.jit.is_scripting(): - if ( - self.training - and random.random() < 0.25 - and self.min_weight != (0.0, 0.0) - ): - weight1 = weight1.clamp( - min=self.min_weight[0], max=1.0 - self.min_weight[1] - ) + if self.training and random.random() < 0.25 and self.min_weight != (0., 0.): + weight1 = weight1.clamp(min=self.min_weight[0], + max=1.0-self.min_weight[1]) + src1 = src1 * weight1 src2 = src2 * (1.0 - weight1) @@ -902,9 +869,12 @@ class SimpleCombiner(torch.nn.Module): else: src1 = src1[:src2_dim] + return src1 + src2 + + class RelPositionalEncoding(torch.nn.Module): """Relative positional encoding module. @@ -918,7 +888,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct a PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -933,7 +905,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x.size(0) * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vecotr and `j` means the @@ -981,6 +955,7 @@ class RelPositionalEncoding(torch.nn.Module): return self.dropout(pos_emb) + class RelPositionMultiheadAttention(nn.Module): r"""Multi-Head Attention layer with relative position encoding @@ -1017,46 +992,34 @@ class RelPositionMultiheadAttention(nn.Module): self.head_dim = attention_dim // num_heads self.pos_dim = pos_dim assert self.head_dim % 2 == 0, self.head_dim - assert self.head_dim * num_heads == attention_dim, ( - self.head_dim, - num_heads, - attention_dim, - ) + assert ( + self.head_dim * num_heads == attention_dim + ), (self.head_dim, num_heads, attention_dim) # the initial_scale is supposed to take over the "scaling" factor of # head_dim ** -0.5, dividing it between the query and key. - in_proj_dim = ( - 2 * attention_dim - + attention_dim // 2 # query, key - + pos_dim * num_heads # value - ) # positional encoding query + in_proj_dim = (2 * attention_dim + # query, key + attention_dim // 2 + # value + pos_dim * num_heads) # positional encoding query - self.in_proj = ScaledLinear( - embed_dim, - in_proj_dim, - bias=True, - initial_scale=self.head_dim**-0.25, - ) + self.in_proj = ScaledLinear(embed_dim, in_proj_dim, bias=True, + initial_scale=self.head_dim**-0.25) # self.whiten_values is applied on the values in forward(); # it just copies the keys but prevents low-rank distribution by modifying grads. - self.whiten_values = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) - self.whiten_keys = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) + self.whiten_values = Whiten(num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025) + self.whiten_keys = Whiten(num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025) + # linear transformation for positional encoding. - self.linear_pos = ScaledLinear( - embed_dim, num_heads * pos_dim, bias=False, initial_scale=0.05 - ) + self.linear_pos = ScaledLinear(embed_dim, num_heads * pos_dim, bias=False, + initial_scale=0.05) # the following are for diagnosics only, see --print-diagnostics option. # they only copy their inputs. @@ -1068,16 +1031,14 @@ class RelPositionMultiheadAttention(nn.Module): ) self.in_proj2 = nn.Linear(embed_dim, attention_dim // 2, bias=False) - self.out_proj2 = ScaledLinear( - attention_dim // 2, embed_dim, bias=True, initial_scale=0.05 - ) + self.out_proj2 = ScaledLinear(attention_dim // 2, embed_dim, bias=True, + initial_scale=0.05) # self.whiten_values2 is applied on the values in forward2() - self.whiten_values2 = Whiten( - num_groups=num_heads, - whitening_limit=2.0, - prob=(0.025, 0.25), - grad_scale=0.025, - ) + self.whiten_values2 = Whiten(num_groups=num_heads, + whitening_limit=2.0, + prob=(0.025, 0.25), + grad_scale=0.025) + def forward( self, @@ -1137,6 +1098,7 @@ class RelPositionMultiheadAttention(nn.Module): ) return x, weights + def multi_head_attention_forward( self, x_proj: Tensor, @@ -1194,24 +1156,26 @@ class RelPositionMultiheadAttention(nn.Module): head_dim = attention_dim // num_heads pos_dim = self.pos_dim # positional-encoding dim per head - assert head_dim * num_heads == attention_dim, ( - f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}," - f" {attention_dim}" - ) + assert ( + head_dim * num_heads == attention_dim + ), f"attention_dim must be divisible by num_heads: {head_dim}, {num_heads}, {attention_dim}" + # self-attention - q = x_proj[..., 0:attention_dim] - k = x_proj[..., attention_dim : 2 * attention_dim] + q = x_proj[...,0:attention_dim] + k = x_proj[...,attention_dim:2*attention_dim] value_dim = attention_dim // 2 - v = x_proj[..., 2 * attention_dim : 2 * attention_dim + value_dim] + v = x_proj[...,2*attention_dim:2*attention_dim+value_dim] # p is the position-encoding query, its dimension is num_heads*pos_dim.. - p = x_proj[..., 2 * attention_dim + value_dim :] + p = x_proj[...,2*attention_dim+value_dim:] + k = self.whiten_keys(k) # does nothing in the forward pass. v = self.whiten_values(v) # does nothing in the forward pass. q = self.copy_query(q) # for diagnostics only, does nothing. p = self.copy_pos_query(p) # for diagnostics only, does nothing. + if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 @@ -1231,25 +1195,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, seq_len, seq_len]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, seq_len, seq_len, ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1258,6 +1230,7 @@ class RelPositionMultiheadAttention(nn.Module): k = k.reshape(seq_len, bsz, num_heads, head_dim) v = v.reshape(seq_len, bsz * num_heads, head_dim // 2).transpose(0, 1) + if key_padding_mask is not None: assert key_padding_mask.size(0) == bsz, "{} == {}".format( key_padding_mask.size(0), bsz @@ -1266,10 +1239,13 @@ class RelPositionMultiheadAttention(nn.Module): key_padding_mask.size(1), seq_len ) + + q = q.permute(1, 2, 0, 3) # (batch, head, time1, head_dim) p = p.permute(1, 2, 0, 3) # (batch, head, time1, pos_dim) k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) + seq_len2 = 2 * seq_len - 1 pos = pos.reshape(1, seq_len2, num_heads, pos_dim).permute(0, 2, 3, 1) # pos shape now: (batch, head, pos_dim, seq_len2) @@ -1280,16 +1256,13 @@ class RelPositionMultiheadAttention(nn.Module): # the following .as_strided() expression converts the last axis of pos_weights from relative # to absolute position. I don't know whether I might have got the time-offsets backwards or # not, but let this code define which way round it is supposed to be. - pos_weights = pos_weights.as_strided( - (bsz, num_heads, seq_len, seq_len), - ( - pos_weights.stride(0), - pos_weights.stride(1), - pos_weights.stride(2) - pos_weights.stride(3), - pos_weights.stride(3), - ), - storage_offset=pos_weights.stride(3) * (seq_len - 1), - ) + pos_weights = pos_weights.as_strided((bsz, num_heads, seq_len, seq_len), + (pos_weights.stride(0), + pos_weights.stride(1), + pos_weights.stride(2)-pos_weights.stride(3), + pos_weights.stride(3)), + storage_offset=pos_weights.stride(3) * (seq_len - 1)) + # caution: they are really scores at this point. attn_output_weights = torch.matmul(q, k) + pos_weights @@ -1302,9 +1275,10 @@ class RelPositionMultiheadAttention(nn.Module): # this mechanism instead of, say, a limit on entropy, because once the entropy # gets very small gradients through the softmax can become very small, and # some mechanisms like that become ineffective. - attn_output_weights = penalize_abs_values_gt( - attn_output_weights, limit=25.0, penalty=1.0e-04 - ) + attn_output_weights = penalize_abs_values_gt(attn_output_weights, + limit=25.0, + penalty=1.0e-04) + # attn_output_weights: (batch, head, time1, time2) attn_output_weights = attn_output_weights.view( @@ -1346,20 +1320,20 @@ class RelPositionMultiheadAttention(nn.Module): ) attn_output = torch.bmm(attn_output_weights, v) - assert list(attn_output.size()) == [ - bsz * num_heads, - seq_len, - head_dim // 2, - ] + assert list(attn_output.size()) == [bsz * num_heads, seq_len, + head_dim // 2] attn_output = ( attn_output.transpose(0, 1) .contiguous() .view(seq_len, bsz, attention_dim // 2) ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias + ) return attn_output, attn_output_weights + def forward2( self, x: Tensor, @@ -1398,7 +1372,11 @@ class RelPositionMultiheadAttention(nn.Module): # returned value is of shape (seq_len, bsz, embed_dim), like x. return self.out_proj2(attn_output) - def _print_attn_stats(self, attn_weights: Tensor, attn_output: Tensor): + + def _print_attn_stats( + self, + attn_weights: Tensor, + attn_output: Tensor): # attn_weights: (batch_size * num_heads, seq_len, seq_len) # attn_output: (bsz * num_heads, seq_len, head_dim) (n, seq_len, head_dim) = attn_output.shape @@ -1409,50 +1387,39 @@ class RelPositionMultiheadAttention(nn.Module): with torch.cuda.amp.autocast(enabled=False): attn_weights = attn_weights.to(torch.float32) attn_output = attn_output.to(torch.float32) - attn_weights_entropy = ( - -((attn_weights + 1.0e-20).log() * attn_weights) - .sum(dim=-1) - .reshape(bsz, num_heads, seq_len) - .mean(dim=(0, 2)) - ) + attn_weights_entropy = -((attn_weights + 1.0e-20).log() * attn_weights).sum( + dim=-1).reshape(bsz, num_heads, seq_len).mean(dim=(0,2)) attn_output = attn_output.reshape(bsz, num_heads, seq_len, head_dim) - attn_output = attn_output.permute(1, 0, 2, 3).reshape( - num_heads, bsz * seq_len, head_dim - ) + attn_output = attn_output.permute(1, 0, 2, 3).reshape(num_heads, bsz * seq_len, head_dim) attn_output_mean = attn_output.mean(dim=1, keepdim=True) attn_output = attn_output - attn_output_mean - attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / ( - bsz * seq_len - ) + attn_covar = torch.matmul(attn_output.transpose(1, 2), attn_output) / (bsz * seq_len) # attn_covar: (num_heads, head_dim, head_dim) - # eigs, _ = torch.symeig(attn_covar) - # logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") + #eigs, _ = torch.symeig(attn_covar) + #logging.info(f"attn_weights_entropy = {attn_weights_entropy}, output_eigs = {eigs}") attn_covar = _diag(attn_covar).mean(dim=1) # (num_heads,) embed_dim = self.in_proj2.weight.shape[1] - in_proj_covar = ( - self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2 - ).mean(dim=(1, 2)) - out_proj_covar = ( - self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2 - ).mean(dim=(0, 2)) - logging.info( - f"attn_weights_entropy = {attn_weights_entropy}," - f" covar={attn_covar}, in_proj_covar={in_proj_covar}," - f" out_proj_covar={out_proj_covar}" - ) + in_proj_covar = (self.in_proj2.weight.reshape(num_heads, head_dim, embed_dim) ** 2).mean(dim=(1,2)) + out_proj_covar = (self.out_proj2.weight.reshape(embed_dim, num_heads, head_dim) ** 2).mean(dim=(0,2)) + logging.info(f"attn_weights_entropy = {attn_weights_entropy}, covar={attn_covar}, in_proj_covar={in_proj_covar}, out_proj_covar={out_proj_covar}") + + class PoolingModule(nn.Module): """ Averages the input over the time dimension and project with a square matrix. """ - - def __init__(self, d_model: int): + def __init__(self, + d_model: int): super().__init__() - self.proj = ScaledLinear(d_model, d_model, initial_scale=0.1, bias=False) + self.proj = ScaledLinear(d_model, d_model, + initial_scale=0.1, bias=False) - def forward(self, x: Tensor, key_padding_mask: Optional[Tensor] = None): + def forward(self, + x: Tensor, + key_padding_mask: Optional[Tensor] = None): """ Args: x: a Tensor of shape (T, N, C) @@ -1463,7 +1430,7 @@ class PoolingModule(nn.Module): """ if key_padding_mask is not None: pooling_mask = key_padding_mask.logical_not().to(x.dtype) # (N, T) - pooling_mask = pooling_mask / pooling_mask.sum(dim=1, keepdim=True) + pooling_mask = (pooling_mask / pooling_mask.sum(dim=1, keepdim=True)) pooling_mask = pooling_mask.transpose(0, 1).contiguous().unsqueeze(-1) # now pooling_mask: (T, N, 1) x = (x * pooling_mask).sum(dim=0, keepdim=True) @@ -1477,19 +1444,24 @@ class PoolingModule(nn.Module): class FeedforwardModule(nn.Module): - """Feedforward module in Zipformer model.""" - - def __init__(self, d_model: int, feedforward_dim: int, dropout: float): + """Feedforward module in Zipformer model. + """ + def __init__(self, + d_model: int, + feedforward_dim: int, + dropout: float): super(FeedforwardModule, self).__init__() self.in_proj = nn.Linear(d_model, feedforward_dim) - self.balancer = ActivationBalancer( - feedforward_dim, channel_dim=-1, max_abs=10.0, min_prob=0.25 - ) + self.balancer = ActivationBalancer(feedforward_dim, + channel_dim=-1, max_abs=10.0, + min_prob=0.25) self.activation = DoubleSwish() self.dropout = nn.Dropout(dropout) - self.out_proj = ScaledLinear(feedforward_dim, d_model, initial_scale=0.01) + self.out_proj = ScaledLinear(feedforward_dim, d_model, + initial_scale=0.01) - def forward(self, x: Tensor): + def forward(self, + x: Tensor): x = self.in_proj(x) x = self.balancer(x) x = self.activation(x) @@ -1509,7 +1481,9 @@ class ConvolutionModule(nn.Module): """ - def __init__(self, channels: int, kernel_size: int, bias: bool = True) -> None: + def __init__( + self, channels: int, kernel_size: int, bias: bool = True + ) -> None: """Construct an ConvolutionModule object.""" super(ConvolutionModule, self).__init__() # kernerl_size should be a odd number for 'SAME' padding @@ -1539,10 +1513,7 @@ class ConvolutionModule(nn.Module): # the correct range. self.deriv_balancer1 = ActivationBalancer( 2 * channels, - channel_dim=1, - max_abs=10.0, - min_positive=0.05, - max_positive=1.0, + channel_dim=1, max_abs=10.0, min_positive=0.05, max_positive=1.0 ) self.depthwise_conv = nn.Conv1d( @@ -1556,10 +1527,8 @@ class ConvolutionModule(nn.Module): ) self.deriv_balancer2 = ActivationBalancer( - channels, - channel_dim=1, - min_positive=0.05, - max_positive=1.0, + channels, channel_dim=1, + min_positive=0.05, max_positive=1.0, max_abs=20.0, ) @@ -1575,10 +1544,9 @@ class ConvolutionModule(nn.Module): initial_scale=0.05, ) - def forward( - self, - x: Tensor, - src_key_padding_mask: Optional[Tensor] = None, + def forward(self, + x: Tensor, + src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: """Compute convolution module. @@ -1658,7 +1626,8 @@ class Conv2dSubsampling(nn.Module): kernel_size=3, padding=(0, 1), # (time, freq) ), - ActivationBalancer(layer1_channels, channel_dim=1), + ActivationBalancer(layer1_channels, + channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer1_channels, @@ -1667,21 +1636,24 @@ class Conv2dSubsampling(nn.Module): stride=2, padding=0, ), - ActivationBalancer(layer2_channels, channel_dim=1), + ActivationBalancer(layer2_channels, + channel_dim=1), DoubleSwish(), nn.Conv2d( in_channels=layer2_channels, out_channels=layer3_channels, kernel_size=3, - stride=(1, 2), # (time, freq) + stride=(1, 2), # (time, freq) ), - ActivationBalancer(layer3_channels, channel_dim=1), + ActivationBalancer(layer3_channels, + channel_dim=1), DoubleSwish(), ) out_height = (((in_channels - 1) // 2) - 1) // 2 self.out = ScaledLinear(out_height * layer3_channels, out_channels) self.dropout = nn.Dropout(dropout) + def forward(self, x: torch.Tensor) -> torch.Tensor: """Subsample x. @@ -1702,7 +1674,6 @@ class Conv2dSubsampling(nn.Module): x = self.dropout(x) return x - class AttentionCombine(nn.Module): """ This module combines a list of Tensors, all with the same shape, to @@ -1746,12 +1717,15 @@ class AttentionCombine(nn.Module): self.random_prob = random_prob self.single_prob = single_prob - self.weight = torch.nn.Parameter(torch.zeros(num_channels, num_inputs)) + self.weight = torch.nn.Parameter(torch.zeros(num_channels, + num_inputs)) self.bias = torch.nn.Parameter(torch.zeros(num_inputs)) assert 0 <= random_prob <= 1, random_prob assert 0 <= single_prob <= 1, single_prob + + def forward(self, inputs: List[Tensor]) -> Tensor: """Forward function. Args: @@ -1782,35 +1756,28 @@ class AttentionCombine(nn.Module): if self.training: # random masking.. - mask_start = torch.randint( - low=1, - high=int(num_inputs / self.random_prob), - size=(num_frames,), - device=scores.device, - ).unsqueeze(1) + mask_start = torch.randint(low=1, high=int(num_inputs / self.random_prob), + size=(num_frames,), device=scores.device).unsqueeze(1) # mask will have rows like: [ False, False, False, True, True, .. ] - arange = ( - torch.arange(num_inputs, device=scores.device) - .unsqueeze(0) - .expand(num_frames, num_inputs) - ) + arange = torch.arange(num_inputs, device=scores.device).unsqueeze(0).expand( + num_frames, num_inputs) mask = arange >= mask_start - apply_single_prob = torch.logical_and( - torch.rand(size=(num_frames, 1), device=scores.device) - < self.single_prob, - mask_start < num_inputs, - ) - single_prob_mask = torch.logical_and( - apply_single_prob, arange < mask_start - 1 - ) + apply_single_prob = torch.logical_and(torch.rand(size=(num_frames, 1), + device=scores.device) < self.single_prob, + mask_start < num_inputs) + single_prob_mask = torch.logical_and(apply_single_prob, + arange < mask_start - 1) - mask = torch.logical_or(mask, single_prob_mask) + mask = torch.logical_or(mask, + single_prob_mask) - scores = scores.masked_fill(mask, float("-inf")) + scores = scores.masked_fill(mask, float('-inf')) if self.training and random.random() < 0.1: - scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04) + scores = penalize_abs_values_gt(scores, + limit=10.0, + penalty=1.0e-04) weights = scores.softmax(dim=1) @@ -1825,6 +1792,7 @@ class AttentionCombine(nn.Module): return ans + def _test_random_combine(): print("_test_random_combine()") num_inputs = 3 @@ -1833,8 +1801,8 @@ def _test_random_combine(): num_channels=num_channels, num_inputs=num_inputs, random_prob=0.5, - single_prob=0.0, - ) + single_prob=0.0) + x = [torch.ones(3, 4, num_channels) for _ in range(num_inputs)] @@ -1851,10 +1819,7 @@ def _test_zipformer_main(): # Just make sure the forward pass runs. c = Zipformer( - num_features=feature_dim, - encoder_dims=(64, 96), - encoder_unmasked_dims=(48, 64), - nhead=(4, 4), + num_features=feature_dim, encoder_dims=(64,96), encoder_unmasked_dims=(48,64), nhead=(4,4) ) batch_size = 5 seq_len = 20 @@ -1872,18 +1837,19 @@ def _test_zipformer_main(): ) f # to remove flake8 warnings - def _test_conv2d_subsampling(): num_features = 80 encoder_dims = 384 dropout = 0.1 - encoder_embed = Conv2dSubsampling(num_features, encoder_dims, dropout=dropout) + encoder_embed = Conv2dSubsampling(num_features, encoder_dims, + dropout=dropout) for i in range(20, 40): x = torch.rand(2, i, num_features) y = encoder_embed(x) assert (x.shape[1] - 7) // 2 == y.shape[1], (x.shape[1], y.shape[1]) + if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) torch.set_num_threads(1) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py index 822f8e44b..9d7335e77 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/decode.py @@ -165,24 +165,20 @@ def get_parser(): "--avg", type=int, default=9, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -277,7 +273,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -397,7 +394,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -456,7 +455,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -587,7 +589,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -620,7 +624,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -675,7 +680,9 @@ def main(): if "LG" in params.decoding_method: params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -712,12 +719,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -745,12 +753,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -779,7 +788,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -807,7 +816,9 @@ def main(): decoding_graph.scores *= params.ngram_lm_scale else: word_table = None - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None word_table = None diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py index 43eb0c1bc..49f469e29 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/export.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/export.py @@ -129,24 +129,20 @@ def get_parser(): "--avg", type=int, default=9, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -180,7 +176,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -220,12 +217,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -254,12 +252,13 @@ def main(): ) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -288,7 +287,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -327,7 +326,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py index ed920dc03..e79a3a3aa 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/jit_pretrained.py @@ -69,12 +69,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -95,9 +93,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -268,7 +267,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py index 39a360796..497b89136 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/model.py @@ -160,7 +160,9 @@ class Transducer(nn.Module): y_padded = y.pad(mode="constant", padding_value=0) y_padded = y_padded.to(torch.int64) - boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) + boundary = torch.zeros( + (x.size(0), 4), dtype=torch.int64, device=x.device + ) boundary[:, 2] = y_lens boundary[:, 3] = x_lens diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py index 716136812..373a48fc1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/pretrained.py @@ -100,11 +100,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -129,12 +127,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -181,7 +177,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -212,9 +209,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -277,11 +275,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -353,7 +355,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py index 381a86a67..2603bb854 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless8/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless8/train.py @@ -92,7 +92,9 @@ from icefall.env import get_env_info from icefall.hooks import register_inf_check_hooks from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: @@ -130,10 +132,7 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-dims", type=str, default="384,384,384,384,384", - help=( - "Embedding dimension in the 2 blocks of zipformer encoder layers, comma" - " separated" - ), + help="Embedding dimension in the 2 blocks of zipformer encoder layers, comma separated", ) parser.add_argument( @@ -148,11 +147,9 @@ def add_model_arguments(parser: argparse.ArgumentParser): "--encoder-unmasked-dims", type=str, default="256,256,256,256,256", - help=( - "Unmasked dimensions in the encoders, relates to augmentation during" - " training. Must be <= each of encoder_dims. Empirically, less than 256" - " seems to make performance worse." - ), + help="Unmasked dimensions in the encoders, relates to augmentation during training. " + "Must be <= each of encoder_dims. Empirically, less than 256 seems to make performance " + " worse.", ) parser.add_argument( @@ -217,7 +214,8 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) parser.add_argument( @@ -287,45 +285,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -696,7 +691,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -745,7 +744,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -951,7 +952,9 @@ def train_one_epoch( # of the grad scaler is configurable, but we can't configure it to have different # behavior depending on the current grad scale. cur_grad_scale = scaler._scale.item() - if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + if cur_grad_scale < 1.0 or ( + cur_grad_scale < 8.0 and batch_idx % 400 == 0 + ): scaler.update(cur_grad_scale * 2.0) if cur_grad_scale < 0.01: logging.warning(f"Grad scale is small: {cur_grad_scale}") @@ -972,7 +975,11 @@ def train_one_epoch( f"giga_tot_loss[{giga_tot_loss}], " f"batch size: {batch_size}, " f"lr: {cur_lr:.2e}, " - + (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "") + + ( + f"grad_scale: {scaler._scale.item()}" + if params.use_fp16 + else "" + ) ) if tb_writer is not None: @@ -985,8 +992,12 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -1000,7 +1011,10 @@ def train_one_epoch( params.batch_idx_train, ) - if batch_idx % params.valid_interval == 0 and not params.print_diagnostics: + if ( + batch_idx % params.valid_interval == 0 + and not params.print_diagnostics + ): logging.info("Computing validation loss") valid_info = compute_validation_loss( params=params, @@ -1012,8 +1026,7 @@ def train_one_epoch( model.train() logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") logging.info( - "Maximum memory allocated so far is" - f" {torch.cuda.max_memory_allocated()//1000000}MB" + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) if tb_writer is not None: valid_info.write_summary( @@ -1041,7 +1054,8 @@ def filter_short_and_long_utterances( # the threshold if c.duration < 1.0 or c.duration > 20.0: logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + f"Exclude cut with ID {c.id} from training. " + f"Duration: {c.duration}" ) return False @@ -1138,7 +1152,9 @@ def run(rank, world_size, args): logging.info("Using DDP") model = DDP(model, device_ids=[rank], find_unused_parameters=True) - optimizer = ScaledAdam(model.parameters(), lr=params.base_lr, clipping_scale=2.0) + optimizer = ScaledAdam( + model.parameters(), lr=params.base_lr, clipping_scale=2.0 + ) scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) @@ -1156,7 +1172,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1191,7 +1207,9 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) else: cuts_musan = None @@ -1346,8 +1364,7 @@ def scan_pessimistic_batches_for_oom( display_and_save_batch(batch, params=params, sp=sp) raise logging.info( - "Maximum memory allocated so far is" - f" {torch.cuda.max_memory_allocated()//1000000}MB" + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" ) diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/README.md b/egs/librispeech/ASR/streaming_conformer_ctc/README.md index 53f383c99..01be7090b 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/README.md +++ b/egs/librispeech/ASR/streaming_conformer_ctc/README.md @@ -1,20 +1,20 @@ ## Train and Decode -Commands of data preparation/train/decode steps are almost the same with +Commands of data preparation/train/decode steps are almost the same with ../conformer_ctc experiment except some options. Please read the code and understand following new added options before running this experiment: For data preparation: - + Nothing new. For streaming_conformer_ctc/train.py: - + --dynamic-chunk-training --short-chunk-proportion For streaming_conformer_ctc/streaming_decode.py: - + --chunk-size --tailing-num-frames --simulate-streaming @@ -57,10 +57,10 @@ And check md5sum values again. Finally, following files will be downloaded: -streaming_models/ -|-- lang_bpe -| |-- L.pt -| |-- Linv.pt +streaming_models/ +|-- lang_bpe +| |-- L.pt +| |-- Linv.pt | |-- bpe.model | |-- tokens.txt | `-- words.txt diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py index 4f7427c1f..ff4c91446 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/conformer.py @@ -309,26 +309,36 @@ class Conformer(Transformer): # start chunk_by_chunk decoding offset = 0 - for cur in range(0, num_frames - embed_left_context + 1, stride): + for cur in range( + 0, num_frames - embed_left_context + 1, stride + ): end = min(cur + decoding_window, num_frames) cur_feature = feature[:, cur:end, :] cur_feature = self.encoder_embed(cur_feature) - cur_embed, cur_pos_emb = self.encoder_pos(cur_feature, offset) - cur_embed = cur_embed.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + cur_embed, cur_pos_emb = self.encoder_pos( + cur_feature, offset + ) + cur_embed = cur_embed.permute( + 1, 0, 2 + ) # (B, T, F) -> (T, B, F) cur_T = cur_feature.size(1) if cur == 0: # for first chunk extract the central pos embedding - pos_emb_central = cur_pos_emb[0, (chunk_size - 1), :].view( - 1, 1, -1 - ) + pos_emb_central = cur_pos_emb[ + 0, (chunk_size - 1), : + ].view(1, 1, -1) cur_T -= 1 pos_emb_positive.append(cur_pos_emb[0, :cur_T].flip(0)) pos_emb_negative.append(cur_pos_emb[0, -cur_T:]) assert pos_emb_positive[-1].size(0) == cur_T - pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze(0) - pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze(0) + pos_emb_pos = torch.cat(pos_emb_positive, dim=0).unsqueeze( + 0 + ) + pos_emb_neg = torch.cat(pos_emb_negative, dim=0).unsqueeze( + 0 + ) cur_pos_emb = torch.cat( [pos_emb_pos.flip(1), pos_emb_central, pos_emb_neg], dim=1, @@ -403,7 +413,9 @@ class ConformerEncoderLayer(nn.Module): causal: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -419,16 +431,22 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(dim_feedforward, d_model), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, causal=causal + ) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -462,7 +480,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -534,7 +554,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -714,7 +736,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -731,7 +755,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -757,7 +783,9 @@ class RelPositionalEncoding(torch.nn.Module): pe = torch.cat([pe_positive, pe_negative], dim=1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward(self, x: torch.Tensor, offset: int = 0) -> Tuple[Tensor, Tensor]: + def forward( + self, x: torch.Tensor, offset: int = 0 + ) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -785,7 +813,9 @@ class RelPositionalEncoding(torch.nn.Module): pos_emb = torch.cat( [ pos_emb[:, : (x_T - 1)], - self.pe[0, self.pe.size(1) // 2].view(1, 1, self.pe.size(-1)), + self.pe[0, self.pe.size(1) // 2].view( + 1, 1, self.pe.size(-1) + ), pos_emb[:, -(x_T - 1) :], # noqa: E203 ], dim=1, @@ -1020,9 +1050,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -1090,25 +1120,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1147,16 +1185,24 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) - matrix_bd = self.rel_shift(matrix_bd, offset=offset) # [B, head, time1, time2] + matrix_bd = torch.matmul( + q_with_bias_v, p + ) # (batch, head, time1, 2*time1-1) + matrix_bd = self.rel_shift( + matrix_bd, offset=offset + ) # [B, head, time1, time2] attn_output_weights = ( matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -1190,9 +1236,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py index 5a8149aad..a74c51836 100755 --- a/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/streaming_decode.py @@ -28,7 +28,6 @@ import torch import torch.nn as nn from asr_datamodule import LibriSpeechAsrDataModule from conformer import Conformer - from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.lexicon import Lexicon @@ -63,36 +62,32 @@ def get_parser(): "--epoch", type=int, default=34, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--chunk-size", type=int, default=8, - help=( - "Frames of right context" - "-1 for whole right context, i.e. non-streaming decoding" - ), + help="Frames of right context" + "-1 for whole right context, i.e. non-streaming decoding", ) parser.add_argument( "--tailing-num-frames", type=int, default=20, - help="tailing dummy frames padded to the right,only used during decoding", + help="tailing dummy frames padded to the right," + "only used during decoding", ) parser.add_argument( @@ -144,7 +139,8 @@ def get_parser(): "--avg-models", type=str, default=None, - help="Manually select models to average, seperated by comma;e.g. 60,62,63,72", + help="Manually select models to average, seperated by comma;" + "e.g. 60,62,63,72", ) return parser @@ -252,9 +248,13 @@ def decode_one_batch( maxlen = nnet_output.size(1) topk_prob, topk_index = nnet_output.topk(1, dim=2) # (B, maxlen, 1) topk_index = topk_index.view(batch_size, maxlen) # (B, maxlen) - topk_index = topk_index.masked_fill_(memory_key_padding_mask, 0) # (B, maxlen) + topk_index = topk_index.masked_fill_( + memory_key_padding_mask, 0 + ) # (B, maxlen) token_ids = [token_id.tolist() for token_id in topk_index] - token_ids = [remove_duplicates_and_blank(token_id) for token_id in token_ids] + token_ids = [ + remove_duplicates_and_blank(token_id) for token_id in token_ids + ] hyps = bpe_model.decode(token_ids) hyps = [s.split() for s in hyps] return {key: hyps} @@ -337,7 +337,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -355,18 +357,15 @@ def save_results( test_set_wers = dict() if params.avg_models is not None: avg_models = params.avg_models.replace(",", "_") - result_file_prefix = ( - f"epoch-avg-{avg_models}-chunksize " - f" -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-" - ) + result_file_prefix = f"epoch-avg-{avg_models}-chunksize \ + -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-" else: - result_file_prefix = ( - f"epoch-{params.epoch}-avg-{params.avg}-chunksize " - f" -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-" - ) + result_file_prefix = f"epoch-{params.epoch}-avg-{params.avg}-chunksize \ + -{params.chunk_size}-tailing-num-frames-{params.tailing_num_frames}-" for key, results in results_dict.items(): recog_path = ( - params.exp_dir / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt" + params.exp_dir + / f"{result_file_prefix}recogs-{test_set_name}-{key}.txt" ) store_transcripts(filename=recog_path, texts=results) if enable_log: @@ -375,7 +374,8 @@ def save_results( # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = ( - params.exp_dir / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt" + params.exp_dir + / f"{result_file_prefix}-errs-{test_set_name}-{key}.txt" ) with open(errs_filename, "w") as f: wer = write_error_stats( @@ -384,7 +384,9 @@ def save_results( test_set_wers[key] = wer if enable_log: - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -472,7 +474,9 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) return model.to(device) @@ -503,7 +507,9 @@ def main(): simulate_streaming=params.simulate_streaming, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/train.py b/egs/librispeech/ASR/streaming_conformer_ctc/train.py index 553b7d092..e41b7ea78 100755 --- a/egs/librispeech/ASR/streaming_conformer_ctc/train.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/train.py @@ -405,7 +405,9 @@ def compute_loss( # # See https://github.com/k2-fsa/icefall/issues/97 # for more details - unsorted_token_ids = graph_compiler.texts_to_ids(supervisions["text"]) + unsorted_token_ids = graph_compiler.texts_to_ids( + supervisions["text"] + ) att_loss = mmodel.decoder_forward( encoder_memory, memory_mask, @@ -434,7 +436,9 @@ def compute_loss( info["utt_duration"] = supervisions["num_frames"].sum().item() # averaged padding proportion over utterances info["utt_pad_proportion"] = ( - ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)).sum().item() + ((feature.size(1) - supervisions["num_frames"]) / feature.size(1)) + .sum() + .item() ) return loss, info @@ -547,7 +551,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -662,7 +668,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py index 0c87fdf1b..bc78e4a41 100644 --- a/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/streaming_conformer_ctc/transformer.py @@ -149,7 +149,9 @@ class Transformer(nn.Module): norm=decoder_norm, ) - self.decoder_output_layer = torch.nn.Linear(d_model, self.decoder_num_class) + self.decoder_output_layer = torch.nn.Linear( + d_model, self.decoder_num_class + ) self.decoder_criterion = LabelSmoothingLoss() else: @@ -284,17 +286,23 @@ class Transformer(nn.Module): """ ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device) ys_out_pad = ys_out_pad.to(device) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -355,17 +363,23 @@ class Transformer(nn.Module): ys_in = add_sos(token_ids, sos_id=sos_id) ys_in = [torch.tensor(y) for y in ys_in] - ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=float(eos_id)) + ys_in_pad = pad_sequence( + ys_in, batch_first=True, padding_value=float(eos_id) + ) ys_out = add_eos(token_ids, eos_id=eos_id) ys_out = [torch.tensor(y) for y in ys_out] - ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=float(-1)) + ys_out_pad = pad_sequence( + ys_out, batch_first=True, padding_value=float(-1) + ) device = memory.device ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) - tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to(device) + tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( + device + ) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) # TODO: Use length information to create the decoder padding mask @@ -638,7 +652,9 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) class PositionalEncoding(nn.Module): @@ -840,7 +856,9 @@ def encoder_padding_mask( 1, ).to(torch.int32) - lengths = [0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1)] + lengths = [ + 0 for _ in range(int(supervision_segments[:, 0].max().item()) + 1) + ] for idx in range(supervision_segments.size(0)): # Note: TorchScript doesn't allow to unpack tensors as tuples sequence_idx = supervision_segments[idx, 0].item() @@ -861,7 +879,9 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: torch.Tensor, ignore_id: int = -1) -> torch.Tensor: +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: """Generate a length mask for input. The masked position are filled with True, diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py index 63afd6be2..355ccc99a 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -77,18 +77,17 @@ class LibriSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) group.add_argument( "--manifest-dir", @@ -100,74 +99,59 @@ class LibriSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--drop-last", @@ -179,18 +163,17 @@ class LibriSpeechAsrDataModule: "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -204,22 +187,18 @@ class LibriSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", ) group.add_argument( @@ -245,16 +224,20 @@ class LibriSpeechAsrDataModule: if self.args.enable_musan: logging.info("Enable MUSAN") logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -269,7 +252,9 @@ class LibriSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -313,7 +298,9 @@ class LibriSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -369,7 +356,9 @@ class LibriSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 94ba0a4dc..7d0cd0bf3 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -57,19 +57,16 @@ def get_parser(): "--epoch", type=int, default=19, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=5, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--method", @@ -339,7 +336,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -401,7 +400,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -466,7 +467,9 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) return model.to(device) @@ -495,7 +498,9 @@ def main(): G=G, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py index 1731e1ebe..5e04c11b4 100644 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/model.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/model.py @@ -66,7 +66,10 @@ class TdnnLstm(nn.Module): nn.BatchNorm1d(num_features=500, affine=False), ) self.lstms = nn.ModuleList( - [nn.LSTM(input_size=500, hidden_size=500, num_layers=1) for _ in range(5)] + [ + nn.LSTM(input_size=500, hidden_size=500, num_layers=1) + for _ in range(5) + ] ) self.lstm_bnorms = nn.ModuleList( [nn.BatchNorm1d(num_features=500, affine=False) for _ in range(5)] diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py index 722e8f003..2baeb6bba 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/pretrained.py @@ -29,7 +29,11 @@ import torchaudio from model import TdnnLstm from torch.nn.utils.rnn import pad_sequence -from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_whole_lattice, +) from icefall.utils import AttributeDict, get_texts @@ -42,11 +46,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -56,7 +58,9 @@ def get_parser(): help="Path to words.txt", ) - parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") + parser.add_argument( + "--HLG", type=str, required=True, help="Path to HLG.pt." + ) parser.add_argument( "--method", @@ -99,12 +103,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -142,9 +144,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -212,7 +215,9 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) features = features.permute(0, 2, 1) # now features is (N, C, T) with torch.no_grad(): @@ -264,7 +269,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index 071ac792b..6b37d5c23 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -355,7 +355,9 @@ def compute_loss( info["utt_duration"] = supervisions["num_frames"].sum().item() # averaged padding proportion over utterances info["utt_pad_proportion"] = ( - ((feature.size(2) - supervisions["num_frames"]) / feature.size(2)).sum().item() + ((feature.size(2) - supervisions["num_frames"]) / feature.size(2)) + .sum() + .item() ) return loss, info @@ -468,7 +470,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: valid_info = compute_validation_loss( diff --git a/egs/librispeech/ASR/transducer/beam_search.py b/egs/librispeech/ASR/transducer/beam_search.py index b45b6a9d8..11032f31a 100644 --- a/egs/librispeech/ASR/transducer/beam_search.py +++ b/egs/librispeech/ASR/transducer/beam_search.py @@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: blank_id = model.decoder.blank_id device = model.device - sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1) + sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape( + 1, 1 + ) decoder_out, (h, c) = model.decoder(sos) T = encoder_out.size(1) t = 0 @@ -121,7 +123,9 @@ def beam_search( max_u = 20000 # terminate after this number of steps u = 0 - cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {} + cache: Dict[ + str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = {} while t < T and u < max_u: # fmt: off @@ -153,9 +157,9 @@ def beam_search( cached_key = "_".join(map(str, y_star.ys)) if cached_key not in cache: - decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape( - 1, 1 - ) + decoder_input = torch.tensor( + [y_star.ys[-1]], device=device + ).reshape(1, 1) decoder_out, decoder_state = model.decoder( decoder_input, diff --git a/egs/librispeech/ASR/transducer/decode.py b/egs/librispeech/ASR/transducer/decode.py index f30332cea..5f233df87 100755 --- a/egs/librispeech/ASR/transducer/decode.py +++ b/egs/librispeech/ASR/transducer/decode.py @@ -71,19 +71,16 @@ def get_parser(): "--epoch", type=int, default=34, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=11, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -231,7 +228,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] batch_size = encoder_out.size(0) @@ -246,7 +245,9 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) hyps.append(sp.decode(hyp).split()) if params.decoding_method == "greedy_search": @@ -317,7 +318,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -350,7 +353,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/transducer/export.py b/egs/librispeech/ASR/transducer/export.py index 4d9f937f5..5a5db30c4 100755 --- a/egs/librispeech/ASR/transducer/export.py +++ b/egs/librispeech/ASR/transducer/export.py @@ -67,20 +67,17 @@ def get_parser(): "--epoch", type=int, default=34, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=11, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -241,7 +238,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer/pretrained.py b/egs/librispeech/ASR/transducer/pretrained.py index 7aadfbcd1..1db2df648 100755 --- a/egs/librispeech/ASR/transducer/pretrained.py +++ b/egs/librispeech/ASR/transducer/pretrained.py @@ -60,11 +60,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -89,12 +87,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -192,9 +188,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -252,7 +249,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -288,7 +287,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer/rnn.py b/egs/librispeech/ASR/transducer/rnn.py index fe8732301..2a165b0c1 100644 --- a/egs/librispeech/ASR/transducer/rnn.py +++ b/egs/librispeech/ASR/transducer/rnn.py @@ -117,8 +117,12 @@ class LayerNormLSTMCell(nn.Module): ) if bias: - self.bias_ih = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs)) - self.bias_hh = nn.Parameter(torch.empty(4 * hidden_size, **factory_kwargs)) + self.bias_ih = nn.Parameter( + torch.empty(4 * hidden_size, **factory_kwargs) + ) + self.bias_hh = nn.Parameter( + torch.empty(4 * hidden_size, **factory_kwargs) + ) else: self.register_parameter("bias_ih", None) self.register_parameter("bias_hh", None) @@ -344,7 +348,9 @@ class LayerNormLSTM(nn.Module): device=device, dtype=dtype, ) - first_layer = LayerNormLSTMLayer(input_size=input_size, **factory_kwargs) + first_layer = LayerNormLSTMLayer( + input_size=input_size, **factory_kwargs + ) layers = [first_layer] for i in range(1, num_layers): layers.append( @@ -379,7 +385,9 @@ class LayerNormLSTM(nn.Module): - List[(next_h, next_c)] containing the hidden states for all layers """ - output_states = torch.jit.annotate(List[Tuple[torch.Tensor, torch.Tensor]], []) + output_states = torch.jit.annotate( + List[Tuple[torch.Tensor, torch.Tensor]], [] + ) output = input for i, rnn_layer in enumerate(self.layers): state = states[i] @@ -448,8 +456,12 @@ class LayerNormGRUCell(nn.Module): ) if bias: - self.bias_ih = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs)) - self.bias_hh = nn.Parameter(torch.empty(3 * hidden_size, **factory_kwargs)) + self.bias_ih = nn.Parameter( + torch.empty(3 * hidden_size, **factory_kwargs) + ) + self.bias_hh = nn.Parameter( + torch.empty(3 * hidden_size, **factory_kwargs) + ) else: self.register_parameter("bias_ih", None) self.register_parameter("bias_hh", None) diff --git a/egs/librispeech/ASR/transducer/test_rnn.py b/egs/librispeech/ASR/transducer/test_rnn.py index 74c94cc70..8591e2d8a 100755 --- a/egs/librispeech/ASR/transducer/test_rnn.py +++ b/egs/librispeech/ASR/transducer/test_rnn.py @@ -254,7 +254,9 @@ def test_layernorm_lstm_layer_with_projection_forward(device="cpu"): for name, self_param in self_layer.cell.named_parameters(): getattr(torch_layer, f"{name}_l0").copy_(self_param) - torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0))) + torch_y, (torch_h, torch_c) = torch_layer( + x_clone, (h.unsqueeze(0), c.unsqueeze(0)) + ) assert_allclose(self_y, torch_y) assert_allclose(self_h, torch_h) assert_allclose(self_c, torch_c) @@ -301,7 +303,9 @@ def test_layernorm_lstm_layer_forward(device="cpu"): for name, self_param in self_layer.cell.named_parameters(): getattr(torch_layer, f"{name}_l0").copy_(self_param) - torch_y, (torch_h, torch_c) = torch_layer(x_clone, (h.unsqueeze(0), c.unsqueeze(0))) + torch_y, (torch_h, torch_c) = torch_layer( + x_clone, (h.unsqueeze(0), c.unsqueeze(0)) + ) assert_allclose(self_y, torch_y) assert_allclose(self_h, torch_h) assert_allclose(self_c, torch_c) @@ -590,7 +594,9 @@ def test_layernorm_gru_cell_forward(device="cpu"): assert_allclose(self_h, torch_h, atol=1e-5) - (self_h.reshape(-1) * torch.arange(self_h.numel(), device=device)).sum().backward() + ( + self_h.reshape(-1) * torch.arange(self_h.numel(), device=device) + ).sum().backward() ( torch_h.reshape(-1) * torch.arange(torch_h.numel(), device=device) ).sum().backward() @@ -712,7 +718,9 @@ def test_layernorm_gru_forward(device="cpu"): T = torch.randint(low=2, high=100, size=(1,)) x = torch.rand(N, T, input_size, device=device).requires_grad_() - states = [torch.rand(N, hidden_size, device=device) for _ in range(num_layers)] + states = [ + torch.rand(N, hidden_size, device=device) for _ in range(num_layers) + ] x_clone = x.detach().clone().requires_grad_() diff --git a/egs/librispeech/ASR/transducer/train.py b/egs/librispeech/ASR/transducer/train.py index 674ea10a6..1dd65eddb 100755 --- a/egs/librispeech/ASR/transducer/train.py +++ b/egs/librispeech/ASR/transducer/train.py @@ -396,7 +396,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -518,7 +520,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -655,7 +659,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/transducer_lstm/beam_search.py b/egs/librispeech/ASR/transducer_lstm/beam_search.py index 5342c3e8c..3531a9633 100644 --- a/egs/librispeech/ASR/transducer_lstm/beam_search.py +++ b/egs/librispeech/ASR/transducer_lstm/beam_search.py @@ -38,7 +38,9 @@ def greedy_search(model: Transducer, encoder_out: torch.Tensor) -> List[int]: blank_id = model.decoder.blank_id device = model.device - sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape(1, 1) + sos = torch.tensor([blank_id], device=device, dtype=torch.int64).reshape( + 1, 1 + ) decoder_out, (h, c) = model.decoder(sos) T = encoder_out.size(1) t = 0 @@ -122,7 +124,9 @@ def beam_search( max_u = 20000 # terminate after this number of steps u = 0 - cache: Dict[str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = {} + cache: Dict[ + str, Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] + ] = {} while t < T and u < max_u: # fmt: off @@ -154,9 +158,9 @@ def beam_search( cached_key = "_".join(map(str, y_star.ys)) if cached_key not in cache: - decoder_input = torch.tensor([y_star.ys[-1]], device=device).reshape( - 1, 1 - ) + decoder_input = torch.tensor( + [y_star.ys[-1]], device=device + ).reshape(1, 1) decoder_out, decoder_state = model.decoder( decoder_input, diff --git a/egs/librispeech/ASR/transducer_lstm/decode.py b/egs/librispeech/ASR/transducer_lstm/decode.py index 61b9de504..604235e2a 100755 --- a/egs/librispeech/ASR/transducer_lstm/decode.py +++ b/egs/librispeech/ASR/transducer_lstm/decode.py @@ -71,19 +71,16 @@ def get_parser(): "--epoch", type=int, default=77, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=55, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -228,7 +225,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] batch_size = encoder_out.size(0) @@ -243,7 +242,9 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) hyps.append(sp.decode(hyp).split()) if params.decoding_method == "greedy_search": @@ -314,7 +315,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -347,7 +350,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/librispeech/ASR/transducer_lstm/encoder.py b/egs/librispeech/ASR/transducer_lstm/encoder.py index 038d80077..3dc992dd2 100644 --- a/egs/librispeech/ASR/transducer_lstm/encoder.py +++ b/egs/librispeech/ASR/transducer_lstm/encoder.py @@ -48,7 +48,9 @@ class LstmEncoder(EncoderInterface): if vgg_frontend: self.encoder_embed = VggSubsampling(num_features, real_hidden_size) else: - self.encoder_embed = Conv2dSubsampling(num_features, real_hidden_size) + self.encoder_embed = Conv2dSubsampling( + num_features, real_hidden_size + ) self.rnn = nn.LSTM( input_size=hidden_size, diff --git a/egs/librispeech/ASR/transducer_lstm/train.py b/egs/librispeech/ASR/transducer_lstm/train.py index 57bda63fd..cdb801e79 100755 --- a/egs/librispeech/ASR/transducer_lstm/train.py +++ b/egs/librispeech/ASR/transducer_lstm/train.py @@ -400,7 +400,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -522,7 +524,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -661,7 +665,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/transducer_stateless/alignment.py b/egs/librispeech/ASR/transducer_stateless/alignment.py index 65f2c58d8..f143611ea 100644 --- a/egs/librispeech/ASR/transducer_stateless/alignment.py +++ b/egs/librispeech/ASR/transducer_stateless/alignment.py @@ -193,7 +193,9 @@ def force_alignment( decoder_out = model.decoder(decoder_input, need_pad=False) # decoder_output is of shape (num_active_items, 1, decoder_output_dim) - current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1) + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, -1 + ) logits = model.joiner( current_encoder_out, diff --git a/egs/librispeech/ASR/transducer_stateless/beam_search.py b/egs/librispeech/ASR/transducer_stateless/beam_search.py index 1d79eef9d..ea985f30d 100644 --- a/egs/librispeech/ASR/transducer_stateless/beam_search.py +++ b/egs/librispeech/ASR/transducer_stateless/beam_search.py @@ -316,9 +316,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id: hyp.append(y) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -478,7 +478,9 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -494,7 +496,9 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -782,7 +786,9 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) @@ -881,7 +887,9 @@ def _deprecated_modified_beam_search( decoder_out = model.decoder(decoder_input, need_pad=False) # decoder_output is of shape (num_hyps, 1, decoder_output_dim) - current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1) + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, -1 + ) logits = model.joiner( current_encoder_out, @@ -951,9 +959,9 @@ def beam_search( device = model.device - decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) diff --git a/egs/librispeech/ASR/transducer_stateless/compute_ali.py b/egs/librispeech/ASR/transducer_stateless/compute_ali.py index 89992856d..48769e9d1 100755 --- a/egs/librispeech/ASR/transducer_stateless/compute_ali.py +++ b/egs/librispeech/ASR/transducer_stateless/compute_ali.py @@ -54,19 +54,16 @@ def get_parser(): "--epoch", type=int, default=34, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -127,7 +124,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -164,7 +162,9 @@ def compute_alignments( feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) batch_size = encoder_out.size(0) @@ -204,7 +204,9 @@ def compute_alignments( if batch_idx % 2 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return CutSet.from_cuts(cuts) diff --git a/egs/librispeech/ASR/transducer_stateless/conformer.py b/egs/librispeech/ASR/transducer_stateless/conformer.py index d279eae85..cde52c9fc 100644 --- a/egs/librispeech/ASR/transducer_stateless/conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/conformer.py @@ -209,7 +209,10 @@ class Conformer(Transformer): NOTE: the returned tensors are on the given device. """ - if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: + if ( + len(self._init_state) == 2 + and self._init_state[0].size(1) == left_context + ): # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -418,7 +421,9 @@ class ConformerEncoderLayer(nn.Module): causal: bool = False, ) -> None: super(ConformerEncoderLayer, self).__init__() - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( nn.Linear(d_model, dim_feedforward), @@ -434,16 +439,22 @@ class ConformerEncoderLayer(nn.Module): nn.Linear(dim_feedforward, d_model), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, causal=causal + ) - self.norm_ff_macaron = nn.LayerNorm(d_model) # for the macaron style FNN module + self.norm_ff_macaron = nn.LayerNorm( + d_model + ) # for the macaron style FNN module self.norm_ff = nn.LayerNorm(d_model) # for the FNN module self.norm_mha = nn.LayerNorm(d_model) # for the MHA module self.ff_scale = 0.5 self.norm_conv = nn.LayerNorm(d_model) # for the CNN module - self.norm_final = nn.LayerNorm(d_model) # for the final output of the block + self.norm_final = nn.LayerNorm( + d_model + ) # for the final output of the block self.dropout = nn.Dropout(dropout) @@ -475,7 +486,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -501,7 +514,9 @@ class ConformerEncoderLayer(nn.Module): if self.normalize_before: src = self.norm_conv(src) - src, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + src, _ = self.conv_module( + src, src_key_padding_mask=src_key_padding_mask + ) src = residual + self.dropout(src) if not self.normalize_before: @@ -566,7 +581,9 @@ class ConformerEncoderLayer(nn.Module): residual = src if self.normalize_before: src = self.norm_ff_macaron(src) - src = residual + self.ff_scale * self.dropout(self.feed_forward_macaron(src)) + src = residual + self.ff_scale * self.dropout( + self.feed_forward_macaron(src) + ) if not self.normalize_before: src = self.norm_ff_macaron(src) @@ -608,7 +625,9 @@ class ConformerEncoderLayer(nn.Module): if self.normalize_before: src = self.norm_conv(src) - src, conv_cache = self.conv_module(src, states[1], right_context=right_context) + src, conv_cache = self.conv_module( + src, states[1], right_context=right_context + ) states[1] = conv_cache src = residual + self.dropout(src) @@ -760,7 +779,9 @@ class RelPositionalEncoding(torch.nn.Module): """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -777,7 +798,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -803,7 +826,9 @@ class RelPositionalEncoding(torch.nn.Module): pe = torch.cat([pe_positive, pe_negative], dim=1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward(self, x: torch.Tensor, left_context: int = 0) -> Tuple[Tensor, Tensor]: + def forward( + self, x: torch.Tensor, left_context: int = 0 + ) -> Tuple[Tensor, Tensor]: """Add positional encoding. Args: @@ -1067,9 +1092,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -1138,25 +1163,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1195,10 +1228,14 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul( + q_with_bias_v, p + ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context=left_context) @@ -1206,7 +1243,9 @@ class RelPositionMultiheadAttention(nn.Module): matrix_ac + matrix_bd ) * scaling # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -1251,7 +1290,9 @@ class RelPositionMultiheadAttention(nn.Module): attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.masked_fill( + combined_mask, 0.0 + ) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1263,9 +1304,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1373,12 +1418,16 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert not self.training, "Cache should be None in training time" + assert ( + not self.training + ), "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : (-right_context), # noqa + -(self.lorder + right_context) : ( # noqa + -right_context + ), ..., ] else: diff --git a/egs/librispeech/ASR/transducer_stateless/decode.py b/egs/librispeech/ASR/transducer_stateless/decode.py index 314f49154..74bba9cad 100755 --- a/egs/librispeech/ASR/transducer_stateless/decode.py +++ b/egs/librispeech/ASR/transducer_stateless/decode.py @@ -94,19 +94,16 @@ def get_parser(): "--epoch", type=int, default=29, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=13, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -174,7 +171,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -232,7 +230,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -248,7 +248,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -294,7 +297,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -367,7 +374,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -400,7 +409,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -440,7 +450,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/egs/librispeech/ASR/transducer_stateless/decoder.py b/egs/librispeech/ASR/transducer_stateless/decoder.py index a182d91e2..fbc2373a9 100644 --- a/egs/librispeech/ASR/transducer_stateless/decoder.py +++ b/egs/librispeech/ASR/transducer_stateless/decoder.py @@ -87,7 +87,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/librispeech/ASR/transducer_stateless/export.py b/egs/librispeech/ASR/transducer_stateless/export.py index 7c10b4348..8bd0bdea1 100755 --- a/egs/librispeech/ASR/transducer_stateless/export.py +++ b/egs/librispeech/ASR/transducer_stateless/export.py @@ -68,20 +68,17 @@ def get_parser(): "--epoch", type=int, default=20, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -112,7 +109,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -246,7 +244,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer_stateless/joiner.py b/egs/librispeech/ASR/transducer_stateless/joiner.py index e1625992d..93cccbd8c 100644 --- a/egs/librispeech/ASR/transducer_stateless/joiner.py +++ b/egs/librispeech/ASR/transducer_stateless/joiner.py @@ -60,9 +60,13 @@ class Joiner(nn.Module): encoder_out_len: List[int] = encoder_out_len.tolist() decoder_out_len: List[int] = decoder_out_len.tolist() - encoder_out_list = [encoder_out[i, : encoder_out_len[i], :] for i in range(N)] + encoder_out_list = [ + encoder_out[i, : encoder_out_len[i], :] for i in range(N) + ] - decoder_out_list = [decoder_out[i, : decoder_out_len[i], :] for i in range(N)] + decoder_out_list = [ + decoder_out[i, : decoder_out_len[i], :] for i in range(N) + ] x = [ e.unsqueeze(1) + d.unsqueeze(0) diff --git a/egs/librispeech/ASR/transducer_stateless/pretrained.py b/egs/librispeech/ASR/transducer_stateless/pretrained.py index bd7eeff28..b64521801 100755 --- a/egs/librispeech/ASR/transducer_stateless/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless/pretrained.py @@ -90,11 +90,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -119,12 +117,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -171,7 +167,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -200,9 +197,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -261,7 +259,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -334,7 +334,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py index 9af46846a..b00fc34f1 100755 --- a/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py +++ b/egs/librispeech/ASR/transducer_stateless/test_compute_ali.py @@ -140,13 +140,16 @@ def main(): token_alignment[i, : token_alignment_length[i]].tolist(), sp=sp ) word_starting_time = [ - "{:.2f}".format(i * frame_shift_in_second) for i in word_starting_frames + "{:.2f}".format(i * frame_shift_in_second) + for i in word_starting_frames ] words = supervisions["text"][i].split() assert len(word_starting_frames) == len(words) - word_starting_time_dict[cuts[i].id] = list(zip(words, word_starting_time)) + word_starting_time_dict[cuts[i].id] = list( + zip(words, word_starting_time) + ) # This is a demo script and we exit here after processing # one batch. @@ -157,7 +160,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer_stateless/test_conformer.py b/egs/librispeech/ASR/transducer_stateless/test_conformer.py index 65b08d425..d1350c8ab 100755 --- a/egs/librispeech/ASR/transducer_stateless/test_conformer.py +++ b/egs/librispeech/ASR/transducer_stateless/test_conformer.py @@ -29,7 +29,9 @@ from conformer import Conformer def test_conformer(): feature_dim = 50 - c = Conformer(num_features=feature_dim, output_dim=256, d_model=128, nhead=4) + c = Conformer( + num_features=feature_dim, output_dim=256, d_model=128, nhead=4 + ) batch_size = 5 seq_len = 20 # Just make sure the forward pass runs. diff --git a/egs/librispeech/ASR/transducer_stateless/train.py b/egs/librispeech/ASR/transducer_stateless/train.py index bcb883fa5..ae93f3348 100755 --- a/egs/librispeech/ASR/transducer_stateless/train.py +++ b/egs/librispeech/ASR/transducer_stateless/train.py @@ -136,7 +136,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -421,7 +422,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -542,7 +545,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -659,9 +664,13 @@ def run(rank, world_size, args): num_removed = num_in_total - num_left removed_percent = num_removed / num_in_total * 100 - logging.info(f"Before removing short and long utterances: {num_in_total}") + logging.info( + f"Before removing short and long utterances: {num_in_total}" + ) logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") + logging.info( + f"Removed {num_removed} utterances ({removed_percent:.5f}%)" + ) except TypeError as e: # You can ignore this error as previous versions of Lhotse work fine # for the above code. In recent versions of Lhotse, it uses @@ -689,7 +698,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/transducer_stateless/transformer.py b/egs/librispeech/ASR/transducer_stateless/transformer.py index b3ff153c1..e851dcc32 100644 --- a/egs/librispeech/ASR/transducer_stateless/transformer.py +++ b/egs/librispeech/ASR/transducer_stateless/transformer.py @@ -250,7 +250,9 @@ def _get_activation_fn(activation: str): elif activation == "gelu": return nn.functional.gelu - raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) class PositionalEncoding(nn.Module): diff --git a/egs/librispeech/ASR/transducer_stateless2/decode.py b/egs/librispeech/ASR/transducer_stateless2/decode.py index 86ef9e5b6..ac2807241 100755 --- a/egs/librispeech/ASR/transducer_stateless2/decode.py +++ b/egs/librispeech/ASR/transducer_stateless2/decode.py @@ -94,19 +94,16 @@ def get_parser(): "--epoch", type=int, default=29, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=13, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -174,7 +171,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -232,7 +230,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -248,7 +248,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -294,7 +297,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -367,7 +374,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -400,7 +409,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -440,7 +450,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/egs/librispeech/ASR/transducer_stateless2/export.py b/egs/librispeech/ASR/transducer_stateless2/export.py index d95eeb1f4..57c1a6094 100755 --- a/egs/librispeech/ASR/transducer_stateless2/export.py +++ b/egs/librispeech/ASR/transducer_stateless2/export.py @@ -63,20 +63,17 @@ def get_parser(): "--epoch", type=int, default=20, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -107,7 +104,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -178,7 +176,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer_stateless2/pretrained.py b/egs/librispeech/ASR/transducer_stateless2/pretrained.py index 793931e3b..292f77f03 100755 --- a/egs/librispeech/ASR/transducer_stateless2/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless2/pretrained.py @@ -90,11 +90,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -119,12 +117,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -171,7 +167,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -200,9 +197,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -261,7 +259,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -334,7 +334,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer_stateless2/train.py b/egs/librispeech/ASR/transducer_stateless2/train.py index 68e247f23..ea15c9040 100755 --- a/egs/librispeech/ASR/transducer_stateless2/train.py +++ b/egs/librispeech/ASR/transducer_stateless2/train.py @@ -136,7 +136,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -409,7 +410,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -530,7 +533,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -647,9 +652,13 @@ def run(rank, world_size, args): num_removed = num_in_total - num_left removed_percent = num_removed / num_in_total * 100 - logging.info(f"Before removing short and long utterances: {num_in_total}") + logging.info( + f"Before removing short and long utterances: {num_in_total}" + ) logging.info(f"After removing short and long utterances: {num_left}") - logging.info(f"Removed {num_removed} utterances ({removed_percent:.5f}%)") + logging.info( + f"Removed {num_removed} utterances ({removed_percent:.5f}%)" + ) except TypeError as e: # You can ignore this error as previous versions of Lhotse work fine # for the above code. In recent versions of Lhotse, it uses @@ -677,7 +686,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py index 22b6ab911..d596e05cb 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/decode.py @@ -95,19 +95,16 @@ def get_parser(): "--epoch", type=int, default=29, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=13, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -175,7 +172,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -233,7 +231,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -249,7 +249,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -295,7 +298,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -368,7 +375,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -401,7 +410,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -441,7 +451,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py index fad9a6977..b6b69d932 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/export.py @@ -69,20 +69,17 @@ def get_parser(): "--epoch", type=int, default=20, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -113,7 +110,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -249,7 +247,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py index efd257b5d..f297fa2b2 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/pretrained.py @@ -90,11 +90,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -119,12 +117,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -171,7 +167,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -200,9 +197,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -261,7 +259,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -334,7 +334,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py index 1e1188ca6..ef51a7811 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/test_asr_datamodule.py @@ -41,7 +41,9 @@ def test_dataset(): print(args) if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) else: cuts_musan = None diff --git a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py index 88987d91c..27912738c 100755 --- a/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py +++ b/egs/librispeech/ASR/transducer_stateless_multi_datasets/train.py @@ -114,7 +114,8 @@ def get_parser(): "--full-libri", type=str2bool, default=True, - help="When enabled, use 960h LibriSpeech. Otherwise, use 100h subset.", + help="When enabled, use 960h LibriSpeech. " + "Otherwise, use 100h subset.", ) parser.add_argument( @@ -169,7 +170,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -467,7 +469,9 @@ def compute_loss( info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -631,7 +635,9 @@ def train_one_epoch( f"train/current_{prefix}_", params.batch_idx_train, ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) libri_tot_loss.write_summary( tb_writer, "train/libri_tot_", params.batch_idx_train ) @@ -778,7 +784,9 @@ def run(rank, world_size, args): train_giga_cuts = train_giga_cuts.repeat(times=None) if args.enable_musan: - cuts_musan = load_manifest(Path(args.manifest_dir) / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + Path(args.manifest_dir) / "musan_cuts.jsonl.gz" + ) else: cuts_musan = None @@ -817,7 +825,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/ptb/LM/local/sort_lm_training_data.py b/egs/ptb/LM/local/sort_lm_training_data.py index bed3856e4..af54dbd07 100755 --- a/egs/ptb/LM/local/sort_lm_training_data.py +++ b/egs/ptb/LM/local/sort_lm_training_data.py @@ -135,7 +135,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/ptb/LM/local/test_prepare_lm_training_data.py b/egs/ptb/LM/local/test_prepare_lm_training_data.py index 3790045fa..877720e7b 100755 --- a/egs/ptb/LM/local/test_prepare_lm_training_data.py +++ b/egs/ptb/LM/local/test_prepare_lm_training_data.py @@ -54,7 +54,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/spgispeech/ASR/local/compute_fbank_musan.py b/egs/spgispeech/ASR/local/compute_fbank_musan.py index 9bea28a41..6cb8b65ae 100755 --- a/egs/spgispeech/ASR/local/compute_fbank_musan.py +++ b/egs/spgispeech/ASR/local/compute_fbank_musan.py @@ -87,7 +87,9 @@ def compute_fbank_musan(): # create chunks of Musan with duration 5 - 10 seconds musan_cuts = ( CutSet.from_manifests( - recordings=combine(part["recordings"] for part in manifests.values()) + recordings=combine( + part["recordings"] for part in manifests.values() + ) ) .cut_into_windows(10.0) .filter(lambda c: c.duration > 5) @@ -106,6 +108,8 @@ def compute_fbank_musan(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan() diff --git a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py index 20ff6d7ab..8116e7605 100755 --- a/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py +++ b/egs/spgispeech/ASR/local/compute_fbank_spgispeech.py @@ -103,7 +103,11 @@ def compute_fbank_spgispeech(args): chunk_size=chunk_size, ) start = args.start - stop = min(args.stop, args.num_splits) if args.stop > 0 else args.num_splits + stop = ( + min(args.stop, args.num_splits) + if args.stop > 0 + else args.num_splits + ) num_digits = len(str(args.num_splits)) for i in range(start, stop): idx = f"{i + 1}".zfill(num_digits) @@ -125,7 +129,9 @@ def compute_fbank_spgispeech(args): logging.info(f"{partition} already exists - skipping.") continue logging.info(f"Processing {partition}") - cut_set = load_manifest_lazy(src_dir / f"cuts_{partition}_raw.jsonl.gz") + cut_set = load_manifest_lazy( + src_dir / f"cuts_{partition}_raw.jsonl.gz" + ) cut_set = cut_set.compute_and_store_features_batch( extractor=extractor, storage_path=output_dir / f"feats_{partition}", @@ -138,7 +144,9 @@ def compute_fbank_spgispeech(args): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) args = get_args() diff --git a/egs/spgispeech/ASR/local/prepare_splits.py b/egs/spgispeech/ASR/local/prepare_splits.py index 508d4acd8..8c8f1c133 100755 --- a/egs/spgispeech/ASR/local/prepare_splits.py +++ b/egs/spgispeech/ASR/local/prepare_splits.py @@ -55,7 +55,9 @@ def split_spgispeech_train(): # Add speed perturbation train_cuts = ( - train_cuts + train_cuts.perturb_speed(0.9) + train_cuts.perturb_speed(1.1) + train_cuts + + train_cuts.perturb_speed(0.9) + + train_cuts.perturb_speed(1.1) ) # Write the manifests to disk. @@ -71,7 +73,9 @@ def split_spgispeech_train(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) split_spgispeech_train() diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 83f95d123..f165f6e60 100644 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -70,12 +70,10 @@ class SPGISpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--manifest-dir", @@ -87,81 +85,67 @@ class SPGISpeechAsrDataModule: "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it " - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it " + "with training dataset. ", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--max-duration", type=int, default=100.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the BucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the BucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--num-workers", type=int, default=8, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( "--enable-spec-aug", @@ -173,12 +157,10 @@ class SPGISpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) def train_dataloaders( @@ -194,20 +176,24 @@ class SPGISpeechAsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "cuts_musan.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "cuts_musan.jsonl.gz" + ) transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -222,7 +208,9 @@ class SPGISpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) input_transforms.append( SpecAugment( time_warp_factor=self.args.spec_aug_time_warp_factor, @@ -239,7 +227,9 @@ class SPGISpeechAsrDataModule: if self.args.on_the_fly_feats: train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, ) else: @@ -292,7 +282,9 @@ class SPGISpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), ) else: validate = K2SpeechRecognitionDataset( @@ -336,7 +328,9 @@ class SPGISpeechAsrDataModule: @lru_cache() def train_cuts(self) -> CutSet: logging.info("About to get SPGISpeech train cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_train_shuf.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "cuts_train_shuf.jsonl.gz" + ) @lru_cache() def dev_cuts(self) -> CutSet: diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py index 72a7cd1c1..c39bd0530 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/decode.py @@ -76,7 +76,11 @@ from beam_search import ( ) from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import ( AttributeDict, setup_logger, @@ -113,11 +117,9 @@ def get_parser(): "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( @@ -185,7 +187,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -243,7 +246,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -258,7 +263,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -304,7 +312,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -377,7 +389,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -410,7 +424,9 @@ def save_results( # we also compute CER for spgispeech dataset. results_char = [] for res in results: - results_char.append((res[0], list("".join(res[1])), list("".join(res[2])))) + results_char.append( + (res[0], list("".join(res[1])), list("".join(res[2]))) + ) cers_filename = ( params.res_dir / f"cers-{test_set_name}-{key}-{params.suffix}.txt" ) @@ -422,23 +438,32 @@ def save_results( logging.info("Wrote detailed error stats to {}".format(wers_filename)) - test_set_wers = {k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1])} - test_set_cers = {k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1])} + test_set_wers = { + k: v for k, v in sorted(test_set_wers.items(), key=lambda x: x[1]) + } + test_set_cers = { + k: v for k, v in sorted(test_set_cers.items(), key=lambda x: x[1]) + } errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER\tCER", file=f) for key in test_set_wers: print( - "{}\t{}\t{}".format(key, test_set_wers[key], test_set_cers[key]), + "{}\t{}\t{}".format( + key, test_set_wers[key], test_set_cers[key] + ), file=f, ) s = "\nFor {}, WER/CER of different settings are:\n".format(test_set_name) note = "\tbest for {}".format(test_set_name) for key in test_set_wers: - s += "{}\t{}\t{}{}\n".format(key, test_set_wers[key], test_set_cers[key], note) + s += "{}\t{}\t{}{}\n".format( + key, test_set_wers[key], test_set_cers[key], note + ) note = "" logging.info(s) @@ -471,7 +496,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -503,7 +530,8 @@ def main(): ] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py index 1f18ae2f3..77faa3c0e 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/export.py @@ -50,7 +50,11 @@ import sentencepiece as spm import torch from train import get_params, get_transducer_model -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.utils import str2bool @@ -63,20 +67,17 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -118,7 +119,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -194,7 +196,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py index cd835a7b4..dda29b3e5 100755 --- a/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/spgispeech/ASR/pruned_transducer_stateless2/train.py @@ -77,7 +77,9 @@ from icefall.dist import cleanup_dist, setup_dist from icefall.env import get_env_info from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -153,7 +155,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need to be " + "changed.", ) parser.add_argument( @@ -176,45 +179,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -554,16 +554,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -726,7 +733,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") diff --git a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py index 602e50d29..4582609ac 100755 --- a/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py +++ b/egs/tal_csasr/ASR/local/compute_fbank_tal_csasr.py @@ -84,7 +84,9 @@ def compute_fbank_tal_csasr(num_mel_bins: int = 80): ) if "train" in partition: cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -110,7 +112,9 @@ def get_args(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/tal_csasr/ASR/local/prepare_char.py b/egs/tal_csasr/ASR/local/prepare_char.py index 1262baf63..2c5b8b8b3 100755 --- a/egs/tal_csasr/ASR/local/prepare_char.py +++ b/egs/tal_csasr/ASR/local/prepare_char.py @@ -87,7 +87,9 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [token2id[i] if i in token2id else token2id[""] for i in pieces] + pieces = [ + token2id[i] if i in token2id else token2id[" "] for i in pieces + ] for i in range(len(pieces) - 1): w = word if i == 0 else eps diff --git a/egs/tal_csasr/ASR/local/prepare_lang.py b/egs/tal_csasr/ASR/local/prepare_lang.py index c8cf9b881..e5ae89ec4 100755 --- a/egs/tal_csasr/ASR/local/prepare_lang.py +++ b/egs/tal_csasr/ASR/local/prepare_lang.py @@ -317,7 +317,9 @@ def lexicon_to_fst( def get_args(): parser = argparse.ArgumentParser() - parser.add_argument("--lang-dir", type=str, help="The lang dir, data/lang_phone") + parser.add_argument( + "--lang-dir", type=str, help="The lang dir, data/lang_phone" + ) return parser.parse_args() diff --git a/egs/tal_csasr/ASR/local/test_prepare_lang.py b/egs/tal_csasr/ASR/local/test_prepare_lang.py index 74e025ad7..d4cf62bba 100755 --- a/egs/tal_csasr/ASR/local/test_prepare_lang.py +++ b/egs/tal_csasr/ASR/local/test_prepare_lang.py @@ -88,7 +88,9 @@ def test_read_lexicon(filename: str): fsa.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa.draw("L.pdf", title="L") - fsa_disambig = lexicon_to_fst(lexicon_disambig, phone2id=phone2id, word2id=word2id) + fsa_disambig = lexicon_to_fst( + lexicon_disambig, phone2id=phone2id, word2id=word2id + ) fsa_disambig.labels_sym = k2.SymbolTable.from_file("phones.txt") fsa_disambig.aux_labels_sym = k2.SymbolTable.from_file("words.txt") fsa_disambig.draw("L_disambig.pdf", title="L_disambig") diff --git a/egs/tal_csasr/ASR/local/text2token.py b/egs/tal_csasr/ASR/local/text2token.py index 2be639b7a..71be2a613 100755 --- a/egs/tal_csasr/ASR/local/text2token.py +++ b/egs/tal_csasr/ASR/local/text2token.py @@ -50,15 +50,15 @@ def get_parser(): "-n", default=1, type=int, - help=( - "number of characters to split, i.e., aabb -> a a b" - " b with -n 1 and aa bb with -n 2" - ), + help="number of characters to split, i.e., \ + aabb -> a a b b with -n 1 and aa bb with -n 2", ) parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument("--space", default=" ", type=str, help="space symbol") + parser.add_argument( + "--space", default=" ", type=str, help="space symbol" + ) parser.add_argument( "--non-lang-syms", "-l", @@ -66,7 +66,9 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument("text", type=str, default=False, nargs="?", help="input text") + parser.add_argument( + "text", type=str, default=False, nargs="?", help="input text" + ) parser.add_argument( "--trans_type", "-t", @@ -106,7 +108,8 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id for txt in text + token_table[txt] if txt in token_table else oov_id + for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -132,7 +135,9 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) + f = codecs.getreader("utf-8")( + sys.stdin if is_python2 else sys.stdin.buffer + ) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py index 02bd6e2cc..49bfb148b 100644 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/asr_datamodule.py @@ -74,12 +74,10 @@ class TAL_CSASRAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( @@ -93,81 +91,66 @@ class TAL_CSASRAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=300, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( @@ -181,18 +164,17 @@ class TAL_CSASRAsrDataModule: "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -206,22 +188,18 @@ class TAL_CSASRAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", ) group.add_argument( @@ -244,20 +222,24 @@ class TAL_CSASRAsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -272,7 +254,9 @@ class TAL_CSASRAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -316,7 +300,9 @@ class TAL_CSASRAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -374,7 +360,9 @@ class TAL_CSASRAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py index b2aef7e86..b624913f5 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/decode.py @@ -124,24 +124,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -212,7 +208,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -271,7 +268,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] zh_hyps = [] en_hyps = [] @@ -304,7 +303,10 @@ def decode_one_batch( hyps.append(chars_new) zh_hyps.append(zh_text) en_hyps.append(en_text) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -373,7 +375,9 @@ def decode_one_batch( f"Unsupported decoding method: {params.decoding_method}" ) for i in range(encoder_out.size(0)): - hyp = sp.decode([lexicon.token_table[idx] for idx in hyp_tokens[i]]) + hyp = sp.decode( + [lexicon.token_table[idx] for idx in hyp_tokens[i]] + ) chars = pattern.split(hyp.upper()) chars_new = [] zh_text = [] @@ -392,11 +396,11 @@ def decode_one_batch( return {"greedy_search": (hyps, zh_hyps, en_hyps)} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": ( - hyps, - zh_hyps, - en_hyps, - ) + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): (hyps, zh_hyps, en_hyps) } else: return {f"beam_size_{params.beam_size}": (hyps, zh_hyps, en_hyps)} @@ -502,7 +506,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results, zh_results, en_results @@ -535,7 +541,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -578,7 +585,9 @@ def main(): params.suffix += f"-max-contexts-{params.max_contexts}" params.suffix += f"-max-states-{params.max_states}" elif "beam_search" in params.decoding_method: - params.suffix += f"-{params.decoding_method}-beam-size-{params.beam_size}" + params.suffix += ( + f"-{params.decoding_method}-beam-size-{params.beam_size}" + ) else: params.suffix += f"-context-{params.context_size}" params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}" @@ -610,12 +619,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -638,12 +648,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -671,7 +682,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py index 94a4c7a2e..8f900208a 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/export.py @@ -92,24 +92,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=False, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -143,7 +139,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -179,12 +176,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -207,12 +205,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -240,7 +239,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -278,7 +277,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py index 198242129..dbe213b24 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/pretrained.py @@ -84,11 +84,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -117,12 +115,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -169,7 +165,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -200,9 +197,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -265,11 +263,15 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) - encoder_out, encoder_out_lens = model.encoder(x=features, x_lens=feature_lengths) + encoder_out, encoder_out_lens = model.encoder( + x=features, x_lens=feature_lengths + ) num_waves = encoder_out.size(0) hyps = [] @@ -365,7 +367,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py index 676e8c904..ca35eba45 100755 --- a/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py +++ b/egs/tal_csasr/ASR/pruned_transducer_stateless5/train.py @@ -86,7 +86,9 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -212,7 +214,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need " + "to be changed.", ) parser.add_argument( @@ -235,45 +238,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -600,7 +600,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -630,15 +634,22 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -817,7 +828,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -931,7 +944,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py index 733ebf235..327962a79 100755 --- a/egs/tedlium3/ASR/local/compute_fbank_tedlium.py +++ b/egs/tedlium3/ASR/local/compute_fbank_tedlium.py @@ -83,7 +83,9 @@ def compute_fbank_tedlium(): ) if "train" in partition: cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) ) cur_num_jobs = num_jobs if ex is None else 80 cur_num_jobs = min(cur_num_jobs, len(cut_set)) @@ -102,7 +104,9 @@ def compute_fbank_tedlium(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py index 9dbcc9d9e..49544ccb3 100644 --- a/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py +++ b/egs/tedlium3/ASR/local/convert_transcript_words_to_bpe_ids.py @@ -25,7 +25,9 @@ import sentencepiece as spm def get_args(): parser = argparse.ArgumentParser() - parser.add_argument("--texts", type=List[str], help="The input transcripts list.") + parser.add_argument( + "--texts", type=List[str], help="The input transcripts list." + ) parser.add_argument( "--bpe-model", type=str, diff --git a/egs/tedlium3/ASR/local/prepare_lexicon.py b/egs/tedlium3/ASR/local/prepare_lexicon.py index b9160b6d4..35dd332e8 100755 --- a/egs/tedlium3/ASR/local/prepare_lexicon.py +++ b/egs/tedlium3/ASR/local/prepare_lexicon.py @@ -23,12 +23,11 @@ consisting of supervisions_train.json and does the following: 1. Generate lexicon_words.txt. """ +import lhotse import argparse import logging from pathlib import Path -import lhotse - def get_args(): parser = argparse.ArgumentParser() @@ -62,7 +61,9 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str): words = set() lexicon = Path(lang_dir) / "lexicon_words.txt" - sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz") + sups = lhotse.load_manifest( + f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz" + ) for s in sups: # list the words units and filter the empty item words_list = list(filter(None, s.text.split())) @@ -87,7 +88,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/tedlium3/ASR/local/prepare_transcripts.py b/egs/tedlium3/ASR/local/prepare_transcripts.py index 7ea4e89a4..1039ac5bb 100755 --- a/egs/tedlium3/ASR/local/prepare_transcripts.py +++ b/egs/tedlium3/ASR/local/prepare_transcripts.py @@ -23,12 +23,11 @@ consisting of supervisions_train.json and does the following: 1. Generate train.text. """ +import lhotse import argparse import logging from pathlib import Path -import lhotse - def get_args(): parser = argparse.ArgumentParser() @@ -62,7 +61,9 @@ def prepare_transcripts(manifests_dir: str, lang_dir: str): texts = [] train_text = Path(lang_dir) / "train.text" - sups = lhotse.load_manifest(f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz") + sups = lhotse.load_manifest( + f"{manifests_dir}/tedlium_supervisions_train.jsonl.gz" + ) for s in sups: texts.append(s.text) @@ -82,7 +83,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py index 6bae33e65..2b294e601 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/decode.py @@ -94,20 +94,17 @@ def get_parser(): "--epoch", type=int, default=29, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=13, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -175,7 +172,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -233,7 +231,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -248,7 +248,10 @@ def decode_one_batch( ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -294,7 +297,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -367,7 +374,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -400,7 +409,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py index 244740932..a1c3bcea3 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/export.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/export.py @@ -65,20 +65,17 @@ def get_parser(): "--epoch", type=int, default=30, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=13, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -109,7 +106,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -181,7 +179,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py index 00545f107..8480ac029 100644 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/pretrained.py @@ -93,11 +93,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -124,12 +122,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -169,7 +165,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -206,9 +203,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -273,7 +271,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -298,7 +298,10 @@ def main(): ) for hyp in sp.decode(hyp_tokens): hyps.append(hyp.split()) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -350,7 +353,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py index 70c5e290f..8d5cdf683 100755 --- a/egs/tedlium3/ASR/pruned_transducer_stateless/train.py +++ b/egs/tedlium3/ASR/pruned_transducer_stateless/train.py @@ -133,45 +133,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -559,7 +556,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -679,7 +678,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py index f90f79d8c..94784c4c4 100644 --- a/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py +++ b/egs/tedlium3/ASR/transducer_stateless/asr_datamodule.py @@ -18,6 +18,7 @@ import argparse import logging + from functools import lru_cache from pathlib import Path from typing import Any, Dict, Optional @@ -62,12 +63,10 @@ class TedLiumAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--manifest-dir", @@ -79,90 +78,74 @@ class TedLiumAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( "--enable-spec-aug", @@ -174,25 +157,23 @@ class TedLiumAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset.", ) def train_dataloaders( - self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None ) -> DataLoader: """ Args: @@ -205,7 +186,9 @@ class TedLiumAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) input_transforms.append( SpecAugment( @@ -225,16 +208,20 @@ class TedLiumAsrDataModule: transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -260,7 +247,9 @@ class TedLiumAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -317,7 +306,9 @@ class TedLiumAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -348,7 +339,9 @@ class TedLiumAsrDataModule: logging.debug("About to create test dataset") if self.args.on_the_fly_feats: test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -382,9 +375,13 @@ class TedLiumAsrDataModule: @lru_cache() def dev_cuts(self) -> CutSet: logging.info("About to get dev cuts") - return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "tedlium_cuts_dev.jsonl.gz" + ) @lru_cache() def test_cuts(self) -> CutSet: logging.info("About to get test cuts") - return load_manifest_lazy(self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "tedlium_cuts_test.jsonl.gz" + ) diff --git a/egs/tedlium3/ASR/transducer_stateless/beam_search.py b/egs/tedlium3/ASR/transducer_stateless/beam_search.py index 1f99edaf3..77caf6460 100644 --- a/egs/tedlium3/ASR/transducer_stateless/beam_search.py +++ b/egs/tedlium3/ASR/transducer_stateless/beam_search.py @@ -87,9 +87,9 @@ def greedy_search( y = logits.argmax().item() if y != blank_id and y != unk_id: hyp.append(y) - decoder_input = torch.tensor([hyp[-context_size:]], device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [hyp[-context_size:]], device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -148,7 +148,9 @@ class HypothesisList(object): key = hyp.key if key in self: old_hyp = self._data[key] # shallow copy - torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + torch.logaddexp( + old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob + ) else: self._data[key] = hyp @@ -164,7 +166,9 @@ class HypothesisList(object): Return the hypothesis that has the largest `log_prob`. """ if length_norm: - return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + return max( + self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys) + ) else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) @@ -340,9 +344,9 @@ def modified_beam_search( device = model.device - decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) @@ -379,7 +383,9 @@ def modified_beam_search( decoder_out = model.decoder(decoder_input, need_pad=False) # decoder_output is of shape (num_hyps, 1, decoder_output_dim) - current_encoder_out = current_encoder_out.expand(decoder_out.size(0), 1, -1) + current_encoder_out = current_encoder_out.expand( + decoder_out.size(0), 1, -1 + ) logits = model.joiner( current_encoder_out, @@ -448,9 +454,9 @@ def beam_search( device = model.device - decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape( - 1, context_size - ) + decoder_input = torch.tensor( + [blank_id] * context_size, device=device + ).reshape(1, context_size) decoder_out = model.decoder(decoder_input, need_pad=False) diff --git a/egs/tedlium3/ASR/transducer_stateless/decode.py b/egs/tedlium3/ASR/transducer_stateless/decode.py index 12d0e2652..d3e9e55e7 100755 --- a/egs/tedlium3/ASR/transducer_stateless/decode.py +++ b/egs/tedlium3/ASR/transducer_stateless/decode.py @@ -81,19 +81,16 @@ def get_parser(): "--epoch", type=int, default=29, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=13, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -133,7 +130,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -252,7 +250,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] batch_size = encoder_out.size(0) @@ -275,7 +275,9 @@ def decode_one_batch( model=model, encoder_out=encoder_out_i, beam=params.beam_size ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) hyps.append(sp.decode(hyp).split()) if params.decoding_method == "greedy_search": @@ -346,7 +348,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -379,7 +383,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) diff --git a/egs/tedlium3/ASR/transducer_stateless/decoder.py b/egs/tedlium3/ASR/transducer_stateless/decoder.py index f9a3814c6..f0c6f32b6 100644 --- a/egs/tedlium3/ASR/transducer_stateless/decoder.py +++ b/egs/tedlium3/ASR/transducer_stateless/decoder.py @@ -90,7 +90,9 @@ class Decoder(nn.Module): if self.context_size > 1: embedding_out = embedding_out.permute(0, 2, 1) if need_pad is True: - embedding_out = F.pad(embedding_out, pad=(self.context_size - 1, 0)) + embedding_out = F.pad( + embedding_out, pad=(self.context_size - 1, 0) + ) else: # During inference time, there is no need to do extra padding # as we only need one output diff --git a/egs/tedlium3/ASR/transducer_stateless/export.py b/egs/tedlium3/ASR/transducer_stateless/export.py index 0b2ae970b..c32b1d002 100644 --- a/egs/tedlium3/ASR/transducer_stateless/export.py +++ b/egs/tedlium3/ASR/transducer_stateless/export.py @@ -69,20 +69,17 @@ def get_parser(): "--epoch", type=int, default=20, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=10, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -113,7 +110,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -249,7 +247,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/tedlium3/ASR/transducer_stateless/pretrained.py b/egs/tedlium3/ASR/transducer_stateless/pretrained.py index 912d65497..c0e3bb844 100644 --- a/egs/tedlium3/ASR/transducer_stateless/pretrained.py +++ b/egs/tedlium3/ASR/transducer_stateless/pretrained.py @@ -82,11 +82,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -112,12 +110,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -131,7 +127,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -225,9 +222,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -287,7 +285,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -335,7 +335,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/tedlium3/ASR/transducer_stateless/train.py b/egs/tedlium3/ASR/transducer_stateless/train.py index 6fed32e81..09cbf4a00 100755 --- a/egs/tedlium3/ASR/transducer_stateless/train.py +++ b/egs/tedlium3/ASR/transducer_stateless/train.py @@ -133,7 +133,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -524,7 +525,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -644,7 +647,9 @@ def run(rank, world_size, args): cur_lr = optimizer._rate if tb_writer is not None: - tb_writer.add_scalar("train/learning_rate", cur_lr, params.batch_idx_train) + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) if rank == 0: diff --git a/egs/timit/ASR/RESULTS.md b/egs/timit/ASR/RESULTS.md index d8ceb82b6..b78c16b88 100644 --- a/egs/timit/ASR/RESULTS.md +++ b/egs/timit/ASR/RESULTS.md @@ -71,4 +71,4 @@ python tdnn_ligru_ctc/decode.py --epoch 25 \ --avg 17 \ --max-duration 20 \ --lang-dir data/lang_phone -``` +``` \ No newline at end of file diff --git a/egs/timit/ASR/local/compile_hlg.py b/egs/timit/ASR/local/compile_hlg.py index 32c248d7e..58cab4cf2 100644 --- a/egs/timit/ASR/local/compile_hlg.py +++ b/egs/timit/ASR/local/compile_hlg.py @@ -146,7 +146,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/timit/ASR/local/compute_fbank_timit.py b/egs/timit/ASR/local/compute_fbank_timit.py index ecdf10ba9..f25786a0c 100644 --- a/egs/timit/ASR/local/compute_fbank_timit.py +++ b/egs/timit/ASR/local/compute_fbank_timit.py @@ -85,7 +85,9 @@ def compute_fbank_timit(): ) if partition == "TRAIN": cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -99,7 +101,9 @@ def compute_fbank_timit(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/timit/ASR/local/prepare_lexicon.py b/egs/timit/ASR/local/prepare_lexicon.py index 0cf0f0deb..04023a9ab 100644 --- a/egs/timit/ASR/local/prepare_lexicon.py +++ b/egs/timit/ASR/local/prepare_lexicon.py @@ -62,7 +62,9 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str): phones = set() - supervisions_train = Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz" + supervisions_train = ( + Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz" + ) lexicon = Path(lang_dir) / "lexicon.txt" logging.info(f"Loading {supervisions_train}!") @@ -95,7 +97,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/timit/ASR/prepare.sh b/egs/timit/ASR/prepare.sh index d11cd3a05..ae1b96a68 100644 --- a/egs/timit/ASR/prepare.sh +++ b/egs/timit/ASR/prepare.sh @@ -20,9 +20,9 @@ stop_stage=100 # - $dl_dir/lm # This directory contains the language model(LM) downloaded from # https://huggingface.co/luomingshuang/timit_lm, and the LM is based -# on 39 phones. About how to get these LM files, you can know it +# on 39 phones. About how to get these LM files, you can know it # from https://github.com/luomingshuang/Train_LM_with_kaldilm. -# +# # - lm_3_gram.arpa # - lm_4_gram.arpa # diff --git a/egs/timit/ASR/tdnn_ligru_ctc/decode.py b/egs/timit/ASR/tdnn_ligru_ctc/decode.py index 5a59a13ce..4f2aa2340 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/decode.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/decode.py @@ -57,19 +57,16 @@ def get_parser(): "--epoch", type=int, default=19, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=5, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--method", @@ -339,7 +336,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -401,7 +400,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -461,7 +462,9 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) return model.to(device) @@ -482,7 +485,9 @@ def main(): G=G, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/timit/ASR/tdnn_ligru_ctc/model.py b/egs/timit/ASR/tdnn_ligru_ctc/model.py index 9a594a969..4d2199ace 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/model.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/model.py @@ -16,11 +16,11 @@ # limitations under the License. -from typing import Optional - import torch import torch.nn as nn + from torch import Tensor +from typing import Optional class TdnnLiGRU(nn.Module): @@ -261,7 +261,9 @@ class LiGRU(torch.nn.Module): h = [] if hx is not None: if self.bidirectional: - hx = hx.reshape(self.num_layers, self.batch_size * 2, self.hidden_size) + hx = hx.reshape( + self.num_layers, self.batch_size * 2, self.hidden_size + ) # Processing the different layers for i, ligru_lay in enumerate(self.rnn): if hx is not None: @@ -443,7 +445,9 @@ class LiGRU_Layer(torch.nn.Module): if self.drop_mask_cnt + self.batch_size > self.N_drop_masks: self.drop_mask_cnt = 0 self.drop_masks = self.drop( - torch.ones(self.N_drop_masks, self.hidden_size, device=w.device) + torch.ones( + self.N_drop_masks, self.hidden_size, device=w.device + ) ).data # Sampling the mask diff --git a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py index da669bc39..7da285944 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/pretrained.py @@ -29,7 +29,11 @@ import torchaudio from model import TdnnLiGRU from torch.nn.utils.rnn import pad_sequence -from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_whole_lattice, +) from icefall.utils import AttributeDict, get_texts @@ -42,11 +46,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -56,7 +58,9 @@ def get_parser(): help="Path to words.txt", ) - parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") + parser.add_argument( + "--HLG", type=str, required=True, help="Path to HLG.pt." + ) parser.add_argument( "--method", @@ -99,12 +103,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -142,9 +144,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -212,7 +215,9 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) features = features.permute(0, 2, 1) # now features is (N, C, T) with torch.no_grad(): @@ -264,7 +269,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/timit/ASR/tdnn_ligru_ctc/train.py b/egs/timit/ASR/tdnn_ligru_ctc/train.py index 48b7feda0..452c2a7cb 100644 --- a/egs/timit/ASR/tdnn_ligru_ctc/train.py +++ b/egs/timit/ASR/tdnn_ligru_ctc/train.py @@ -449,7 +449,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: valid_info = compute_validation_loss( diff --git a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py index d957c22e1..1554e987f 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/asr_datamodule.py @@ -63,12 +63,10 @@ class TimitAsrDataModule(DataModule): super().add_arguments(parser) group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--feature-dir", @@ -80,91 +78,75 @@ class TimitAsrDataModule(DataModule): "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=30, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) def train_dataloaders(self) -> DataLoader: @@ -172,13 +154,15 @@ class TimitAsrDataModule(DataModule): cuts_train = self.train_cuts() logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.feature_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.feature_dir / "musan_cuts.jsonl.gz" + ) logging.info("About to create train dataset") transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))] if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -194,9 +178,9 @@ class TimitAsrDataModule(DataModule): # In different Lhotse's versions, the default of num_frame_masks is # different. num_frame_masks = 10 - num_frame_masks_parameter = inspect.signature(SpecAugment.__init__).parameters[ - "num_frame_masks" - ] + num_frame_masks_parameter = inspect.signature( + SpecAugment.__init__ + ).parameters["num_frame_masks"] if num_frame_masks_parameter.default == 1: num_frame_masks = 2 logging.info(f"Num frame mask: {num_frame_masks}") @@ -228,7 +212,9 @@ class TimitAsrDataModule(DataModule): # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -277,7 +263,9 @@ class TimitAsrDataModule(DataModule): if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -311,14 +299,20 @@ class TimitAsrDataModule(DataModule): for cuts_test in cuts: logging.debug("About to create test dataset") test = K2SpeechRecognitionDataset( - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ) if self.args.on_the_fly_feats else PrecomputedFeatures(), return_cuts=self.args.return_cuts, ) - sampler = SingleCutSampler(cuts_test, max_duration=self.args.max_duration) + sampler = SingleCutSampler( + cuts_test, max_duration=self.args.max_duration + ) logging.debug("About to create test dataloader") - test_dl = DataLoader(test, batch_size=None, sampler=sampler, num_workers=1) + test_dl = DataLoader( + test, batch_size=None, sampler=sampler, num_workers=1 + ) test_loaders.append(test_dl) if is_list: diff --git a/egs/timit/ASR/tdnn_lstm_ctc/decode.py b/egs/timit/ASR/tdnn_lstm_ctc/decode.py index 319ee5515..5e7300cf2 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/decode.py @@ -56,19 +56,16 @@ def get_parser(): "--epoch", type=int, default=25, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=5, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--method", @@ -338,7 +335,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -400,7 +399,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -460,7 +461,9 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) return model.to(device) @@ -480,7 +483,9 @@ def main(): G=G, ) - save_results(params=params, test_set_name=test_set, results_dict=results_dict) + save_results( + params=params, test_set_name=test_set, results_dict=results_dict + ) logging.info("Done!") diff --git a/egs/timit/ASR/tdnn_lstm_ctc/model.py b/egs/timit/ASR/tdnn_lstm_ctc/model.py index e211ad80d..51edb97e2 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/model.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/model.py @@ -74,7 +74,10 @@ class TdnnLstm(nn.Module): nn.BatchNorm1d(num_features=512, affine=False), ) self.lstms = nn.ModuleList( - [nn.LSTM(input_size=512, hidden_size=512, num_layers=1) for _ in range(4)] + [ + nn.LSTM(input_size=512, hidden_size=512, num_layers=1) + for _ in range(4) + ] ) self.lstm_bnorms = nn.ModuleList( [nn.BatchNorm1d(num_features=512, affine=False) for _ in range(5)] diff --git a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py index 0c72c973b..5f478da1c 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/pretrained.py @@ -29,7 +29,11 @@ import torchaudio from model import TdnnLstm from torch.nn.utils.rnn import pad_sequence -from icefall.decode import get_lattice, one_best_decoding, rescore_with_whole_lattice +from icefall.decode import ( + get_lattice, + one_best_decoding, + rescore_with_whole_lattice, +) from icefall.utils import AttributeDict, get_texts @@ -42,11 +46,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -56,7 +58,9 @@ def get_parser(): help="Path to words.txt", ) - parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") + parser.add_argument( + "--HLG", type=str, required=True, help="Path to HLG.pt." + ) parser.add_argument( "--method", @@ -99,12 +103,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -142,9 +144,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -212,7 +215,9 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) features = features.permute(0, 2, 1) # now features is (N, C, T) with torch.no_grad(): @@ -264,7 +269,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/timit/ASR/tdnn_lstm_ctc/train.py b/egs/timit/ASR/tdnn_lstm_ctc/train.py index be1ecffaa..849256b98 100644 --- a/egs/timit/ASR/tdnn_lstm_ctc/train.py +++ b/egs/timit/ASR/tdnn_lstm_ctc/train.py @@ -449,7 +449,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: valid_info = compute_validation_loss( diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py index bd73e520e..8a9f6ed30 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_dev_test.py @@ -20,7 +20,12 @@ import logging from pathlib import Path import torch -from lhotse import CutSet, KaldifeatFbank, KaldifeatFbankConfig, LilcomHdf5Writer +from lhotse import ( + CutSet, + KaldifeatFbank, + KaldifeatFbankConfig, + LilcomHdf5Writer, +) # Torch's multithreaded behavior needs to be disabled or # it wastes a lot of CPU and slow things down. @@ -78,7 +83,9 @@ def compute_fbank_wenetspeech_dev_test(): def main(): - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_wenetspeech_dev_test() diff --git a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py index c228597b8..a882b6113 100755 --- a/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py +++ b/egs/wenetspeech/ASR/local/compute_fbank_wenetspeech_splits.py @@ -62,10 +62,8 @@ def get_parser(): "--batch-duration", type=float, default=600.0, - help=( - "The maximum number of audio seconds in a batch." - "Determines batch size dynamically." - ), + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", ) parser.add_argument( @@ -154,7 +152,9 @@ def main(): date_time = now.strftime("%Y-%m-%d-%H-%M-%S") log_filename = "log-compute_fbank_wenetspeech_splits" - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) log_filename = f"{log_filename}-{date_time}" logging.basicConfig( diff --git a/egs/wenetspeech/ASR/local/prepare_char.py b/egs/wenetspeech/ASR/local/prepare_char.py index d8622842f..8bc073c75 100755 --- a/egs/wenetspeech/ASR/local/prepare_char.py +++ b/egs/wenetspeech/ASR/local/prepare_char.py @@ -83,7 +83,9 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [token2id[i] if i in token2id else token2id[" "] for i in pieces] + pieces = [ + token2id[i] if i in token2id else token2id[" "] for i in pieces + ] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -136,7 +138,9 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: return False -def generate_lexicon(token_sym_table: Dict[str, int], words: List[str]) -> Lexicon: +def generate_lexicon( + token_sym_table: Dict[str, int], words: List[str] +) -> Lexicon: """Generate a lexicon from a word list and token_sym_table. Args: token_sym_table: diff --git a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py index 93ce750f8..817969c47 100755 --- a/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py +++ b/egs/wenetspeech/ASR/local/preprocess_wenetspeech.py @@ -115,7 +115,11 @@ def preprocess_wenet_speech(): f"Speed perturb for {partition} with factors 0.9 and 1.1 " "(Perturbing may take 8 minutes and saving may take 20 minutes)" ) - cut_set = cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set = ( + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) + ) logging.info(f"Saving to {raw_cuts_path}") cut_set.to_file(raw_cuts_path) diff --git a/egs/wenetspeech/ASR/local/text2token.py b/egs/wenetspeech/ASR/local/text2token.py index e121d842c..1c463cf1c 100755 --- a/egs/wenetspeech/ASR/local/text2token.py +++ b/egs/wenetspeech/ASR/local/text2token.py @@ -50,15 +50,15 @@ def get_parser(): "-n", default=1, type=int, - help=( - "number of characters to split, i.e., aabb -> a a b" - " b with -n 1 and aa bb with -n 2" - ), + help="number of characters to split, i.e., \ + aabb -> a a b b with -n 1 and aa bb with -n 2", ) parser.add_argument( "--skip-ncols", "-s", default=0, type=int, help="skip first n columns" ) - parser.add_argument("--space", default=" ", type=str, help="space symbol") + parser.add_argument( + "--space", default=" ", type=str, help="space symbol" + ) parser.add_argument( "--non-lang-syms", "-l", @@ -66,7 +66,9 @@ def get_parser(): type=str, help="list of non-linguistic symobles, e.g., etc.", ) - parser.add_argument("text", type=str, default=False, nargs="?", help="input text") + parser.add_argument( + "text", type=str, default=False, nargs="?", help="input text" + ) parser.add_argument( "--trans_type", "-t", @@ -106,7 +108,8 @@ def token2id( if token_type == "lazy_pinyin": text = lazy_pinyin(chars_list) sub_ids = [ - token_table[txt] if txt in token_table else oov_id for txt in text + token_table[txt] if txt in token_table else oov_id + for txt in text ] ids.append(sub_ids) else: # token_type = "pinyin" @@ -132,7 +135,9 @@ def main(): if args.text: f = codecs.open(args.text, encoding="utf-8") else: - f = codecs.getreader("utf-8")(sys.stdin if is_python2 else sys.stdin.buffer) + f = codecs.getreader("utf-8")( + sys.stdin if is_python2 else sys.stdin.buffer + ) sys.stdout = codecs.getwriter("utf-8")( sys.stdout if is_python2 else sys.stdout.buffer diff --git a/egs/wenetspeech/ASR/prepare.sh b/egs/wenetspeech/ASR/prepare.sh index da7d7e061..755fbb2d7 100755 --- a/egs/wenetspeech/ASR/prepare.sh +++ b/egs/wenetspeech/ASR/prepare.sh @@ -190,7 +190,7 @@ if [ $stage -le 15 ] && [ $stop_stage -ge 15 ]; then mkdir -p $lang_char_dir if ! which jq; then - echo "This script is intended to be used with jq but you have not installed jq + echo "This script is intended to be used with jq but you have not installed jq Note: in Linux, you can install jq with the following command: 1. wget -O jq https://github.com/stedolan/jq/releases/download/jq-1.6/jq-linux64 2. chmod +x ./jq diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index bd92ac115..10c953e3b 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -81,12 +81,10 @@ class WenetSpeechAsrDataModule: def add_arguments(cls, parser: argparse.ArgumentParser): group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--manifest-dir", @@ -98,91 +96,75 @@ class WenetSpeechAsrDataModule: "--max-duration", type=int, default=200.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=True, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=300, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) group.add_argument( @@ -196,22 +178,18 @@ class WenetSpeechAsrDataModule: "--spec-aug-time-warp-factor", type=int, default=80, - help=( - "Used only when --enable-spec-aug is True. " - "It specifies the factor for time warping in SpecAugment. " - "Larger values mean more warping. " - "A value less than 1 means to disable time warp." - ), + help="Used only when --enable-spec-aug is True. " + "It specifies the factor for time warping in SpecAugment. " + "Larger values mean more warping. " + "A value less than 1 means to disable time warp.", ) group.add_argument( "--enable-musan", type=str2bool, default=True, - help=( - "When enabled, select noise from MUSAN and mix it" - "with training dataset. " - ), + help="When enabled, select noise from MUSAN and mix it" + "with training dataset. ", ) group.add_argument( @@ -234,20 +212,24 @@ class WenetSpeechAsrDataModule: The state dict for the training sampler. """ logging.info("About to get Musan cuts") - cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") + cuts_musan = load_manifest( + self.args.manifest_dir / "musan_cuts.jsonl.gz" + ) transforms = [] if self.args.enable_musan: logging.info("Enable MUSAN") transforms.append( - CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True) + CutMix( + cuts=cuts_musan, prob=0.5, snr=(10, 20), preserve_id=True + ) ) else: logging.info("Disable MUSAN") if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, @@ -262,7 +244,9 @@ class WenetSpeechAsrDataModule: input_transforms = [] if self.args.enable_spec_aug: logging.info("Enable SpecAugment") - logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") + logging.info( + f"Time warp factor: {self.args.spec_aug_time_warp_factor}" + ) # Set the value of num_frame_masks according to Lhotse's version. # In different Lhotse's versions, the default of num_frame_masks is # different. @@ -305,7 +289,9 @@ class WenetSpeechAsrDataModule: # Drop feats to be on the safe side. train = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), input_transforms=input_transforms, return_cuts=self.args.return_cuts, ) @@ -362,7 +348,9 @@ class WenetSpeechAsrDataModule: if self.args.on_the_fly_feats: validate = K2SpeechRecognitionDataset( cut_transforms=transforms, - input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + input_strategy=OnTheFlyFeatures( + Fbank(FbankConfig(num_mel_bins=80)) + ), return_cuts=self.args.return_cuts, ) else: @@ -426,7 +414,8 @@ class WenetSpeechAsrDataModule: def train_cuts(self) -> CutSet: logging.info("About to get train cuts") cuts_train = load_manifest_lazy( - self.args.manifest_dir / f"cuts_{self.args.training_subset}.jsonl.gz" + self.args.manifest_dir + / f"cuts_{self.args.training_subset}.jsonl.gz" ) return cuts_train @@ -438,9 +427,13 @@ class WenetSpeechAsrDataModule: @lru_cache() def test_net_cuts(self) -> List[CutSet]: logging.info("About to get TEST_NET cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "cuts_TEST_NET.jsonl.gz" + ) @lru_cache() def test_meeting_cuts(self) -> List[CutSet]: logging.info("About to get TEST_MEETING cuts") - return load_manifest_lazy(self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz") + return load_manifest_lazy( + self.args.manifest_dir / "cuts_TEST_MEETING.jsonl.gz" + ) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py index 6e856248c..f0c9bebec 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/decode.py @@ -114,7 +114,11 @@ from beam_search import ( from train import get_params, get_transducer_model from icefall.char_graph_compiler import CharCtcTrainingGraphCompiler -from icefall.checkpoint import average_checkpoints, find_checkpoints, load_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + find_checkpoints, + load_checkpoint, +) from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -133,30 +137,25 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--batch", type=int, default=None, - help=( - "It specifies the batch checkpoint to use for decoding." - "Note: Epoch counts from 0." - ), + help="It specifies the batch checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -253,7 +252,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -328,7 +328,9 @@ def decode_one_batch( supervisions = batch["supervisions"] feature_lens = supervisions["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] if params.decoding_method == "fast_beam_search": @@ -387,7 +389,10 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -433,7 +438,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -506,7 +515,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -539,7 +550,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -651,7 +663,9 @@ def main(): ) decoding_graph.scores *= params.ngram_lm_scale else: - decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + decoding_graph = k2.trivial_graph( + params.vocab_size - 1, device=device + ) else: decoding_graph = None @@ -702,7 +716,8 @@ def main(): ) dev_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) + str(path) + for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) ] cuts_dev_webdataset = CutSet.from_webdataset( dev_shards, @@ -712,7 +727,8 @@ def main(): ) test_net_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) + str(path) + for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) ] cuts_test_net_webdataset = CutSet.from_webdataset( test_net_shards, @@ -723,7 +739,9 @@ def main(): test_meeting_shards = [ str(path) - for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar"))) + for path in sorted( + glob.glob(os.path.join(test_meeting, "shared-*.tar")) + ) ] cuts_test_meeting_webdataset = CutSet.from_webdataset( test_meeting_shards, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py index c742593df..933642a0f 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/export.py @@ -126,20 +126,17 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -208,7 +205,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) return parser @@ -470,9 +468,13 @@ def export_joiner_model_onnx( - projected_decoder_out: a tensor of shape (N, joiner_dim) """ - encoder_proj_filename = str(joiner_filename).replace(".onnx", "_encoder_proj.onnx") + encoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_encoder_proj.onnx" + ) - decoder_proj_filename = str(joiner_filename).replace(".onnx", "_decoder_proj.onnx") + decoder_proj_filename = str(joiner_filename).replace( + ".onnx", "_decoder_proj.onnx" + ) encoder_out_dim = joiner_model.encoder_proj.weight.shape[1] decoder_out_dim = joiner_model.decoder_proj.weight.shape[1] @@ -643,7 +645,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py index ed9020c67..e5cc47bfe 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/jit_pretrained.py @@ -107,12 +107,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -147,9 +145,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -332,7 +331,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py index a46ff5a07..c396c50ef 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_check.py @@ -219,7 +219,9 @@ def test_joiner( ) # Now test encoder_proj - joiner_encoder_proj_inputs = {encoder_proj_input_name: encoder_out.numpy()} + joiner_encoder_proj_inputs = { + encoder_proj_input_name: encoder_out.numpy() + } joiner_encoder_proj_out = joiner_encoder_proj_session.run( [encoder_proj_output_name], joiner_encoder_proj_inputs )[0] @@ -228,10 +230,16 @@ def test_joiner( torch_joiner_encoder_proj_out = model.joiner.encoder_proj(encoder_out) assert torch.allclose( joiner_encoder_proj_out, torch_joiner_encoder_proj_out, atol=1e-5 - ), ((joiner_encoder_proj_out - torch_joiner_encoder_proj_out).abs().max()) + ), ( + (joiner_encoder_proj_out - torch_joiner_encoder_proj_out) + .abs() + .max() + ) # Now test decoder_proj - joiner_decoder_proj_inputs = {decoder_proj_input_name: decoder_out.numpy()} + joiner_decoder_proj_inputs = { + decoder_proj_input_name: decoder_out.numpy() + } joiner_decoder_proj_out = joiner_decoder_proj_session.run( [decoder_proj_output_name], joiner_decoder_proj_inputs )[0] @@ -240,7 +248,11 @@ def test_joiner( torch_joiner_decoder_proj_out = model.joiner.decoder_proj(decoder_out) assert torch.allclose( joiner_decoder_proj_out, torch_joiner_decoder_proj_out, atol=1e-5 - ), ((joiner_decoder_proj_out - torch_joiner_decoder_proj_out).abs().max()) + ), ( + (joiner_decoder_proj_out - torch_joiner_decoder_proj_out) + .abs() + .max() + ) @torch.no_grad() @@ -292,7 +304,9 @@ def main(): if __name__ == "__main__": torch.manual_seed(20220727) - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py index f7d962008..3770fbbb4 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/onnx_pretrained.py @@ -111,12 +111,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -151,9 +149,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -201,7 +200,11 @@ def greedy_search( projected_encoder_out = joiner_encoder_proj.run( [joiner_encoder_proj.get_outputs()[0].name], - {joiner_encoder_proj.get_inputs()[0].name: packed_encoder_out.data.numpy()}, + { + joiner_encoder_proj.get_inputs()[ + 0 + ].name: packed_encoder_out.data.numpy() + }, )[0] blank_id = 0 # hard-code to 0 @@ -386,7 +389,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py index 26c9c2b8c..9a549efd9 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/pretrained.py @@ -80,11 +80,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -109,12 +107,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -162,7 +158,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -192,9 +189,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -255,7 +253,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -280,7 +280,10 @@ def main(): ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -332,7 +335,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py index e020c4c05..d3cc7c9c9 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/train.py @@ -115,7 +115,9 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def get_parser(): @@ -217,45 +219,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -591,15 +590,22 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -756,7 +762,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -856,7 +864,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py index 1023c931a..dd27c17f0 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/conformer.py @@ -210,7 +210,10 @@ class Conformer(EncoderInterface): (num_encoder_layers, cnn_module_kernel - 1, encoder_dim). NOTE: the returned tensors are on the given device. """ - if len(self._init_state) == 2 and self._init_state[0].size(1) == left_context: + if ( + len(self._init_state) == 2 + and self._init_state[0].size(1) == left_context + ): # Note: It is OK to share the init state as it is # not going to be modified by the model return self._init_state @@ -430,7 +433,9 @@ class ConformerEncoderLayer(nn.Module): self.d_model = d_model - self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) + self.self_attn = RelPositionMultiheadAttention( + d_model, nhead, dropout=0.0 + ) self.feed_forward = nn.Sequential( ScaledLinear(d_model, dim_feedforward), @@ -448,7 +453,9 @@ class ConformerEncoderLayer(nn.Module): ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), ) - self.conv_module = ConvolutionModule(d_model, cnn_module_kernel, causal=causal) + self.conv_module = ConvolutionModule( + d_model, cnn_module_kernel, causal=causal + ) self.norm_final = BasicNorm(d_model) @@ -513,7 +520,9 @@ class ConformerEncoderLayer(nn.Module): src = src + self.dropout(src_att) # convolution module - conv, _ = self.conv_module(src, src_key_padding_mask=src_key_padding_mask) + conv, _ = self.conv_module( + src, src_key_padding_mask=src_key_padding_mask + ) src = src + self.dropout(conv) # feed forward module @@ -757,7 +766,9 @@ class RelPositionalEncoding(torch.nn.Module): max_len: Maximum input length. """ - def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: + def __init__( + self, d_model: int, dropout_rate: float, max_len: int = 5000 + ) -> None: """Construct an PositionalEncoding object.""" super(RelPositionalEncoding, self).__init__() self.d_model = d_model @@ -773,7 +784,9 @@ class RelPositionalEncoding(torch.nn.Module): # the length of self.pe is 2 * input_len - 1 if self.pe.size(1) >= x_size_1 * 2 - 1: # Note: TorchScript doesn't implement operator== for torch.Device - if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): + if self.pe.dtype != x.dtype or str(self.pe.device) != str( + x.device + ): self.pe = self.pe.to(dtype=x.dtype, device=x.device) return # Suppose `i` means to the position of query vector and `j` means the @@ -1060,9 +1073,9 @@ class RelPositionMultiheadAttention(nn.Module): if torch.equal(query, key) and torch.equal(key, value): # self-attention - q, k, v = nn.functional.linear(query, in_proj_weight, in_proj_bias).chunk( - 3, dim=-1 - ) + q, k, v = nn.functional.linear( + query, in_proj_weight, in_proj_bias + ).chunk(3, dim=-1) elif torch.equal(key, value): # encoder-decoder attention @@ -1131,25 +1144,33 @@ class RelPositionMultiheadAttention(nn.Module): if attn_mask.dim() == 2: attn_mask = attn_mask.unsqueeze(0) if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: - raise RuntimeError("The size of the 2D attn_mask is not correct.") + raise RuntimeError( + "The size of the 2D attn_mask is not correct." + ) elif attn_mask.dim() == 3: if list(attn_mask.size()) != [ bsz * num_heads, query.size(0), key.size(0), ]: - raise RuntimeError("The size of the 3D attn_mask is not correct.") + raise RuntimeError( + "The size of the 3D attn_mask is not correct." + ) else: raise RuntimeError( - "attn_mask's dimension {} is not supported".format(attn_mask.dim()) + "attn_mask's dimension {} is not supported".format( + attn_mask.dim() + ) ) # attn_mask's dim is 3 now. # convert ByteTensor key_padding_mask to bool - if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8: + if ( + key_padding_mask is not None + and key_padding_mask.dtype == torch.uint8 + ): warnings.warn( - "Byte tensor for key_padding_mask is deprecated. Use bool tensor" - " instead." + "Byte tensor for key_padding_mask is deprecated. Use bool tensor instead." ) key_padding_mask = key_padding_mask.to(torch.bool) @@ -1187,15 +1208,23 @@ class RelPositionMultiheadAttention(nn.Module): # first compute matrix a and matrix c # as described in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" Section 3.3 k = k.permute(1, 2, 3, 0) # (batch, head, d_k, time2) - matrix_ac = torch.matmul(q_with_bias_u, k) # (batch, head, time1, time2) + matrix_ac = torch.matmul( + q_with_bias_u, k + ) # (batch, head, time1, time2) # compute matrix b and matrix d - matrix_bd = torch.matmul(q_with_bias_v, p) # (batch, head, time1, 2*time1-1) + matrix_bd = torch.matmul( + q_with_bias_v, p + ) # (batch, head, time1, 2*time1-1) matrix_bd = self.rel_shift(matrix_bd, left_context) - attn_output_weights = matrix_ac + matrix_bd # (batch, head, time1, time2) + attn_output_weights = ( + matrix_ac + matrix_bd + ) # (batch, head, time1, time2) - attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, -1) + attn_output_weights = attn_output_weights.view( + bsz * num_heads, tgt_len, -1 + ) assert list(attn_output_weights.size()) == [ bsz * num_heads, @@ -1236,17 +1265,21 @@ class RelPositionMultiheadAttention(nn.Module): ): if attn_mask.size(0) != 1: attn_mask = attn_mask.view(bsz, num_heads, tgt_len, src_len) - combined_mask = attn_mask | key_padding_mask.unsqueeze(1).unsqueeze(2) - else: - # attn_mask.shape == (1, tgt_len, src_len) - combined_mask = attn_mask.unsqueeze(0) | key_padding_mask.unsqueeze( + combined_mask = attn_mask | key_padding_mask.unsqueeze( 1 ).unsqueeze(2) + else: + # attn_mask.shape == (1, tgt_len, src_len) + combined_mask = attn_mask.unsqueeze( + 0 + ) | key_padding_mask.unsqueeze(1).unsqueeze(2) attn_output_weights = attn_output_weights.view( bsz, num_heads, tgt_len, src_len ) - attn_output_weights = attn_output_weights.masked_fill(combined_mask, 0.0) + attn_output_weights = attn_output_weights.masked_fill( + combined_mask, 0.0 + ) attn_output_weights = attn_output_weights.view( bsz * num_heads, tgt_len, src_len ) @@ -1258,9 +1291,13 @@ class RelPositionMultiheadAttention(nn.Module): attn_output = torch.bmm(attn_output_weights, v) assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] attn_output = ( - attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn_output.transpose(0, 1) + .contiguous() + .view(tgt_len, bsz, embed_dim) + ) + attn_output = nn.functional.linear( + attn_output, out_proj_weight, out_proj_bias ) - attn_output = nn.functional.linear(attn_output, out_proj_weight, out_proj_bias) if need_weights: # average attention weights over heads @@ -1393,12 +1430,16 @@ class ConvolutionModule(nn.Module): # manualy padding self.lorder zeros to the left x = nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) else: - assert not self.training, "Cache should be None in training time" + assert ( + not self.training + ), "Cache should be None in training time" assert cache.size(0) == self.lorder x = torch.cat([cache.permute(1, 2, 0), x], dim=2) if right_context > 0: cache = x.permute(2, 0, 1)[ - -(self.lorder + right_context) : (-right_context), # noqa + -(self.lorder + right_context) : ( # noqa + -right_context + ), ..., ] else: diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py index 3d66f9dc9..344e31283 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode.py @@ -160,24 +160,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -248,7 +244,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--max-sym-per-frame", @@ -345,7 +342,9 @@ def decode_one_batch( simulate_streaming=True, ) else: - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] @@ -361,7 +360,10 @@ def decode_one_batch( ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -407,7 +409,11 @@ def decode_one_batch( return {"greedy_search": hyps} elif params.decoding_method == "fast_beam_search": return { - f"beam_{params.beam}_max_contexts_{params.max_contexts}_max_states_{params.max_states}": hyps + ( + f"beam_{params.beam}_" + f"max_contexts_{params.max_contexts}_" + f"max_states_{params.max_states}" + ): hyps } else: return {f"beam_size_{params.beam_size}": hyps} @@ -478,7 +484,9 @@ def decode_dataset( if batch_idx % log_interval == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -511,7 +519,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -580,12 +589,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -608,12 +618,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -641,7 +652,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) @@ -709,7 +720,8 @@ def main(): ) dev_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) + str(path) + for path in sorted(glob.glob(os.path.join(dev, "shared-*.tar"))) ] cuts_dev_webdataset = CutSet.from_webdataset( dev_shards, @@ -719,7 +731,8 @@ def main(): ) test_net_shards = [ - str(path) for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) + str(path) + for path in sorted(glob.glob(os.path.join(test_net, "shared-*.tar"))) ] cuts_test_net_webdataset = CutSet.from_webdataset( test_net_shards, @@ -730,7 +743,9 @@ def main(): test_meeting_shards = [ str(path) - for path in sorted(glob.glob(os.path.join(test_meeting, "shared-*.tar"))) + for path in sorted( + glob.glob(os.path.join(test_meeting, "shared-*.tar")) + ) ] cuts_test_meeting_webdataset = CutSet.from_webdataset( test_meeting_shards, diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py index e522943c0..386248554 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/decode_stream.py @@ -75,7 +75,9 @@ class DecodeStream(object): # encoder.streaming_forward self.done_frames: int = 0 - self.pad_length = (params.right_context + 2) * params.subsampling_factor + 3 + self.pad_length = ( + params.right_context + 2 + ) * params.subsampling_factor + 3 if params.decoding_method == "greedy_search": self.hyp = [params.blank_id] * params.context_size @@ -89,11 +91,13 @@ class DecodeStream(object): ) elif params.decoding_method == "fast_beam_search": # The rnnt_decoding_stream for fast_beam_search. - self.rnnt_decoding_stream: k2.RnntDecodingStream = k2.RnntDecodingStream( - decoding_graph + self.rnnt_decoding_stream: k2.RnntDecodingStream = ( + k2.RnntDecodingStream(decoding_graph) ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) @property def done(self) -> bool: @@ -122,10 +126,13 @@ class DecodeStream(object): """Consume chunk_size frames of features""" chunk_length = chunk_size + self.pad_length - ret_length = min(self.num_frames - self.num_processed_frames, chunk_length) + ret_length = min( + self.num_frames - self.num_processed_frames, chunk_length + ) ret_features = self.features[ - self.num_processed_frames : self.num_processed_frames + ret_length # noqa + self.num_processed_frames : self.num_processed_frames # noqa + + ret_length ] self.num_processed_frames += chunk_size diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py index fb53f70ab..d0a7fd69f 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/export.py @@ -90,20 +90,17 @@ def get_parser(): "--epoch", type=int, default=28, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -134,7 +131,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) add_model_arguments(parser) @@ -203,7 +201,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py index 9834189d8..1b064c874 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/pretrained.py @@ -80,11 +80,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -109,12 +107,10 @@ def get_parser(): "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) parser.add_argument( @@ -161,7 +157,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -192,9 +189,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -255,7 +253,9 @@ def main(): features = fbank(waves) feature_lengths = [f.size(0) for f in features] - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) feature_lengths = torch.tensor(feature_lengths, device=device) @@ -280,7 +280,10 @@ def main(): ) for i in range(encoder_out.size(0)): hyps.append([lexicon.token_table[idx] for idx in hyp_tokens[i]]) - elif params.decoding_method == "greedy_search" and params.max_sym_per_frame == 1: + elif ( + params.decoding_method == "greedy_search" + and params.max_sym_per_frame == 1 + ): hyp_tokens = greedy_search_batch( model=model, encoder_out=encoder_out, @@ -332,7 +335,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py index 810d94135..651aff6c9 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_beam_search.py @@ -173,10 +173,14 @@ def modified_beam_search( log_probs_shape = k2.ragged.create_ragged_shape2( row_splits=row_splits, cached_tot_size=log_probs.numel() ) - ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + ragged_log_probs = k2.RaggedTensor( + shape=log_probs_shape, value=log_probs + ) for i in range(batch_size): - topk_log_probs, topk_indexes = ragged_log_probs[i].topk(num_active_paths) + topk_log_probs, topk_indexes = ragged_log_probs[i].topk( + num_active_paths + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py index 31a7fe605..ff96c6487 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/streaming_decode.py @@ -119,24 +119,20 @@ def get_parser(): "--avg", type=int, default=15, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch' and '--iter'" - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch' and '--iter'", ) parser.add_argument( "--use-averaged-model", type=str2bool, default=True, - help=( - "Whether to load averaged model. Currently it only supports " - "using --epoch. If True, it would decode with the averaged model " - "over the epoch range from `epoch-avg` (excluded) to `epoch`." - "Actually only the models with epoch number of `epoch-avg` and " - "`epoch` are loaded for averaging. " - ), + help="Whether to load averaged model. Currently it only supports " + "using --epoch. If True, it would decode with the averaged model " + "over the epoch range from `epoch-avg` (excluded) to `epoch`." + "Actually only the models with epoch number of `epoch-avg` and " + "`epoch` are loaded for averaging. ", ) parser.add_argument( @@ -205,7 +201,8 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( @@ -314,7 +311,9 @@ def decode_one_chunk( encoder_out = model.joiner.encoder_proj(encoder_out) if params.decoding_method == "greedy_search": - greedy_search(model=model, encoder_out=encoder_out, streams=decode_streams) + greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) elif params.decoding_method == "fast_beam_search": processed_lens = processed_lens + encoder_out_lens fast_beam_search_one_best( @@ -334,7 +333,9 @@ def decode_one_chunk( num_active_paths=params.num_active_paths, ) else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) states = [torch.unbind(states[0], dim=2), torch.unbind(states[1], dim=2)] @@ -388,7 +389,9 @@ def decode_dataset( decode_results = [] # Contain decode streams currently running. decode_streams = [] - initial_states = model.encoder.get_init_state(params.left_context, device=device) + initial_states = model.encoder.get_init_state( + params.left_context, device=device + ) for num, cut in enumerate(cuts): # each utterance has a DecodeStream. decode_stream = DecodeStream( @@ -458,7 +461,9 @@ def decode_dataset( elif params.decoding_method == "modified_beam_search": key = f"num_active_paths_{params.num_active_paths}" else: - raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + raise ValueError( + f"Unsupported decoding method: {params.decoding_method}" + ) return {key: decode_results} @@ -494,7 +499,8 @@ def save_results( test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = ( - params.res_dir / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" + params.res_dir + / f"wer-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(errs_info, "w") as f: print("settings\tWER", file=f) @@ -559,12 +565,13 @@ def main(): if not params.use_averaged_model: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg: raise ValueError( @@ -587,12 +594,13 @@ def main(): model.load_state_dict(average_checkpoints(filenames, device=device)) else: if params.iter > 0: - filenames = find_checkpoints(params.exp_dir, iteration=-params.iter)[ - : params.avg + 1 - ] + filenames = find_checkpoints( + params.exp_dir, iteration=-params.iter + )[: params.avg + 1] if len(filenames) == 0: raise ValueError( - f"No checkpoints found for --iter {params.iter}, --avg {params.avg}" + f"No checkpoints found for" + f" --iter {params.iter}, --avg {params.avg}" ) elif len(filenames) < params.avg + 1: raise ValueError( @@ -620,7 +628,7 @@ def main(): filename_start = f"{params.exp_dir}/epoch-{start}.pt" filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" logging.info( - "Calculating the averaged model over epoch range from " + f"Calculating the averaged model over epoch range from " f"{start} (excluded) to {params.epoch}" ) model.to(device) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py index 40c9665f7..2052e9da7 100755 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless5/train.py @@ -98,7 +98,9 @@ from icefall.env import get_env_info from icefall.lexicon import Lexicon from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool -LRSchedulerType = Union[torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler] +LRSchedulerType = Union[ + torch.optim.lr_scheduler._LRScheduler, optim.LRScheduler +] def add_model_arguments(parser: argparse.ArgumentParser): @@ -258,7 +260,8 @@ def get_parser(): "--initial-lr", type=float, default=0.003, - help="The initial learning rate. This value should not need to be changed.", + help="The initial learning rate. This value should not need " + "to be changed.", ) parser.add_argument( @@ -281,45 +284,42 @@ def get_parser(): "--context-size", type=int, default=2, - help="The context size in the decoder. 1 means bigram; 2 means tri-gram", + help="The context size in the decoder. 1 means bigram; " + "2 means tri-gram", ) parser.add_argument( "--prune-range", type=int, default=5, - help=( - "The prune range for rnnt loss, it means how many symbols(context)" - "we are using to compute the loss" - ), + help="The prune range for rnnt loss, it means how many symbols(context)" + "we are using to compute the loss", ) parser.add_argument( "--lm-scale", type=float, default=0.25, - help=( - "The scale to smooth the loss with lm (output of prediction network) part." - ), + help="The scale to smooth the loss with lm " + "(output of prediction network) part.", ) parser.add_argument( "--am-scale", type=float, default=0.0, - help="The scale to smooth the loss with am (output of encoder network)part.", + help="The scale to smooth the loss with am (output of encoder network)" + "part.", ) parser.add_argument( "--simple-loss-scale", type=float, default=0.5, - help=( - "To get pruning ranges, we will calculate a simple version" - "loss(joiner is just addition), this simple loss also uses for" - "training (as a regularization item). We will scale the simple loss" - "with this parameter before adding to the final loss." - ), + help="To get pruning ranges, we will calculate a simple version" + "loss(joiner is just addition), this simple loss also uses for" + "training (as a regularization item). We will scale the simple loss" + "with this parameter before adding to the final loss.", ) parser.add_argument( @@ -665,7 +665,11 @@ def compute_loss( warmup: a floating point value which increases throughout training; values >= 1.0 are fully warmed up and have all modules present. """ - device = model.device if isinstance(model, DDP) else next(model.parameters()).device + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) feature = batch["inputs"] # at entry, feature is (N, T, C) assert feature.ndim == 3 @@ -697,16 +701,23 @@ def compute_loss( # overwhelming the simple_loss and causing it to diverge, # in case it had not fully learned the alignment yet. pruned_loss_scale = ( - 0.0 if warmup < 1.0 else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + 0.0 + if warmup < 1.0 + else (0.1 if warmup > 1.0 and warmup < 2.0 else 1.0) + ) + loss = ( + params.simple_loss_scale * simple_loss + + pruned_loss_scale * pruned_loss ) - loss = params.simple_loss_scale * simple_loss + pruned_loss_scale * pruned_loss assert loss.requires_grad == is_training info = MetricsTracker() with warnings.catch_warnings(): warnings.simplefilter("ignore") - info["frames"] = (feature_lens // params.subsampling_factor).sum().item() + info["frames"] = ( + (feature_lens // params.subsampling_factor).sum().item() + ) # Note: We use reduction=sum while computing the loss. info["loss"] = loss.detach().cpu().item() @@ -830,7 +841,9 @@ def train_one_epoch( scaler.update() optimizer.zero_grad() except: # noqa - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch( + batch, params=params, graph_compiler=graph_compiler + ) raise if params.print_diagnostics and batch_idx == 5: @@ -888,7 +901,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -1001,7 +1016,7 @@ def run(rank, world_size, args): if params.print_diagnostics: opts = diagnostics.TensorDiagnosticOptions( - 2**22 + 2 ** 22 ) # allow 4 megabytes per sub-module diagnostic = diagnostics.attach_diagnostics(model, opts) @@ -1169,7 +1184,9 @@ def scan_pessimistic_batches_for_oom( f"Failing criterion: {criterion} " f"(={crit_values[criterion]}) ..." ) - display_and_save_batch(batch, params=params, graph_compiler=graph_compiler) + display_and_save_batch( + batch, params=params, graph_compiler=graph_compiler + ) raise diff --git a/egs/yesno/ASR/local/compile_hlg.py b/egs/yesno/ASR/local/compile_hlg.py index 7234ca929..f83be05cf 100755 --- a/egs/yesno/ASR/local/compile_hlg.py +++ b/egs/yesno/ASR/local/compile_hlg.py @@ -128,7 +128,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/yesno/ASR/local/compute_fbank_yesno.py b/egs/yesno/ASR/local/compute_fbank_yesno.py index 75d95df68..9a4e8a36f 100755 --- a/egs/yesno/ASR/local/compute_fbank_yesno.py +++ b/egs/yesno/ASR/local/compute_fbank_yesno.py @@ -54,7 +54,9 @@ def compute_fbank_yesno(): dataset_parts, ) - extractor = Fbank(FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins)) + extractor = Fbank( + FbankConfig(sampling_rate=8000, num_mel_bins=num_mel_bins) + ) with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): @@ -69,7 +71,9 @@ def compute_fbank_yesno(): ) if "train" in partition: cut_set = ( - cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) + cut_set + + cut_set.perturb_speed(0.9) + + cut_set.perturb_speed(1.1) ) cut_set = cut_set.compute_and_store_features( extractor=extractor, @@ -83,7 +87,9 @@ def compute_fbank_yesno(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) diff --git a/egs/yesno/ASR/tdnn/asr_datamodule.py b/egs/yesno/ASR/tdnn/asr_datamodule.py index 21860d2f5..85e5f1358 100644 --- a/egs/yesno/ASR/tdnn/asr_datamodule.py +++ b/egs/yesno/ASR/tdnn/asr_datamodule.py @@ -56,12 +56,10 @@ class YesNoAsrDataModule(DataModule): super().add_arguments(parser) group = parser.add_argument_group( title="ASR data related options", - description=( - "These options are used for the preparation of " - "PyTorch DataLoaders from Lhotse CutSet's -- they control the " - "effective batch sizes, sampling strategies, applied data " - "augmentations, etc." - ), + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", ) group.add_argument( "--feature-dir", @@ -73,91 +71,75 @@ class YesNoAsrDataModule(DataModule): "--max-duration", type=int, default=30.0, - help=( - "Maximum pooled recordings duration (seconds) in a " - "single batch. You can reduce it if it causes CUDA OOM." - ), + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", ) group.add_argument( "--bucketing-sampler", type=str2bool, default=False, - help=( - "When enabled, the batches will come from buckets of " - "similar duration (saves padding frames)." - ), + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", ) group.add_argument( "--num-buckets", type=int, default=10, - help=( - "The number of buckets for the DynamicBucketingSampler" - "(you might want to increase it for larger datasets)." - ), + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", ) group.add_argument( "--concatenate-cuts", type=str2bool, default=False, - help=( - "When enabled, utterances (cuts) will be concatenated " - "to minimize the amount of padding." - ), + help="When enabled, utterances (cuts) will be concatenated " + "to minimize the amount of padding.", ) group.add_argument( "--duration-factor", type=float, default=1.0, - help=( - "Determines the maximum duration of a concatenated cut " - "relative to the duration of the longest cut in a batch." - ), + help="Determines the maximum duration of a concatenated cut " + "relative to the duration of the longest cut in a batch.", ) group.add_argument( "--gap", type=float, default=1.0, - help=( - "The amount of padding (in seconds) inserted between " - "concatenated cuts. This padding is filled with noise when " - "noise augmentation is used." - ), + help="The amount of padding (in seconds) inserted between " + "concatenated cuts. This padding is filled with noise when " + "noise augmentation is used.", ) group.add_argument( "--on-the-fly-feats", type=str2bool, default=False, - help=( - "When enabled, use on-the-fly cut mixing and feature " - "extraction. Will drop existing precomputed feature manifests " - "if available." - ), + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", ) group.add_argument( "--shuffle", type=str2bool, default=True, - help=( - "When enabled (=default), the examples will be shuffled for each epoch." - ), + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", ) group.add_argument( "--return-cuts", type=str2bool, default=True, - help=( - "When enabled, each batch will have the " - "field: batch['supervisions']['cut'] with the cuts that " - "were used to construct it." - ), + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", ) group.add_argument( "--num-workers", type=int, default=2, - help="The number of training dataloader workers that collect the batches.", + help="The number of training dataloader workers that " + "collect the batches.", ) def train_dataloaders(self) -> DataLoader: @@ -168,7 +150,7 @@ class YesNoAsrDataModule(DataModule): transforms = [] if self.args.concatenate_cuts: logging.info( - "Using cut concatenation with duration factor " + f"Using cut concatenation with duration factor " f"{self.args.duration_factor} and gap {self.args.gap}." ) # Cut concatenation should be the first transform in the list, diff --git a/egs/yesno/ASR/tdnn/decode.py b/egs/yesno/ASR/tdnn/decode.py index 41afe0404..9d4ab4b61 100755 --- a/egs/yesno/ASR/tdnn/decode.py +++ b/egs/yesno/ASR/tdnn/decode.py @@ -35,19 +35,16 @@ def get_parser(): "--epoch", type=int, default=14, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=2, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -204,7 +201,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -275,7 +274,9 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu")) + HLG = k2.Fsa.from_dict( + torch.load(f"{params.lang_dir}/HLG.pt", map_location="cpu") + ) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -296,7 +297,9 @@ def main(): if params.export: logging.info(f"Export averaged model to {params.exp_dir}/pretrained.pt") - torch.save({"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt") + torch.save( + {"model": model.state_dict()}, f"{params.exp_dir}/pretrained.pt" + ) return model.to(device) @@ -314,7 +317,9 @@ def main(): word_table=lexicon.word_table, ) - save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results) + save_results( + exp_dir=params.exp_dir, test_set_name="test_set", results=results + ) logging.info("Done!") diff --git a/egs/yesno/ASR/tdnn/pretrained.py b/egs/yesno/ASR/tdnn/pretrained.py index 09a8672ae..14220be19 100755 --- a/egs/yesno/ASR/tdnn/pretrained.py +++ b/egs/yesno/ASR/tdnn/pretrained.py @@ -41,11 +41,9 @@ def get_parser(): "--checkpoint", type=str, required=True, - help=( - "Path to the checkpoint. " - "The checkpoint is assumed to be saved by " - "icefall.checkpoint.save_checkpoint()." - ), + help="Path to the checkpoint. " + "The checkpoint is assumed to be saved by " + "icefall.checkpoint.save_checkpoint().", ) parser.add_argument( @@ -55,18 +53,18 @@ def get_parser(): help="Path to words.txt", ) - parser.add_argument("--HLG", type=str, required=True, help="Path to HLG.pt.") + parser.add_argument( + "--HLG", type=str, required=True, help="Path to HLG.pt." + ) parser.add_argument( "sound_files", type=str, nargs="+", - help=( - "The input sound file(s) to transcribe. " - "Supported formats are those supported by torchaudio.load(). " - "For example, wav and flac are supported. " - "The sample rate has to be 16kHz." - ), + help="The input sound file(s) to transcribe. " + "Supported formats are those supported by torchaudio.load(). " + "For example, wav and flac are supported. " + "The sample rate has to be 16kHz.", ) return parser @@ -103,9 +101,10 @@ def read_sound_files( ans = [] for f in filenames: wave, sample_rate = torchaudio.load(f) - assert ( - sample_rate == expected_sample_rate - ), f"expected sample rate: {expected_sample_rate}. Given: {sample_rate}" + assert sample_rate == expected_sample_rate, ( + f"expected sample rate: {expected_sample_rate}. " + f"Given: {sample_rate}" + ) # We use only the first channel ans.append(wave[0]) return ans @@ -160,7 +159,9 @@ def main(): logging.info("Decoding started") features = fbank(waves) - features = pad_sequence(features, batch_first=True, padding_value=math.log(1e-10)) + features = pad_sequence( + features, batch_first=True, padding_value=math.log(1e-10) + ) # Note: We don't use key padding mask for attention during decoding with torch.no_grad(): @@ -200,7 +201,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/egs/yesno/ASR/tdnn/train.py b/egs/yesno/ASR/tdnn/train.py index 335493491..f32a27f35 100755 --- a/egs/yesno/ASR/tdnn/train.py +++ b/egs/yesno/ASR/tdnn/train.py @@ -430,7 +430,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: valid_info = compute_validation_loss( diff --git a/egs/yesno/ASR/transducer/decode.py b/egs/yesno/ASR/transducer/decode.py index de478334e..6714180db 100755 --- a/egs/yesno/ASR/transducer/decode.py +++ b/egs/yesno/ASR/transducer/decode.py @@ -48,19 +48,16 @@ def get_parser(): "--epoch", type=int, default=125, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( "--exp-dir", @@ -119,7 +116,9 @@ def decode_one_batch( # at entry, feature is (N, T, C) feature_lens = batch["supervisions"]["num_frames"].to(device) - encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + encoder_out, encoder_out_lens = model.encoder( + x=feature, x_lens=feature_lens + ) hyps = [] batch_size = encoder_out.size(0) @@ -187,7 +186,9 @@ def decode_dataset( if batch_idx % 100 == 0: batch_str = f"{batch_idx}/{num_batches}" - logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + logging.info( + f"batch {batch_str}, cuts processed until now is {num_cuts}" + ) return results @@ -302,7 +303,9 @@ def main(): model=model, ) - save_results(exp_dir=params.exp_dir, test_set_name="test_set", results=results) + save_results( + exp_dir=params.exp_dir, test_set_name="test_set", results=results + ) logging.info("Done!") diff --git a/egs/yesno/ASR/transducer/train.py b/egs/yesno/ASR/transducer/train.py index 88866ae81..deb92107d 100755 --- a/egs/yesno/ASR/transducer/train.py +++ b/egs/yesno/ASR/transducer/train.py @@ -430,7 +430,9 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: valid_info = compute_validation_loss( diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py index c31db6e4c..235160e14 100644 --- a/icefall/char_graph_compiler.py +++ b/icefall/char_graph_compiler.py @@ -71,7 +71,9 @@ class CharCtcTrainingGraphCompiler(object): for text in texts: text = re.sub(whitespace, "", text) sub_ids = [ - self.token_table[txt] if txt in self.token_table else self.oov_id + self.token_table[txt] + if txt in self.token_table + else self.oov_id for txt in text ] ids.append(sub_ids) @@ -94,7 +96,9 @@ class CharCtcTrainingGraphCompiler(object): for text in texts: text = text.split("/") sub_ids = [ - self.token_table[txt] if txt in self.token_table else self.oov_id + self.token_table[txt] + if txt in self.token_table + else self.oov_id for txt in text ] ids.append(sub_ids) diff --git a/icefall/checkpoint.py b/icefall/checkpoint.py index 8aa0a8eeb..5069b78e8 100644 --- a/icefall/checkpoint.py +++ b/icefall/checkpoint.py @@ -292,11 +292,15 @@ def find_checkpoints(out_dir: Path, iteration: int = 0) -> List[str]: """ checkpoints = list(glob.glob(f"{out_dir}/checkpoint-[0-9]*.pt")) pattern = re.compile(r"checkpoint-([0-9]+).pt") - iter_checkpoints = [(int(pattern.search(c).group(1)), c) for c in checkpoints] + iter_checkpoints = [ + (int(pattern.search(c).group(1)), c) for c in checkpoints + ] # iter_checkpoints is a list of tuples. Each tuple contains # two elements: (iteration_number, checkpoint-iteration_number.pt) - iter_checkpoints = sorted(iter_checkpoints, reverse=True, key=lambda x: x[0]) + iter_checkpoints = sorted( + iter_checkpoints, reverse=True, key=lambda x: x[0] + ) if iteration >= 0: ans = [ic[1] for ic in iter_checkpoints if ic[0] >= iteration] else: @@ -465,5 +469,7 @@ def average_state_dict( v = state_dict_1[k] if torch.is_floating_point(v): v *= weight_1 - v += state_dict_2[k].to(device=state_dict_1[k].device) * weight_2 + v += ( + state_dict_2[k].to(device=state_dict_1[k].device) * weight_2 + ) v *= scaling_factor diff --git a/icefall/decode.py b/icefall/decode.py index e4c614c4e..099e2d171 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -334,9 +334,13 @@ class Nbest(object): if hasattr(lattice, "aux_labels"): # delete token IDs as it is not needed del word_fsa.aux_labels - word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops(word_fsa) + word_fsa_with_epsilon_loops = k2.linear_fsa_with_self_loops( + word_fsa + ) else: - word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops(word_fsa) + word_fsa_with_epsilon_loops = k2.linear_fst_with_self_loops( + word_fsa + ) path_to_utt_map = self.shape.row_ids(1) @@ -366,7 +370,9 @@ class Nbest(object): # path_lattice has word IDs as labels and token IDs as aux_labels path_lattice = k2.top_sort(k2.connect(path_lattice)) - one_best = k2.shortest_path(path_lattice, use_double_scores=use_double_scores) + one_best = k2.shortest_path( + path_lattice, use_double_scores=use_double_scores + ) one_best = k2.invert(one_best) # Now one_best has token IDs as labels and word IDs as aux_labels @@ -436,7 +442,9 @@ class Nbest(object): scores_shape = self.fsa.arcs.shape().remove_axis(1) # scores_shape has axes [path][arc] - ragged_scores = k2.RaggedTensor(scores_shape, self.fsa.scores.contiguous()) + ragged_scores = k2.RaggedTensor( + scores_shape, self.fsa.scores.contiguous() + ) tot_scores = ragged_scores.sum() @@ -475,7 +483,9 @@ def one_best_decoding( am_scores = saved_am_scores / lm_scale lattice.scores = am_scores + lattice.lm_scores - best_path = k2.shortest_path(lattice, use_double_scores=use_double_scores) + best_path = k2.shortest_path( + lattice, use_double_scores=use_double_scores + ) key = f"lm_scale_{lm_scale}" ans[key] = best_path return ans @@ -686,7 +696,9 @@ def rescore_with_n_best_list( logging.info(f"num_paths before decreasing: {num_paths}") num_paths = int(num_paths / 2) if loop_count >= max_loop_count or num_paths <= 0: - logging.info("Return None as the resulting lattice is too large.") + logging.info( + "Return None as the resulting lattice is too large." + ) return None logging.info( "This OOM is not an error. You can ignore it. " @@ -793,9 +805,13 @@ def rescore_with_whole_lattice( except RuntimeError as e: logging.info(f"Caught exception:\n{e}\n") if loop_count >= max_loop_count: - logging.info("Return None as the resulting lattice is too large.") + logging.info( + "Return None as the resulting lattice is too large." + ) return None - logging.info(f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}") + logging.info( + f"num_arcs before pruning: {inv_lattice.arcs.num_elements()}" + ) logging.info( "This OOM is not an error. You can ignore it. " "If your model does not converge well, or --max-duration " @@ -807,7 +823,9 @@ def rescore_with_whole_lattice( prune_th_list[loop_count], True, ) - logging.info(f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}") + logging.info( + f"num_arcs after pruning: {inv_lattice.arcs.num_elements()}" + ) loop_count += 1 # lat has token IDs as labels @@ -894,7 +912,9 @@ def rescore_with_attention_decoder( logging.info(f"num_paths before decreasing: {num_paths}") num_paths = int(num_paths / 2) if loop_count >= max_loop_count or num_paths <= 0: - logging.info("Return None as the resulting lattice is too large.") + logging.info( + "Return None as the resulting lattice is too large." + ) return None logging.info( "This OOM is not an error. You can ignore it. " diff --git a/icefall/diagnostics.py b/icefall/diagnostics.py index 7b58ffbd4..b075aceac 100644 --- a/icefall/diagnostics.py +++ b/icefall/diagnostics.py @@ -19,7 +19,7 @@ import random from dataclasses import dataclass -from typing import List, Optional, Tuple +from typing import Optional, Tuple, List import torch from torch import Tensor, nn @@ -78,11 +78,11 @@ def get_tensor_stats( elif stats_type == "abs": x = x.abs() elif stats_type == "rms": - x = x**2 + x = x ** 2 elif stats_type == "positive": x = (x > 0).to(dtype=torch.float) else: - assert stats_type in ["value", "max", "min"] + assert stats_type in [ "value", "max", "min" ] sum_dims = [d for d in range(x.ndim) if d != dim] if len(sum_dims) > 0: @@ -121,9 +121,7 @@ class TensorDiagnostic(object): self.name = name self.class_name = None # will assign in accumulate() - self.stats = ( - None # we'll later assign a list to this data member. It's a list of dict. - ) + self.stats = None # we'll later assign a list to this data member. It's a list of dict. # the keys into self.stats[dim] are strings, whose values can be # "abs", "max", "min" ,"value", "positive", "rms", "value". @@ -135,6 +133,7 @@ class TensorDiagnostic(object): # only adding a new element to the list if there was a different dim. # if the string in the key is "eigs", if we detect a length mismatch we put None as the value. + def accumulate(self, x, class_name: Optional[str] = None): """ Accumulate tensors. @@ -186,12 +185,17 @@ class TensorDiagnostic(object): done = True break if not done: - if this_dim_stats[stats_type] != [] and stats_type == "eigs": + if ( + this_dim_stats[stats_type] != [] + and stats_type == "eigs" + ): # >1 size encountered on this dim, e.g. it's a batch or time dimension, # don't accumulat "eigs" stats type, it uses too much memory this_dim_stats[stats_type] = None else: - this_dim_stats[stats_type].append(TensorAndCount(stats, count)) + this_dim_stats[stats_type].append( + TensorAndCount(stats, count) + ) def print_diagnostics(self): """Print diagnostics for each dimension of the tensor.""" @@ -207,6 +211,7 @@ class TensorDiagnostic(object): assert stats_type == "eigs" continue + def get_count(count): return 1 if stats_type in ["max", "min"] else count @@ -216,8 +221,7 @@ class TensorDiagnostic(object): # a dimension that has variable size in different nnet # forwards, e.g. a time dimension in an ASR model. stats = torch.cat( - [x.tensor / get_count(x.count) for x in stats_list], - dim=0, + [x.tensor / get_count(x.count) for x in stats_list], dim=0 ) if stats_type == "eigs": @@ -225,7 +229,9 @@ class TensorDiagnostic(object): eigs, _ = torch.symeig(stats) stats = eigs.abs().sqrt() except: # noqa - print("Error getting eigenvalues, trying another method.") + print( + "Error getting eigenvalues, trying another method." + ) eigs, _ = torch.eig(stats) stats = eigs.abs().sqrt() # sqrt so it reflects data magnitude, like stddev- not variance @@ -236,9 +242,9 @@ class TensorDiagnostic(object): # if `summarize` we print percentiles of the stats; else, # we print out individual elements. - summarize = (len(stats_list) > 1) or self.opts.dim_is_summarized( - stats.numel() - ) + summarize = ( + len(stats_list) > 1 + ) or self.opts.dim_is_summarized(stats.numel()) if summarize: # usually `summarize` will be true # print out percentiles. stats = stats.sort()[0] @@ -255,15 +261,15 @@ class TensorDiagnostic(object): ans = stats.tolist() ans = ["%.2g" % x for x in ans] ans = "[" + " ".join(ans) + "]" - if stats_type in ["value", "rms", "eigs"]: + if stats_type in [ "value", "rms", "eigs" ]: # This norm is useful because it is strictly less than the largest # sqrt(eigenvalue) of the variance, which we print out, and shows, # speaking in an approximate way, how much of that largest eigenvalue # can be attributed to the mean of the distribution. - norm = (stats**2).sum().sqrt().item() + norm = (stats ** 2).sum().sqrt().item() ans += f", norm={norm:.2g}" mean = stats.mean().item() - rms = (stats**2).mean().sqrt().item() + rms = (stats ** 2).mean().sqrt().item() ans += f", mean={mean:.3g}, rms={rms:.3g}" # OK, "ans" contains the actual stats, e.g. @@ -271,17 +277,17 @@ class TensorDiagnostic(object): sizes = [x.tensor.shape[0] for x in stats_list] size_str = ( - f"{sizes[0]}" if len(sizes) == 1 else f"{min(sizes)}..{max(sizes)}" - ) - maybe_class_name = ( - f" type={self.class_name}," if self.class_name is not None else "" + f"{sizes[0]}" + if len(sizes) == 1 + else f"{min(sizes)}..{max(sizes)}" ) + maybe_class_name = f" type={self.class_name}," if self.class_name is not None else "" print( - f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}," - f" {stats_type} {ans}" + f"module={self.name},{maybe_class_name} dim={dim}, size={size_str}, {stats_type} {ans}" ) + class ModelDiagnostic(object): """This class stores diagnostics for all tensors in the torch.nn.Module. @@ -339,32 +345,32 @@ def attach_diagnostics( # (matters for name, since the variable gets overwritten). # These closures don't really capture by value, only by # "the final value the variable got in the function" :-( - def forward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): + def forward_hook( + _module, _input, _output, _model_diagnostic=ans, _name=name + ): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] if isinstance(_output, Tensor): - _model_diagnostic[f"{_name}.output"].accumulate( - _output, class_name=type(_module).__name__ - ) + _model_diagnostic[f"{_name}.output"].accumulate(_output, + class_name=type(_module).__name__) elif isinstance(_output, tuple): for i, o in enumerate(_output): - _model_diagnostic[f"{_name}.output[{i}]"].accumulate( - o, class_name=type(_module).__name__ - ) + _model_diagnostic[f"{_name}.output[{i}]"].accumulate(o, + class_name=type(_module).__name__) - def backward_hook(_module, _input, _output, _model_diagnostic=ans, _name=name): + def backward_hook( + _module, _input, _output, _model_diagnostic=ans, _name=name + ): if isinstance(_output, tuple) and len(_output) == 1: _output = _output[0] if isinstance(_output, Tensor): - _model_diagnostic[f"{_name}.grad"].accumulate( - _output, class_name=type(_module).__name__ - ) + _model_diagnostic[f"{_name}.grad"].accumulate(_output, + class_name=type(_module).__name__) elif isinstance(_output, tuple): for i, o in enumerate(_output): - _model_diagnostic[f"{_name}.grad[{i}]"].accumulate( - o, class_name=type(_module).__name__ - ) + _model_diagnostic[f"{_name}.grad[{i}]"].accumulate(o, + class_name=type(_module).__name__) module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook) diff --git a/icefall/dist.py b/icefall/dist.py index 9df1c5bd1..7016beafb 100644 --- a/icefall/dist.py +++ b/icefall/dist.py @@ -29,7 +29,9 @@ def setup_dist(rank, world_size, master_port=None, use_ddp_launch=False): os.environ["MASTER_ADDR"] = "localhost" if "MASTER_PORT" not in os.environ: - os.environ["MASTER_PORT"] = "12354" if master_port is None else str(master_port) + os.environ["MASTER_PORT"] = ( + "12354" if master_port is None else str(master_port) + ) if use_ddp_launch is False: dist.init_process_group("nccl", rank=rank, world_size=world_size) diff --git a/icefall/env.py b/icefall/env.py index 373e9a9ff..8aeda6be2 100644 --- a/icefall/env.py +++ b/icefall/env.py @@ -53,7 +53,9 @@ def get_git_sha1(): ) > 0 ) - git_commit = git_commit + "-dirty" if dirty_commit else git_commit + "-clean" + git_commit = ( + git_commit + "-dirty" if dirty_commit else git_commit + "-clean" + ) except: # noqa return None diff --git a/icefall/graph_compiler.py b/icefall/graph_compiler.py index e2ff03f61..570ed7d7a 100644 --- a/icefall/graph_compiler.py +++ b/icefall/graph_compiler.py @@ -75,7 +75,9 @@ class CtcTrainingGraphCompiler(object): # NOTE: k2.compose runs on CUDA only when treat_epsilons_specially # is False, so we add epsilon self-loops here - fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops(transcript_fsa) + fsa_with_self_loops = k2.remove_epsilon_and_add_self_loops( + transcript_fsa + ) fsa_with_self_loops = k2.arc_sort(fsa_with_self_loops) diff --git a/icefall/hooks.py b/icefall/hooks.py index 398a5f689..fbcf5e148 100644 --- a/icefall/hooks.py +++ b/icefall/hooks.py @@ -14,11 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import random - import torch from torch import Tensor, nn +import logging def register_inf_check_hooks(model: nn.Module) -> None: @@ -57,7 +56,7 @@ def register_inf_check_hooks(model: nn.Module) -> None: if isinstance(_output, Tensor): if not torch.isfinite(_output.to(torch.float32).sum()): logging.warning( - f"The sum of {_name}.grad is not finite" # ": {_output}" + f"The sum of {_name}.grad is not finite" # ": {_output}" ) elif isinstance(_output, tuple): for i, o in enumerate(_output): @@ -66,20 +65,28 @@ def register_inf_check_hooks(model: nn.Module) -> None: if not isinstance(o, Tensor): continue if not torch.isfinite(o.to(torch.float32).sum()): - logging.warning(f"The sum of {_name}.grad[{i}] is not finite") + logging.warning( + f"The sum of {_name}.grad[{i}] is not finite" + ) module.register_forward_hook(forward_hook) module.register_backward_hook(backward_hook) + for name, parameter in model.named_parameters(): - def param_backward_hook(grad, _name=name): + def param_backward_hook( + grad, _name=name + ): if not torch.isfinite(grad.to(torch.float32).sum()): - logging.warning(f"The sum of {_name}.param_grad is not finite") + logging.warning( + f"The sum of {_name}.param_grad is not finite" + ) parameter.register_hook(param_backward_hook) + def _test_inf_check_hooks(): model = nn.Sequential(nn.Linear(100, 50), nn.Linear(50, 80)) diff --git a/icefall/lexicon.py b/icefall/lexicon.py index 22e1b78bb..80bd7c1ee 100644 --- a/icefall/lexicon.py +++ b/icefall/lexicon.py @@ -49,12 +49,18 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]: continue if len(a) < 2: - logging.info(f"Found bad line {line} in lexicon file {filename}") - logging.info("Every line is expected to contain at least 2 fields") + logging.info( + f"Found bad line {line} in lexicon file {filename}" + ) + logging.info( + "Every line is expected to contain at least 2 fields" + ) sys.exit(1) word = a[0] if word == " ": - logging.info(f"Found bad line {line} in lexicon file {filename}") + logging.info( + f"Found bad line {line} in lexicon file {filename}" + ) logging.info(" should not be a valid word") sys.exit(1) @@ -113,7 +119,9 @@ def convert_lexicon_to_ragged( lexicon_tmp = read_lexicon(filename) lexicon = dict(lexicon_tmp) if len(lexicon_tmp) != len(lexicon): - raise RuntimeError("It's assumed that each word has a unique pronunciation") + raise RuntimeError( + "It's assumed that each word has a unique pronunciation" + ) for i in range(disambig_id): w = word_table[i] diff --git a/icefall/mmi.py b/icefall/mmi.py index 16ed6e032..2c479fc2c 100644 --- a/icefall/mmi.py +++ b/icefall/mmi.py @@ -63,7 +63,10 @@ def _compute_mmi_loss_exact_optimized( # [0, num_fsas, 1, num_fsas, 2, num_fsas, ... ] num_den_graphs_indexes = ( - torch.stack([num_graphs_indexes, den_graphs_indexes]).t().reshape(-1).to(device) + torch.stack([num_graphs_indexes, den_graphs_indexes]) + .t() + .reshape(-1) + .to(device) ) num_den_reordered_graphs = k2.index(num_den_graphs, num_den_graphs_indexes) @@ -112,12 +115,20 @@ def _compute_mmi_loss_exact_non_optimized( num_graphs, den_graphs = graph_compiler.compile(texts, replicate_den=True) # TODO: pass output_beam as function argument - num_lats = k2.intersect_dense(num_graphs, dense_fsa_vec, output_beam=beam_size) - den_lats = k2.intersect_dense(den_graphs, dense_fsa_vec, output_beam=beam_size) + num_lats = k2.intersect_dense( + num_graphs, dense_fsa_vec, output_beam=beam_size + ) + den_lats = k2.intersect_dense( + den_graphs, dense_fsa_vec, output_beam=beam_size + ) - num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) + num_tot_scores = num_lats.get_tot_scores( + log_semiring=True, use_double_scores=True + ) - den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) + den_tot_scores = den_lats.get_tot_scores( + log_semiring=True, use_double_scores=True + ) tot_scores = num_tot_scores - den_scale * den_tot_scores @@ -157,9 +168,13 @@ def _compute_mmi_loss_pruned( max_active_states=10000, ) - num_tot_scores = num_lats.get_tot_scores(log_semiring=True, use_double_scores=True) + num_tot_scores = num_lats.get_tot_scores( + log_semiring=True, use_double_scores=True + ) - den_tot_scores = den_lats.get_tot_scores(log_semiring=True, use_double_scores=True) + den_tot_scores = den_lats.get_tot_scores( + log_semiring=True, use_double_scores=True + ) tot_scores = num_tot_scores - den_scale * den_tot_scores diff --git a/icefall/mmi_graph_compiler.py b/icefall/mmi_graph_compiler.py index 9f680f83d..0d901227d 100644 --- a/icefall/mmi_graph_compiler.py +++ b/icefall/mmi_graph_compiler.py @@ -137,7 +137,9 @@ class MmiTrainingGraphCompiler(object): transcript_fsa ) - transcript_fsa_with_self_loops = k2.arc_sort(transcript_fsa_with_self_loops) + transcript_fsa_with_self_loops = k2.arc_sort( + transcript_fsa_with_self_loops + ) num = k2.compose( self.ctc_topo_P, @@ -153,7 +155,9 @@ class MmiTrainingGraphCompiler(object): ctc_topo_P_vec = k2.create_fsa_vec([self.ctc_topo_P]) if replicate_den: - indexes = torch.zeros(len(texts), dtype=torch.int32, device=self.device) + indexes = torch.zeros( + len(texts), dtype=torch.int32, device=self.device + ) den = k2.index_fsa(ctc_topo_P_vec, indexes) else: den = ctc_topo_P_vec diff --git a/icefall/rnn_lm/compute_perplexity.py b/icefall/rnn_lm/compute_perplexity.py index 9a275bf28..550801a8f 100755 --- a/icefall/rnn_lm/compute_perplexity.py +++ b/icefall/rnn_lm/compute_perplexity.py @@ -46,19 +46,16 @@ def get_parser(): "--epoch", type=int, default=49, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=20, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -197,7 +194,7 @@ def main(): logging.info(f"Number of model parameters: {num_param}") logging.info( - "Number of model parameters (requires_grad): " + f"Number of model parameters (requires_grad): " f"{num_param_requires_grad} " f"({num_param_requires_grad/num_param_requires_grad*100}%)" ) diff --git a/icefall/rnn_lm/dataset.py b/icefall/rnn_lm/dataset.py index 4bf982503..598e329c4 100644 --- a/icefall/rnn_lm/dataset.py +++ b/icefall/rnn_lm/dataset.py @@ -155,8 +155,12 @@ class LmDatasetCollate: sentence_tokens_with_sos = add_sos(sentence_tokens, self.sos_id) sentence_tokens_with_eos = add_eos(sentence_tokens, self.eos_id) - x = sentence_tokens_with_sos.pad(mode="constant", padding_value=self.blank_id) - y = sentence_tokens_with_eos.pad(mode="constant", padding_value=self.blank_id) + x = sentence_tokens_with_sos.pad( + mode="constant", padding_value=self.blank_id + ) + y = sentence_tokens_with_eos.pad( + mode="constant", padding_value=self.blank_id + ) sentence_token_lengths += 1 # plus 1 since we added a SOS return x.to(torch.int64), y.to(torch.int64), sentence_token_lengths diff --git a/icefall/rnn_lm/export.py b/icefall/rnn_lm/export.py index 2e878f5c8..094035fce 100644 --- a/icefall/rnn_lm/export.py +++ b/icefall/rnn_lm/export.py @@ -38,20 +38,17 @@ def get_parser(): "--epoch", type=int, default=29, - help=( - "It specifies the checkpoint to use for decoding.Note: Epoch counts from 0." - ), + help="It specifies the checkpoint to use for decoding." + "Note: Epoch counts from 0.", ) parser.add_argument( "--avg", type=int, default=5, - help=( - "Number of checkpoints to average. Automatically select " - "consecutive checkpoints before the checkpoint specified by " - "'--epoch'. " - ), + help="Number of checkpoints to average. Automatically select " + "consecutive checkpoints before the checkpoint specified by " + "'--epoch'. ", ) parser.add_argument( @@ -162,7 +159,9 @@ def main(): if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) logging.basicConfig(format=formatter, level=logging.INFO) main() diff --git a/icefall/rnn_lm/model.py b/icefall/rnn_lm/model.py index 9eef88840..a6144727a 100644 --- a/icefall/rnn_lm/model.py +++ b/icefall/rnn_lm/model.py @@ -129,7 +129,9 @@ class RnnLmModel(torch.nn.Module): tokens_eos = add_eos(tokens, eos_id) sos_tokens_row_splits = sos_tokens.shape.row_splits(1) - sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + sentence_lengths = ( + sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + ) x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) @@ -159,12 +161,12 @@ class RnnLmModel(torch.nn.Module): if state: h, c = state else: - h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to( - device - ) - c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size).to( - device - ) + h = torch.zeros( + self.rnn.num_layers, batch_size, self.rnn.input_size + ).to(device) + c = torch.zeros( + self.rnn.num_layers, batch_size, self.rnn.input_size + ).to(device) embedding = self.input_embedding(tokens) rnn_out, states = self.rnn(embedding, (h, c)) @@ -179,8 +181,12 @@ class RnnLmModel(torch.nn.Module): if state: h, c = state else: - h = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) - c = torch.zeros(self.rnn.num_layers, batch_size, self.rnn.input_size) + h = torch.zeros( + self.rnn.num_layers, batch_size, self.rnn.input_size + ) + c = torch.zeros( + self.rnn.num_layers, batch_size, self.rnn.input_size + ) device = next(self.parameters()).device @@ -188,7 +194,9 @@ class RnnLmModel(torch.nn.Module): tokens_eos = add_eos(tokens, eos_id) sos_tokens_row_splits = sos_tokens.shape.row_splits(1) - sentence_lengths = sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + sentence_lengths = ( + sos_tokens_row_splits[1:] - sos_tokens_row_splits[:-1] + ) x_tokens = sos_tokens.pad(mode="constant", padding_value=blank_id) y_tokens = tokens_eos.pad(mode="constant", padding_value=blank_id) diff --git a/icefall/rnn_lm/train.py b/icefall/rnn_lm/train.py index e17b50332..bb5f03fb9 100755 --- a/icefall/rnn_lm/train.py +++ b/icefall/rnn_lm/train.py @@ -446,13 +446,17 @@ def train_one_epoch( loss_info.write_summary( tb_writer, "train/current_", params.batch_idx_train ) - tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) tb_writer.add_scalar( "train/current_ppl", this_batch_ppl, params.batch_idx_train ) - tb_writer.add_scalar("train/tot_ppl", tot_ppl, params.batch_idx_train) + tb_writer.add_scalar( + "train/tot_ppl", tot_ppl, params.batch_idx_train + ) if batch_idx > 0 and batch_idx % params.valid_interval == 0: logging.info("Computing validation loss") @@ -467,7 +471,8 @@ def train_one_epoch( valid_ppl = math.exp(valid_info["loss"] / valid_info["frames"]) logging.info( - f"Epoch {params.cur_epoch}, validation: {valid_info}, ppl: {valid_ppl}" + f"Epoch {params.cur_epoch}, validation: {valid_info}, " + f"ppl: {valid_ppl}" ) if tb_writer is not None: diff --git a/icefall/shared/make_kn_lm.py b/icefall/shared/make_kn_lm.py index a3bf1ef4c..c2edd823e 100755 --- a/icefall/shared/make_kn_lm.py +++ b/icefall/shared/make_kn_lm.py @@ -15,50 +15,30 @@ # The data structure is based on: kaldi/egs/wsj/s5/utils/lang/make_phone_lm.py # The smoothing algorithm is based on: http://www.speech.sri.com/projects/srilm/manpages/ngram-discount.7.html -import argparse -import io -import math +import sys import os import re -import sys +import io +import math +import argparse from collections import Counter, defaultdict -parser = argparse.ArgumentParser( - description=""" + +parser = argparse.ArgumentParser(description=""" Generate kneser-ney language model as arpa format. By default, it will read the corpus from standard input, and output to standard output. - """ -) -parser.add_argument( - "-ngram-order", - type=int, - default=4, - choices=[2, 3, 4, 5, 6, 7], - help="Order of n-gram", -) + """) +parser.add_argument("-ngram-order", type=int, default=4, choices=[2, 3, 4, 5, 6, 7], help="Order of n-gram") parser.add_argument("-text", type=str, default=None, help="Path to the corpus file") -parser.add_argument( - "-lm", - type=str, - default=None, - help="Path to output arpa file for language models", -) -parser.add_argument( - "-verbose", - type=int, - default=0, - choices=[0, 1, 2, 3, 4, 5], - help="Verbose level", -) +parser.add_argument("-lm", type=str, default=None, help="Path to output arpa file for language models") +parser.add_argument("-verbose", type=int, default=0, choices=[0, 1, 2, 3, 4, 5], help="Verbose level") args = parser.parse_args() -default_encoding = ( - "latin-1" # For encoding-agnostic scripts, we assume byte stream as input. -) -# Need to be very careful about the use of strip() and split() -# in this case, because there is a latin-1 whitespace character -# (nbsp) which is part of the unicode encoding range. -# Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717 +default_encoding = "latin-1" # For encoding-agnostic scripts, we assume byte stream as input. + # Need to be very careful about the use of strip() and split() + # in this case, because there is a latin-1 whitespace character + # (nbsp) which is part of the unicode encoding range. + # Ref: kaldi/egs/wsj/s5/utils/lang/bpe/prepend_words.py @ 69cd717 strip_chars = " \t\r\n" whitespace = re.compile("[ \t]+") @@ -72,9 +52,7 @@ class CountsForHistory: # The 'lambda: defaultdict(float)' is an anonymous function taking no # arguments that returns a new defaultdict(float). self.word_to_count = defaultdict(int) - self.word_to_context = defaultdict( - set - ) # using a set to count the number of unique contexts + self.word_to_context = defaultdict(set) # using a set to count the number of unique contexts self.word_to_f = dict() # discounted probability self.word_to_bow = dict() # back-off weight self.total_count = 0 @@ -84,15 +62,10 @@ class CountsForHistory: def __str__(self): # e.g. returns ' total=12: 3->4, 4->6, -1->2' - return " total={0}: {1}".format( + return ' total={0}: {1}'.format( str(self.total_count), - ", ".join( - [ - "{0} -> {1}".format(word, count) - for word, count in self.word_to_count.items() - ] - ), - ) + ', '.join(['{0} -> {1}'.format(word, count) + for word, count in self.word_to_count.items()])) def add_count(self, predicted_word, context_word, count): assert count >= 0 @@ -112,7 +85,7 @@ class NgramCounts: # accumulating the 4-gram count for the '8' in the sequence '5 6 7 8', we'd # do as follows: self.counts[3][[5,6,7]][8] += 1.0 where the [3] indexes an # array, the [[5,6,7]] indexes a dict, and the [8] indexes a dict. - def __init__(self, ngram_order, bos_symbol=" ", eos_symbol=""): + def __init__(self, ngram_order, bos_symbol='', eos_symbol=''): assert ngram_order >= 2 self.ngram_order = ngram_order @@ -130,48 +103,39 @@ class NgramCounts: # would be (6,7,8) and 'predicted_word' would be 9; 'count' would be # 1. def add_count(self, history, predicted_word, context_word, count): - self.counts[len(history)][history].add_count( - predicted_word, context_word, count - ) + self.counts[len(history)][history].add_count(predicted_word, context_word, count) # 'line' is a string containing a sequence of integer word-ids. # This function adds the un-smoothed counts from this line of text. def add_raw_counts_from_line(self, line): - if line == "": + if line == '': words = [self.bos_symbol, self.eos_symbol] else: words = [self.bos_symbol] + whitespace.split(line) + [self.eos_symbol] for i in range(len(words)): - for n in range(1, self.ngram_order + 1): + for n in range(1, self.ngram_order+1): if i + n > len(words): break - ngram = words[i : i + n] + ngram = words[i: i + n] predicted_word = ngram[-1] - history = tuple(ngram[:-1]) + history = tuple(ngram[: -1]) if i == 0 or n == self.ngram_order: context_word = None else: - context_word = words[i - 1] + context_word = words[i-1] self.add_count(history, predicted_word, context_word, 1) def add_raw_counts_from_standard_input(self): lines_processed = 0 - infile = io.TextIOWrapper( - sys.stdin.buffer, encoding=default_encoding - ) # byte stream as input + infile = io.TextIOWrapper(sys.stdin.buffer, encoding=default_encoding) # byte stream as input for line in infile: line = line.strip(strip_chars) self.add_raw_counts_from_line(line) lines_processed += 1 if lines_processed == 0 or args.verbose > 0: - print( - "make_phone_lm.py: processed {0} lines of input".format( - lines_processed - ), - file=sys.stderr, - ) + print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr) def add_raw_counts_from_file(self, filename): lines_processed = 0 @@ -181,12 +145,7 @@ class NgramCounts: self.add_raw_counts_from_line(line) lines_processed += 1 if lines_processed == 0 or args.verbose > 0: - print( - "make_phone_lm.py: processed {0} lines of input".format( - lines_processed - ), - file=sys.stderr, - ) + print("make_phone_lm.py: processed {0} lines of input".format(lines_processed), file=sys.stderr) def cal_discounting_constants(self): # For each order N of N-grams, we calculate discounting constant D_N = n1_N / (n1_N + 2 * n2_N), @@ -194,11 +153,9 @@ class NgramCounts: # This constant is used similarly to absolute discounting. # Return value: d is a list of floats, where d[N+1] = D_N - self.d = [ - 0 - ] # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0 - # This is a special case: as we currently assumed having seen all vocabularies in the dictionary, - # but perhaps this is not the case for some other scenarios. + self.d = [0] # for the lowest order, i.e., 1-gram, we do not need to discount, thus the constant is 0 + # This is a special case: as we currently assumed having seen all vocabularies in the dictionary, + # but perhaps this is not the case for some other scenarios. for n in range(1, self.ngram_order): this_order_counts = self.counts[n] n1 = 0 @@ -208,11 +165,9 @@ class NgramCounts: n1 += stat[1] n2 += stat[2] assert n1 + 2 * n2 > 0 - self.d.append( - max(0.1, n1 * 1.0) / (n1 + 2 * n2) - ) # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0, - # which could happen if the number of symbols is small. - # Otherwise, zero discounting constant can cause division by zero in computing BOW. + self.d.append(max(0.1, n1 * 1.0) / (n1 + 2 * n2)) # We are doing this max(0.001, xxx) to avoid zero discounting constant D due to n1=0, + # which could happen if the number of symbols is small. + # Otherwise, zero discounting constant can cause division by zero in computing BOW. def cal_f(self): # f(a_z) is a probability distribution of word sequence a_z. @@ -227,9 +182,7 @@ class NgramCounts: this_order_counts = self.counts[n] for hist, counts_for_hist in this_order_counts.items(): for w, c in counts_for_hist.word_to_count.items(): - counts_for_hist.word_to_f[w] = ( - max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count - ) + counts_for_hist.word_to_f[w] = max((c - self.d[n]), 0) * 1.0 / counts_for_hist.total_count # lower order N-grams for n in range(0, self.ngram_order - 1): @@ -243,17 +196,11 @@ class NgramCounts: if n_star_star != 0: for w in counts_for_hist.word_to_count.keys(): n_star_z = len(counts_for_hist.word_to_context[w]) - counts_for_hist.word_to_f[w] = ( - max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star - ) + counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / n_star_star else: # patterns begin with, they do not have "modified count", so use raw count instead for w in counts_for_hist.word_to_count.keys(): n_star_z = counts_for_hist.word_to_count[w] - counts_for_hist.word_to_f[w] = ( - max((n_star_z - self.d[n]), 0) - * 1.0 - / counts_for_hist.total_count - ) + counts_for_hist.word_to_f[w] = max((n_star_z - self.d[n]), 0) * 1.0 / counts_for_hist.total_count def cal_bow(self): # Backoff weights are only necessary for ngrams which form a prefix of a longer ngram. @@ -293,18 +240,12 @@ class NgramCounts: sum_z1_f_z = 0 _ = a_[1:] _counts_for_hist = self.counts[len(_)][_] - for ( - u - ) in ( - a_counts_for_hist.word_to_count.keys() - ): # Should be careful here: what is Z1 + for u in a_counts_for_hist.word_to_count.keys(): # Should be careful here: what is Z1 sum_z1_f_z += _counts_for_hist.word_to_f[u] if sum_z1_f_z < 1: # assert sum_z1_f_a_z < 1 - counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / ( - 1.0 - sum_z1_f_z - ) + counts_for_hist.word_to_bow[w] = (1.0 - sum_z1_f_a_z) / (1.0 - sum_z1_f_z) else: counts_for_hist.word_to_bow[w] = None @@ -318,9 +259,7 @@ class NgramCounts: ngram = " ".join(hist) + " " + w ngram = ngram.strip(strip_chars) - res.append( - "{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w]) - ) + res.append("{0}\t{1}".format(ngram, counts_for_hist.word_to_count[w])) res.sort(reverse=True) for r in res: print(r) @@ -383,40 +322,27 @@ class NgramCounts: if bow is None: res.append("{1}\t{0}".format(ngram, math.log(f, 10))) else: - res.append( - "{1}\t{0}\t{2}".format( - ngram, math.log(f, 10), math.log(bow, 10) - ) - ) + res.append("{1}\t{0}\t{2}".format(ngram, math.log(f, 10), math.log(bow, 10))) res.sort(reverse=True) for r in res: print(r) - def print_as_arpa( - self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding="latin-1") - ): + def print_as_arpa(self, fout=io.TextIOWrapper(sys.stdout.buffer, encoding='latin-1')): # print as ARPA format. - print("\\data\\", file=fout) + print('\\data\\', file=fout) for hist_len in range(self.ngram_order): # print the number of n-grams. - print( - "ngram {0}={1}".format( - hist_len + 1, - sum( - [ - len(counts_for_hist.word_to_f) - for counts_for_hist in self.counts[hist_len].values() - ] - ), - ), - file=fout, + print('ngram {0}={1}'.format( + hist_len + 1, + sum([len(counts_for_hist.word_to_f) for counts_for_hist in self.counts[hist_len].values()])), + file=fout ) - print("", file=fout) + print('', file=fout) for hist_len in range(self.ngram_order): - print("\\{0}-grams:".format(hist_len + 1), file=fout) + print('\\{0}-grams:'.format(hist_len + 1), file=fout) this_order_counts = self.counts[hist_len] for hist, counts_for_hist in this_order_counts.items(): @@ -428,12 +354,12 @@ class NgramCounts: if prob == 0: # f() is always 0 prob = 1e-99 - line = "{0}\t{1}".format("%.7f" % math.log10(prob), " ".join(ngram)) + line = '{0}\t{1}'.format('%.7f' % math.log10(prob), ' '.join(ngram)) if bow is not None: - line += "\t{0}".format("%.7f" % math.log10(bow)) + line += '\t{0}'.format('%.7f' % math.log10(bow)) print(line, file=fout) - print("", file=fout) - print("\\end\\", file=fout) + print('', file=fout) + print('\\end\\', file=fout) if __name__ == "__main__": @@ -453,5 +379,5 @@ if __name__ == "__main__": if args.lm is None: ngram_counts.print_as_arpa() else: - with open(args.lm, "w", encoding=default_encoding) as f: + with open(args.lm, 'w', encoding=default_encoding) as f: ngram_counts.print_as_arpa(fout=f) diff --git a/icefall/utils.py b/icefall/utils.py index 785bd80f9..143c79497 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -130,7 +130,9 @@ def setup_logger( formatter = f"%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] ({rank}/{world_size}) %(message)s" # noqa log_filename = f"{log_filename}-{date_time}-{rank}" else: - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) log_filename = f"{log_filename}-{date_time}" os.makedirs(os.path.dirname(log_filename), exist_ok=True) @@ -201,7 +203,7 @@ def encode_supervisions( supervisions["num_frames"], subsampling_factor, rounding_mode="floor", - ), + ) ), 1, ).to(torch.int32) @@ -286,9 +288,13 @@ def get_texts_with_timestamp( """ if isinstance(best_paths.aux_labels, k2.RaggedTensor): all_aux_shape = ( - best_paths.arcs.shape().remove_axis(1).compose(best_paths.aux_labels.shape) + best_paths.arcs.shape() + .remove_axis(1) + .compose(best_paths.aux_labels.shape) + ) + all_aux_labels = k2.RaggedTensor( + all_aux_shape, best_paths.aux_labels.values ) - all_aux_labels = k2.RaggedTensor(all_aux_shape, best_paths.aux_labels.values) # remove 0's and -1's. aux_labels = best_paths.aux_labels.remove_values_leq(0) # TODO: change arcs.shape() to arcs.shape @@ -357,7 +363,9 @@ def get_alignments(best_paths: k2.Fsa, kind: str) -> List[List[int]]: # arc.shape() has axes [fsa][state][arc], we remove "state"-axis here token_shape = best_paths.arcs.shape().remove_axis(1) # token_shape has axes [fsa][arc] - tokens = k2.RaggedTensor(token_shape, getattr(best_paths, kind).contiguous()) + tokens = k2.RaggedTensor( + token_shape, getattr(best_paths, kind).contiguous() + ) tokens = tokens.remove_values_eq(-1) return tokens.tolist() @@ -578,7 +586,9 @@ def write_error_stats( f"{cut_id}:\t" + " ".join( ( - ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" + ref_word + if ref_word == hyp_word + else f"({ref_word}->{hyp_word})" for ref_word, hyp_word in ali ) ), @@ -588,7 +598,9 @@ def write_error_stats( print("", file=f) print("SUBSTITUTIONS: count ref -> hyp", file=f) - for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): + for count, (ref, hyp) in sorted( + [(v, k) for k, v in subs.items()], reverse=True + ): print(f"{count} {ref} -> {hyp}", file=f) print("", file=f) @@ -602,7 +614,9 @@ def write_error_stats( print(f"{count} {hyp}", file=f) print("", file=f) - print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) + print( + "PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f + ) for _, word, counts in sorted( [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True ): @@ -777,7 +791,9 @@ def write_error_stats_with_timestamps( f"{cut_id}:\t" + " ".join( ( - ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})" + ref_word + if ref_word == hyp_word + else f"({ref_word}->{hyp_word})" for ref_word, hyp_word in ali ) ), @@ -787,7 +803,9 @@ def write_error_stats_with_timestamps( print("", file=f) print("SUBSTITUTIONS: count ref -> hyp", file=f) - for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True): + for count, (ref, hyp) in sorted( + [(v, k) for k, v in subs.items()], reverse=True + ): print(f"{count} {ref} -> {hyp}", file=f) print("", file=f) @@ -801,7 +819,9 @@ def write_error_stats_with_timestamps( print(f"{count} {hyp}", file=f) print("", file=f) - print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f) + print( + "PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f + ) for _, word, counts in sorted( [(sum(v[1:]), k, v) for k, v in words.items()], reverse=True ): @@ -871,7 +891,9 @@ class MetricsTracker(collections.defaultdict): if k == "frames" or k == "utterances": continue norm_value = ( - float(v) / num_frames if "utt_" not in k else float(v) / num_utterances + float(v) / num_frames + if "utt_" not in k + else float(v) / num_utterances ) ans.append((k, norm_value)) return ans @@ -905,7 +927,9 @@ class MetricsTracker(collections.defaultdict): tb_writer.add_scalar(prefix + k, v, batch_idx) -def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTensor: +def concat( + ragged: k2.RaggedTensor, value: int, direction: str +) -> k2.RaggedTensor: """Prepend a value to the beginning of each sublist or append a value. to the end of each sublist. @@ -951,8 +975,8 @@ def concat(ragged: k2.RaggedTensor, value: int, direction: str) -> k2.RaggedTens ans = k2.ragged.cat([ragged, pad], axis=1) else: raise ValueError( - f'Unsupported direction: {direction}. " "Expect either "left"' - ' or "right"' + f'Unsupported direction: {direction}. " \ + "Expect either "left" or "right"' ) return ans @@ -1077,7 +1101,9 @@ def linf_norm(x): return torch.max(torch.abs(x)) -def measure_weight_norms(model: nn.Module, norm: str = "l2") -> Dict[str, float]: +def measure_weight_norms( + model: nn.Module, norm: str = "l2" +) -> Dict[str, float]: """ Compute the norms of the model's parameters. @@ -1100,7 +1126,9 @@ def measure_weight_norms(model: nn.Module, norm: str = "l2") -> Dict[str, float] return norms -def measure_gradient_norms(model: nn.Module, norm: str = "l1") -> Dict[str, float]: +def measure_gradient_norms( + model: nn.Module, norm: str = "l1" +) -> Dict[str, float]: """ Compute the norms of the gradients for each of model's parameters. @@ -1385,7 +1413,9 @@ def parse_hyp_and_timestamp( use_word_table = True for i in range(N): - time = convert_timestamp(res.timestamps[i], subsampling_factor, frame_shift_ms) + time = convert_timestamp( + res.timestamps[i], subsampling_factor, frame_shift_ms + ) if use_word_table: words = [word_table[i] for i in res.hyps[i]] else: diff --git a/pyproject.toml b/pyproject.toml index 3183055d4..b4f8c3377 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ profile = "black" skip = ["icefall/__init__.py"] [tool.black] -line-length = 88 +line-length = 80 exclude = ''' /( \.git diff --git a/setup.py b/setup.py index ccd2503ff..6c720e121 100644 --- a/setup.py +++ b/setup.py @@ -1,8 +1,7 @@ #!/usr/bin/env python3 -from pathlib import Path - from setuptools import find_packages, setup +from pathlib import Path icefall_dir = Path(__file__).parent install_requires = (icefall_dir / "requirements.txt").read_text().splitlines() diff --git a/test/test_checkpoint.py b/test/test_checkpoint.py index 34e829642..511a11c23 100644 --- a/test/test_checkpoint.py +++ b/test/test_checkpoint.py @@ -20,7 +20,11 @@ import pytest import torch import torch.nn as nn -from icefall.checkpoint import average_checkpoints, load_checkpoint, save_checkpoint +from icefall.checkpoint import ( + average_checkpoints, + load_checkpoint, + save_checkpoint, +) @pytest.fixture diff --git a/test/test_decode.py b/test/test_decode.py index 4c2e192a7..97964ac67 100644 --- a/test/test_decode.py +++ b/test/test_decode.py @@ -23,7 +23,6 @@ You can run this file in one of the two ways: """ import k2 - from icefall.decode import Nbest diff --git a/test/test_graph_compiler.py b/test/test_graph_compiler.py index 10443cf22..ccfb57d49 100644 --- a/test/test_graph_compiler.py +++ b/test/test_graph_compiler.py @@ -154,7 +154,9 @@ class TestCtcTrainingGraphCompiler(object): fsas = k2.Fsa.from_fsas([fsa1, fsa2]) decoding_graph = k2.arc_sort(decoding_graph) - lattice = k2.intersect(decoding_graph, fsas, treat_epsilons_specially=False) + lattice = k2.intersect( + decoding_graph, fsas, treat_epsilons_specially=False + ) lattice = k2.connect(lattice) aux_labels0 = lattice[0].aux_labels[:-1] diff --git a/test/test_utils.py b/test/test_utils.py index 31f06bd51..6a9ce7853 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -50,7 +50,9 @@ def test_encode_supervisions(sup): assert torch.all( torch.eq( supervision_segments, - torch.tensor([[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]]), + torch.tensor( + [[1, 0, 30 // 4], [0, 0, 20 // 4], [2, 9 // 4, 10 // 4]] + ), ) ) assert texts == ["two", "one", "three"]