Fix code style

This commit is contained in:
pkufool 2021-11-17 19:02:44 +08:00
parent cbc5557c87
commit ebf142cb98
5 changed files with 18 additions and 17 deletions

View File

@ -5,6 +5,7 @@ max-line-length = 80
per-file-ignores = per-file-ignores =
# line too long # line too long
egs/librispeech/ASR/*/conformer.py: E501, egs/librispeech/ASR/*/conformer.py: E501,
egs/aishell/ASR/*/conformer.py: E501,
exclude = exclude =
.git, .git,

View File

@ -107,7 +107,7 @@ The following options are used quite often:
It is the number of epochs to train. For instance, It is the number of epochs to train. For instance,
``./conformer_ctc/train.py --num-epochs 30`` trains for 30 epochs ``./conformer_ctc/train.py --num-epochs 30`` trains for 30 epochs
and generates ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt`` and generates ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt``
in the folder set with ``--exp-dir``. in the folder set by ``--exp-dir``.
- ``--start-epoch`` - ``--start-epoch``

View File

@ -453,7 +453,6 @@ class RelPositionMultiheadAttention(nn.Module):
self._reset_parameters() self._reset_parameters()
def _reset_parameters(self) -> None: def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight) nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0) nn.init.constant_(self.in_proj.bias, 0.0)
@ -683,7 +682,6 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:] _b = _b[_start:]
v = nn.functional.linear(value, _w, _b) v = nn.functional.linear(value, _w, _b)
if attn_mask is not None: if attn_mask is not None:
assert ( assert (
attn_mask.dtype == torch.float32 attn_mask.dtype == torch.float32

View File

@ -33,10 +33,9 @@ and generates the following files in the directory `lang_dir`:
- tokens.txt - tokens.txt
""" """
import argparse
import re import re
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List
import k2 import k2
import torch import torch
@ -87,7 +86,9 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state cur_state = loop_state
word = word2id[word] word = word2id[word]
pieces = [token2id[i] if i in token2id else token2id['<unk>'] for i in pieces] pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1): for i in range(len(pieces) - 1):
w = word if i == 0 else eps w = word if i == 0 else eps
@ -136,7 +137,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
otherwise False. otherwise False.
""" """
for tok in tokens: for tok in tokens:
if not tok in token_sym_table: if tok not in token_sym_table:
return True return True
return False return False
@ -178,18 +179,18 @@ def generate_tokens(text_file: str) -> Dict[str, int]:
from 0 to len(keys) - 1. from 0 to len(keys) - 1.
""" """
tokens: Dict[str, int] = dict() tokens: Dict[str, int] = dict()
tokens['<blk>'] = 0 tokens["<blk>"] = 0
tokens['<sos/eos>'] = 1 tokens["<sos/eos>"] = 1
tokens['<unk>'] = 2 tokens["<unk>"] = 2
whitespace = re.compile(r"([ \t\r\n]+)") whitespace = re.compile(r"([ \t\r\n]+)")
with open(text_file, "r", encoding="utf-8") as f: with open(text_file, "r", encoding="utf-8") as f:
for line in f: for line in f:
line = re.sub(whitespace, "", line) line = re.sub(whitespace, "", line)
chars = list(line) chars = list(line)
for char in chars: for char in chars:
if not char in tokens: if char not in tokens:
tokens[char] = len(tokens) tokens[char] = len(tokens)
return tokens return tokens
def main(): def main():

View File

@ -54,7 +54,6 @@ class CharCtcTrainingGraphCompiler(object):
self.sos_id = self.token_table[sos_token] self.sos_id = self.token_table[sos_token]
self.eos_id = self.token_table[eos_token] self.eos_id = self.token_table[eos_token]
def texts_to_ids(self, texts: List[str]) -> List[List[int]]: def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts to a list-of-list of token IDs. """Convert a list of texts to a list-of-list of token IDs.
@ -71,12 +70,15 @@ class CharCtcTrainingGraphCompiler(object):
whitespace = re.compile(r"([ \t])") whitespace = re.compile(r"([ \t])")
for text in texts: for text in texts:
text = re.sub(whitespace, "", text) text = re.sub(whitespace, "", text)
sub_ids = [self.token_table[txt] if txt in self.token_table \ sub_ids = [
else self.oov_id for txt in text] self.token_table[txt]
if txt in self.token_table
else self.oov_id
for txt in text
]
ids.append(sub_ids) ids.append(sub_ids)
return ids return ids
def compile( def compile(
self, self,
token_ids: List[List[int]], token_ids: List[List[int]],
@ -95,4 +97,3 @@ class CharCtcTrainingGraphCompiler(object):
piece IDs. piece IDs.
""" """
return k2.ctc_graph(token_ids, modified=modified, device=self.device) return k2.ctc_graph(token_ids, modified=modified, device=self.device)