Fix code style

This commit is contained in:
pkufool 2021-11-17 19:02:44 +08:00
parent cbc5557c87
commit ebf142cb98
5 changed files with 18 additions and 17 deletions

View File

@ -5,6 +5,7 @@ max-line-length = 80
per-file-ignores =
# line too long
egs/librispeech/ASR/*/conformer.py: E501,
egs/aishell/ASR/*/conformer.py: E501,
exclude =
.git,

View File

@ -107,7 +107,7 @@ The following options are used quite often:
It is the number of epochs to train. For instance,
``./conformer_ctc/train.py --num-epochs 30`` trains for 30 epochs
and generates ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt``
in the folder set with ``--exp-dir``.
in the folder set by ``--exp-dir``.
- ``--start-epoch``

View File

@ -453,7 +453,6 @@ class RelPositionMultiheadAttention(nn.Module):
self._reset_parameters()
def _reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.constant_(self.in_proj.bias, 0.0)
@ -683,7 +682,6 @@ class RelPositionMultiheadAttention(nn.Module):
_b = _b[_start:]
v = nn.functional.linear(value, _w, _b)
if attn_mask is not None:
assert (
attn_mask.dtype == torch.float32

View File

@ -33,10 +33,9 @@ and generates the following files in the directory `lang_dir`:
- tokens.txt
"""
import argparse
import re
from pathlib import Path
from typing import Dict, List, Tuple
from typing import Dict, List
import k2
import torch
@ -87,7 +86,9 @@ def lexicon_to_fst_no_sil(
cur_state = loop_state
word = word2id[word]
pieces = [token2id[i] if i in token2id else token2id['<unk>'] for i in pieces]
pieces = [
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
]
for i in range(len(pieces) - 1):
w = word if i == 0 else eps
@ -136,7 +137,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
otherwise False.
"""
for tok in tokens:
if not tok in token_sym_table:
if tok not in token_sym_table:
return True
return False
@ -178,18 +179,18 @@ def generate_tokens(text_file: str) -> Dict[str, int]:
from 0 to len(keys) - 1.
"""
tokens: Dict[str, int] = dict()
tokens['<blk>'] = 0
tokens['<sos/eos>'] = 1
tokens['<unk>'] = 2
tokens["<blk>"] = 0
tokens["<sos/eos>"] = 1
tokens["<unk>"] = 2
whitespace = re.compile(r"([ \t\r\n]+)")
with open(text_file, "r", encoding="utf-8") as f:
for line in f:
line = re.sub(whitespace, "", line)
chars = list(line)
for char in chars:
if not char in tokens:
if char not in tokens:
tokens[char] = len(tokens)
return tokens
return tokens
def main():

View File

@ -54,7 +54,6 @@ class CharCtcTrainingGraphCompiler(object):
self.sos_id = self.token_table[sos_token]
self.eos_id = self.token_table[eos_token]
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
"""Convert a list of texts to a list-of-list of token IDs.
@ -71,12 +70,15 @@ class CharCtcTrainingGraphCompiler(object):
whitespace = re.compile(r"([ \t])")
for text in texts:
text = re.sub(whitespace, "", text)
sub_ids = [self.token_table[txt] if txt in self.token_table \
else self.oov_id for txt in text]
sub_ids = [
self.token_table[txt]
if txt in self.token_table
else self.oov_id
for txt in text
]
ids.append(sub_ids)
return ids
def compile(
self,
token_ids: List[List[int]],
@ -95,4 +97,3 @@ class CharCtcTrainingGraphCompiler(object):
piece IDs.
"""
return k2.ctc_graph(token_ids, modified=modified, device=self.device)