diff --git a/.flake8 b/.flake8 index b8f0e4715..19c3a9bd6 100644 --- a/.flake8 +++ b/.flake8 @@ -5,6 +5,7 @@ max-line-length = 80 per-file-ignores = # line too long egs/librispeech/ASR/*/conformer.py: E501, + egs/aishell/ASR/*/conformer.py: E501, exclude = .git, diff --git a/docs/source/recipes/aishell/conformer_ctc.rst b/docs/source/recipes/aishell/conformer_ctc.rst index d225be9c6..59741833c 100644 --- a/docs/source/recipes/aishell/conformer_ctc.rst +++ b/docs/source/recipes/aishell/conformer_ctc.rst @@ -107,7 +107,7 @@ The following options are used quite often: It is the number of epochs to train. For instance, ``./conformer_ctc/train.py --num-epochs 30`` trains for 30 epochs and generates ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt`` - in the folder set with ``--exp-dir``. + in the folder set by ``--exp-dir``. - ``--start-epoch`` diff --git a/egs/aishell/ASR/conformer_ctc/conformer.py b/egs/aishell/ASR/conformer_ctc/conformer.py index 5b136f40e..7bd0f95cf 100644 --- a/egs/aishell/ASR/conformer_ctc/conformer.py +++ b/egs/aishell/ASR/conformer_ctc/conformer.py @@ -453,7 +453,6 @@ class RelPositionMultiheadAttention(nn.Module): self._reset_parameters() - def _reset_parameters(self) -> None: nn.init.xavier_uniform_(self.in_proj.weight) nn.init.constant_(self.in_proj.bias, 0.0) @@ -683,7 +682,6 @@ class RelPositionMultiheadAttention(nn.Module): _b = _b[_start:] v = nn.functional.linear(value, _w, _b) - if attn_mask is not None: assert ( attn_mask.dtype == torch.float32 diff --git a/egs/aishell/ASR/local/prepare_char.py b/egs/aishell/ASR/local/prepare_char.py index 5b4ec323d..d9e47d17a 100755 --- a/egs/aishell/ASR/local/prepare_char.py +++ b/egs/aishell/ASR/local/prepare_char.py @@ -33,10 +33,9 @@ and generates the following files in the directory `lang_dir`: - tokens.txt """ -import argparse import re from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List import k2 import torch @@ -87,7 +86,9 @@ def lexicon_to_fst_no_sil( cur_state = loop_state word = word2id[word] - pieces = [token2id[i] if i in token2id else token2id[''] for i in pieces] + pieces = [ + token2id[i] if i in token2id else token2id[""] for i in pieces + ] for i in range(len(pieces) - 1): w = word if i == 0 else eps @@ -136,7 +137,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool: otherwise False. """ for tok in tokens: - if not tok in token_sym_table: + if tok not in token_sym_table: return True return False @@ -178,18 +179,18 @@ def generate_tokens(text_file: str) -> Dict[str, int]: from 0 to len(keys) - 1. """ tokens: Dict[str, int] = dict() - tokens[''] = 0 - tokens[''] = 1 - tokens[''] = 2 + tokens[""] = 0 + tokens[""] = 1 + tokens[""] = 2 whitespace = re.compile(r"([ \t\r\n]+)") with open(text_file, "r", encoding="utf-8") as f: for line in f: line = re.sub(whitespace, "", line) chars = list(line) for char in chars: - if not char in tokens: + if char not in tokens: tokens[char] = len(tokens) - return tokens + return tokens def main(): diff --git a/icefall/char_graph_compiler.py b/icefall/char_graph_compiler.py index 9ce948f7c..4a79a300a 100644 --- a/icefall/char_graph_compiler.py +++ b/icefall/char_graph_compiler.py @@ -54,7 +54,6 @@ class CharCtcTrainingGraphCompiler(object): self.sos_id = self.token_table[sos_token] self.eos_id = self.token_table[eos_token] - def texts_to_ids(self, texts: List[str]) -> List[List[int]]: """Convert a list of texts to a list-of-list of token IDs. @@ -71,12 +70,15 @@ class CharCtcTrainingGraphCompiler(object): whitespace = re.compile(r"([ \t])") for text in texts: text = re.sub(whitespace, "", text) - sub_ids = [self.token_table[txt] if txt in self.token_table \ - else self.oov_id for txt in text] + sub_ids = [ + self.token_table[txt] + if txt in self.token_table + else self.oov_id + for txt in text + ] ids.append(sub_ids) return ids - def compile( self, token_ids: List[List[int]], @@ -95,4 +97,3 @@ class CharCtcTrainingGraphCompiler(object): piece IDs. """ return k2.ctc_graph(token_ids, modified=modified, device=self.device) -