mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Fix code style
This commit is contained in:
parent
cbc5557c87
commit
ebf142cb98
1
.flake8
1
.flake8
@ -5,6 +5,7 @@ max-line-length = 80
|
|||||||
per-file-ignores =
|
per-file-ignores =
|
||||||
# line too long
|
# line too long
|
||||||
egs/librispeech/ASR/*/conformer.py: E501,
|
egs/librispeech/ASR/*/conformer.py: E501,
|
||||||
|
egs/aishell/ASR/*/conformer.py: E501,
|
||||||
|
|
||||||
exclude =
|
exclude =
|
||||||
.git,
|
.git,
|
||||||
|
@ -107,7 +107,7 @@ The following options are used quite often:
|
|||||||
It is the number of epochs to train. For instance,
|
It is the number of epochs to train. For instance,
|
||||||
``./conformer_ctc/train.py --num-epochs 30`` trains for 30 epochs
|
``./conformer_ctc/train.py --num-epochs 30`` trains for 30 epochs
|
||||||
and generates ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt``
|
and generates ``epoch-0.pt``, ``epoch-1.pt``, ..., ``epoch-29.pt``
|
||||||
in the folder set with ``--exp-dir``.
|
in the folder set by ``--exp-dir``.
|
||||||
|
|
||||||
- ``--start-epoch``
|
- ``--start-epoch``
|
||||||
|
|
||||||
|
@ -453,7 +453,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
|
|
||||||
self._reset_parameters()
|
self._reset_parameters()
|
||||||
|
|
||||||
|
|
||||||
def _reset_parameters(self) -> None:
|
def _reset_parameters(self) -> None:
|
||||||
nn.init.xavier_uniform_(self.in_proj.weight)
|
nn.init.xavier_uniform_(self.in_proj.weight)
|
||||||
nn.init.constant_(self.in_proj.bias, 0.0)
|
nn.init.constant_(self.in_proj.bias, 0.0)
|
||||||
@ -683,7 +682,6 @@ class RelPositionMultiheadAttention(nn.Module):
|
|||||||
_b = _b[_start:]
|
_b = _b[_start:]
|
||||||
v = nn.functional.linear(value, _w, _b)
|
v = nn.functional.linear(value, _w, _b)
|
||||||
|
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
assert (
|
assert (
|
||||||
attn_mask.dtype == torch.float32
|
attn_mask.dtype == torch.float32
|
||||||
|
@ -33,10 +33,9 @@ and generates the following files in the directory `lang_dir`:
|
|||||||
- tokens.txt
|
- tokens.txt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
|
||||||
import re
|
import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List
|
||||||
|
|
||||||
import k2
|
import k2
|
||||||
import torch
|
import torch
|
||||||
@ -87,7 +86,9 @@ def lexicon_to_fst_no_sil(
|
|||||||
cur_state = loop_state
|
cur_state = loop_state
|
||||||
|
|
||||||
word = word2id[word]
|
word = word2id[word]
|
||||||
pieces = [token2id[i] if i in token2id else token2id['<unk>'] for i in pieces]
|
pieces = [
|
||||||
|
token2id[i] if i in token2id else token2id["<unk>"] for i in pieces
|
||||||
|
]
|
||||||
|
|
||||||
for i in range(len(pieces) - 1):
|
for i in range(len(pieces) - 1):
|
||||||
w = word if i == 0 else eps
|
w = word if i == 0 else eps
|
||||||
@ -136,7 +137,7 @@ def contain_oov(token_sym_table: Dict[str, int], tokens: List[str]) -> bool:
|
|||||||
otherwise False.
|
otherwise False.
|
||||||
"""
|
"""
|
||||||
for tok in tokens:
|
for tok in tokens:
|
||||||
if not tok in token_sym_table:
|
if tok not in token_sym_table:
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@ -178,18 +179,18 @@ def generate_tokens(text_file: str) -> Dict[str, int]:
|
|||||||
from 0 to len(keys) - 1.
|
from 0 to len(keys) - 1.
|
||||||
"""
|
"""
|
||||||
tokens: Dict[str, int] = dict()
|
tokens: Dict[str, int] = dict()
|
||||||
tokens['<blk>'] = 0
|
tokens["<blk>"] = 0
|
||||||
tokens['<sos/eos>'] = 1
|
tokens["<sos/eos>"] = 1
|
||||||
tokens['<unk>'] = 2
|
tokens["<unk>"] = 2
|
||||||
whitespace = re.compile(r"([ \t\r\n]+)")
|
whitespace = re.compile(r"([ \t\r\n]+)")
|
||||||
with open(text_file, "r", encoding="utf-8") as f:
|
with open(text_file, "r", encoding="utf-8") as f:
|
||||||
for line in f:
|
for line in f:
|
||||||
line = re.sub(whitespace, "", line)
|
line = re.sub(whitespace, "", line)
|
||||||
chars = list(line)
|
chars = list(line)
|
||||||
for char in chars:
|
for char in chars:
|
||||||
if not char in tokens:
|
if char not in tokens:
|
||||||
tokens[char] = len(tokens)
|
tokens[char] = len(tokens)
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
@ -54,7 +54,6 @@ class CharCtcTrainingGraphCompiler(object):
|
|||||||
self.sos_id = self.token_table[sos_token]
|
self.sos_id = self.token_table[sos_token]
|
||||||
self.eos_id = self.token_table[eos_token]
|
self.eos_id = self.token_table[eos_token]
|
||||||
|
|
||||||
|
|
||||||
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
def texts_to_ids(self, texts: List[str]) -> List[List[int]]:
|
||||||
"""Convert a list of texts to a list-of-list of token IDs.
|
"""Convert a list of texts to a list-of-list of token IDs.
|
||||||
|
|
||||||
@ -71,12 +70,15 @@ class CharCtcTrainingGraphCompiler(object):
|
|||||||
whitespace = re.compile(r"([ \t])")
|
whitespace = re.compile(r"([ \t])")
|
||||||
for text in texts:
|
for text in texts:
|
||||||
text = re.sub(whitespace, "", text)
|
text = re.sub(whitespace, "", text)
|
||||||
sub_ids = [self.token_table[txt] if txt in self.token_table \
|
sub_ids = [
|
||||||
else self.oov_id for txt in text]
|
self.token_table[txt]
|
||||||
|
if txt in self.token_table
|
||||||
|
else self.oov_id
|
||||||
|
for txt in text
|
||||||
|
]
|
||||||
ids.append(sub_ids)
|
ids.append(sub_ids)
|
||||||
return ids
|
return ids
|
||||||
|
|
||||||
|
|
||||||
def compile(
|
def compile(
|
||||||
self,
|
self,
|
||||||
token_ids: List[List[int]],
|
token_ids: List[List[int]],
|
||||||
@ -95,4 +97,3 @@ class CharCtcTrainingGraphCompiler(object):
|
|||||||
piece IDs.
|
piece IDs.
|
||||||
"""
|
"""
|
||||||
return k2.ctc_graph(token_ids, modified=modified, device=self.device)
|
return k2.ctc_graph(token_ids, modified=modified, device=self.device)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user