mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Fix code style
This commit is contained in:
parent
cbc5557c87
commit
ebf142cb98
1
.flake8
1
.flake8
@ -5,6 +5,7 @@ max-line-length = 80
|
||||
per-file-ignores =
|
||||
# line too long
|
||||
egs/librispeech/ASR/*/conformer.py: E501,
|
||||
egs/aishell/ASR/*/conformer.py: E501,
|
||||
|
||||
exclude =
|
||||
.git,
|
||||
|
@ -107,7 +107,7 @@ The following options are used quite often:
|
||||
It is the number of epochs to train. For instance,
|
||||
``./conformer_ctc/train.py --num-epochs 30`` trains for 30 epochs
|
||||
and generates ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt``
|
||||
in the folder set with ``--exp-dir``.
|
||||
in the folder set by ``--exp-dir``.
|
||||
|
||||
- ``--start-epoch``
|
||||
|
||||
|
@ -453,7 +453,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
|
||||
self._reset_parameters()
|
||||
|
||||
|
||||
def _reset_parameters(self) -> None:
|
||||
nn.init.xavier_uniform_(self.in_proj.weight)
|
||||
nn.init.constant_(self.in_proj.bias, 0.0)
|
||||
@ -683,7 +682,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
||||
_b = _b[_start:]
|
||||
v = nn.functional.linear(value, _w, _b)
|
||||
|
||||
|
||||
if attn_mask is not None:
|
||||
assert (
|
||||
attn_mask.dtype == torch.float32
|
||||
|
@ -33,10 +33,9 @@ and generates the following files in the directory `lang_dir`:
|
||||
- tokens.txt
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List
|
||||
|
||||
import k2
|
||||
import torch
|
||||
@ -87,7 +86,9 @@ def lexicon_to_fst_no_sil(
|
||||
cur_state = loop_state
|
||||
|
||||
word = word2id[word]
|
||||
pieces = [token2id[i] if i in token2id else token2id['<unk>'] for i in pieces]
|
||||
pieces = [
|
||||
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
|
||||
]
|
||||
|
||||
for i in range(len(pieces) - 1):
|
||||
w = word if i == 0 else eps
|
||||
@ -136,7 +137,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
|
||||
otherwise False.
|
||||
"""
|
||||
for tok in tokens:
|
||||
if not tok in token_sym_table:
|
||||
if tok not in token_sym_table:
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -178,18 +179,18 @@ def generate_tokens(text_file: str) -> Dict[str, int]:
|
||||
from 0 to len(keys) - 1.
|
||||
"""
|
||||
tokens: Dict[str, int] = dict()
|
||||
tokens['<blk>'] = 0
|
||||
tokens['<sos/eos>'] = 1
|
||||
tokens['<unk>'] = 2
|
||||
tokens["<blk>"] = 0
|
||||
tokens["<sos/eos>"] = 1
|
||||
tokens["<unk>"] = 2
|
||||
whitespace = re.compile(r"([ \t\r\n]+)")
|
||||
with open(text_file, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = re.sub(whitespace, "", line)
|
||||
chars = list(line)
|
||||
for char in chars:
|
||||
if not char in tokens:
|
||||
if char not in tokens:
|
||||
tokens[char] = len(tokens)
|
||||
return tokens
|
||||
return tokens
|
||||
|
||||
|
||||
def main():
|
||||
|
@ -54,7 +54,6 @@ class CharCtcTrainingGraphCompiler(object):
|
||||
self.sos_id = self.token_table[sos_token]
|
||||
self.eos_id = self.token_table[eos_token]
|
||||
|
||||
|
||||
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
||||
"""Convert a list of texts to a list-of-list of token IDs.
|
||||
|
||||
@ -71,12 +70,15 @@ class CharCtcTrainingGraphCompiler(object):
|
||||
whitespace = re.compile(r"([ \t])")
|
||||
for text in texts:
|
||||
text = re.sub(whitespace, "", text)
|
||||
sub_ids = [self.token_table[txt] if txt in self.token_table \
|
||||
else self.oov_id for txt in text]
|
||||
sub_ids = [
|
||||
self.token_table[txt]
|
||||
if txt in self.token_table
|
||||
else self.oov_id
|
||||
for txt in text
|
||||
]
|
||||
ids.append(sub_ids)
|
||||
return ids
|
||||
|
||||
|
||||
def compile(
|
||||
self,
|
||||
token_ids: List[List[int]],
|
||||
@ -95,4 +97,3 @@ class CharCtcTrainingGraphCompiler(object):
|
||||
piece IDs.
|
||||
"""
|
||||
return k2.ctc_graph(token_ids, modified=modified, device=self.device)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user