mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
Merge branch 'master' into streaming5
This commit is contained in:
commit
995f260f91
@ -120,10 +120,17 @@ class Conformer(EncoderInterface):
|
|||||||
layer_dropout,
|
layer_dropout,
|
||||||
cnn_module_kernel,
|
cnn_module_kernel,
|
||||||
)
|
)
|
||||||
|
# aux_layers from 1/3
|
||||||
self.encoder = ConformerEncoder(
|
self.encoder = ConformerEncoder(
|
||||||
encoder_layer,
|
encoder_layer,
|
||||||
num_encoder_layers,
|
num_encoder_layers,
|
||||||
aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)),
|
aux_layers=list(
|
||||||
|
range(
|
||||||
|
num_encoder_layers // 3,
|
||||||
|
num_encoder_layers - 1,
|
||||||
|
aux_layer_period,
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -362,10 +369,8 @@ class ConformerEncoder(nn.Module):
|
|||||||
assert num_layers - 1 not in aux_layers
|
assert num_layers - 1 not in aux_layers
|
||||||
self.aux_layers = aux_layers + [num_layers - 1]
|
self.aux_layers = aux_layers + [num_layers - 1]
|
||||||
|
|
||||||
num_channels = encoder_layer.norm_final.num_channels
|
|
||||||
self.combiner = RandomCombine(
|
self.combiner = RandomCombine(
|
||||||
num_inputs=len(self.aux_layers),
|
num_inputs=len(self.aux_layers),
|
||||||
num_channels=num_channels,
|
|
||||||
final_weight=0.5,
|
final_weight=0.5,
|
||||||
pure_prob=0.333,
|
pure_prob=0.333,
|
||||||
stddev=2.0,
|
stddev=2.0,
|
||||||
@ -1206,7 +1211,6 @@ class RandomCombine(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
num_inputs: int,
|
num_inputs: int,
|
||||||
num_channels: int,
|
|
||||||
final_weight: float = 0.5,
|
final_weight: float = 0.5,
|
||||||
pure_prob: float = 0.5,
|
pure_prob: float = 0.5,
|
||||||
stddev: float = 2.0,
|
stddev: float = 2.0,
|
||||||
@ -1217,8 +1221,6 @@ class RandomCombine(nn.Module):
|
|||||||
The number of tensor inputs, which equals the number of layers'
|
The number of tensor inputs, which equals the number of layers'
|
||||||
outputs that are fed into this module. E.g. in an 18-layer neural
|
outputs that are fed into this module. E.g. in an 18-layer neural
|
||||||
net if we output layers 16, 12, 18, num_inputs would be 3.
|
net if we output layers 16, 12, 18, num_inputs would be 3.
|
||||||
num_channels:
|
|
||||||
The number of channels on the input, e.g. 512.
|
|
||||||
final_weight:
|
final_weight:
|
||||||
The amount of weight or probability we assign to the
|
The amount of weight or probability we assign to the
|
||||||
final layer when randomly choosing layers or when choosing
|
final layer when randomly choosing layers or when choosing
|
||||||
@ -1249,13 +1251,6 @@ class RandomCombine(nn.Module):
|
|||||||
assert 0 < final_weight < 1, final_weight
|
assert 0 < final_weight < 1, final_weight
|
||||||
assert num_inputs >= 1
|
assert num_inputs >= 1
|
||||||
|
|
||||||
self.linear = nn.ModuleList(
|
|
||||||
[
|
|
||||||
nn.Linear(num_channels, num_channels, bias=True)
|
|
||||||
for _ in range(num_inputs - 1)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
self.num_inputs = num_inputs
|
self.num_inputs = num_inputs
|
||||||
self.final_weight = final_weight
|
self.final_weight = final_weight
|
||||||
self.pure_prob = pure_prob
|
self.pure_prob = pure_prob
|
||||||
@ -1268,12 +1263,6 @@ class RandomCombine(nn.Module):
|
|||||||
.log()
|
.log()
|
||||||
.item()
|
.item()
|
||||||
)
|
)
|
||||||
self._reset_parameters()
|
|
||||||
|
|
||||||
def _reset_parameters(self):
|
|
||||||
for i in range(len(self.linear)):
|
|
||||||
nn.init.eye_(self.linear[i].weight)
|
|
||||||
nn.init.constant_(self.linear[i].bias, 0.0)
|
|
||||||
|
|
||||||
def forward(self, inputs: List[Tensor]) -> Tensor:
|
def forward(self, inputs: List[Tensor]) -> Tensor:
|
||||||
"""Forward function.
|
"""Forward function.
|
||||||
@ -1294,28 +1283,9 @@ class RandomCombine(nn.Module):
|
|||||||
num_channels = inputs[0].shape[-1]
|
num_channels = inputs[0].shape[-1]
|
||||||
num_frames = inputs[0].numel() // num_channels
|
num_frames = inputs[0].numel() // num_channels
|
||||||
|
|
||||||
mod_inputs = []
|
|
||||||
|
|
||||||
if False:
|
|
||||||
# It throws the following error for torch 1.6.0 when using
|
|
||||||
# torch script.
|
|
||||||
#
|
|
||||||
# Expected integer literal for index. ModuleList/Sequential
|
|
||||||
# indexing is only supported with integer literals. Enumeration is
|
|
||||||
# supported, e.g. 'for index, v in enumerate(self): ...':
|
|
||||||
# for i in range(num_inputs - 1):
|
|
||||||
# mod_inputs.append(self.linear[i](inputs[i]))
|
|
||||||
assert False
|
|
||||||
else:
|
|
||||||
for i, linear in enumerate(self.linear):
|
|
||||||
if i < num_inputs - 1:
|
|
||||||
mod_inputs.append(linear(inputs[i]))
|
|
||||||
|
|
||||||
mod_inputs.append(inputs[num_inputs - 1])
|
|
||||||
|
|
||||||
ndim = inputs[0].ndim
|
ndim = inputs[0].ndim
|
||||||
# stacked_inputs: (num_frames, num_channels, num_inputs)
|
# stacked_inputs: (num_frames, num_channels, num_inputs)
|
||||||
stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape(
|
stacked_inputs = torch.stack(inputs, dim=ndim).reshape(
|
||||||
(num_frames, num_channels, num_inputs)
|
(num_frames, num_channels, num_inputs)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -58,15 +58,19 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str):
|
|||||||
Return:
|
Return:
|
||||||
The lexicon.txt file and the train.text in lang_dir.
|
The lexicon.txt file and the train.text in lang_dir.
|
||||||
"""
|
"""
|
||||||
|
import gzip
|
||||||
|
|
||||||
phones = set()
|
phones = set()
|
||||||
|
|
||||||
supervisions_train = Path(manifests_dir) / "supervisions_TRAIN.json"
|
supervisions_train = (
|
||||||
|
Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz"
|
||||||
|
)
|
||||||
lexicon = Path(lang_dir) / "lexicon.txt"
|
lexicon = Path(lang_dir) / "lexicon.txt"
|
||||||
|
|
||||||
logging.info(f"Loading {supervisions_train}!")
|
logging.info(f"Loading {supervisions_train}!")
|
||||||
with open(supervisions_train, "r") as load_f:
|
with gzip.open(supervisions_train, "r") as load_f:
|
||||||
load_dicts = json.load(load_f)
|
for line in load_f.readlines():
|
||||||
for load_dict in load_dicts:
|
load_dict = json.loads(line)
|
||||||
text = load_dict["text"]
|
text = load_dict["text"]
|
||||||
# list the phone units and filter the empty item
|
# list the phone units and filter the empty item
|
||||||
phones_list = list(filter(None, text.split()))
|
phones_list = list(filter(None, text.split()))
|
||||||
|
@ -192,13 +192,6 @@ class WenetSpeechAsrDataModule:
|
|||||||
"with training dataset. ",
|
"with training dataset. ",
|
||||||
)
|
)
|
||||||
|
|
||||||
group.add_argument(
|
|
||||||
"--lazy-load",
|
|
||||||
type=str2bool,
|
|
||||||
default=True,
|
|
||||||
help="lazily open CutSets to avoid OOM (for L|XL subset)",
|
|
||||||
)
|
|
||||||
|
|
||||||
group.add_argument(
|
group.add_argument(
|
||||||
"--training-subset",
|
"--training-subset",
|
||||||
type=str,
|
type=str,
|
||||||
@ -420,14 +413,7 @@ class WenetSpeechAsrDataModule:
|
|||||||
@lru_cache()
|
@lru_cache()
|
||||||
def train_cuts(self) -> CutSet:
|
def train_cuts(self) -> CutSet:
|
||||||
logging.info("About to get train cuts")
|
logging.info("About to get train cuts")
|
||||||
if self.args.lazy_load:
|
cuts_train = load_manifest_lazy(
|
||||||
logging.info("use lazy cuts")
|
|
||||||
cuts_train = CutSet.from_jsonl_lazy(
|
|
||||||
self.args.manifest_dir
|
|
||||||
/ f"cuts_{self.args.training_subset}.jsonl.gz"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
cuts_train = CutSet.from_file(
|
|
||||||
self.args.manifest_dir
|
self.args.manifest_dir
|
||||||
/ f"cuts_{self.args.training_subset}.jsonl.gz"
|
/ f"cuts_{self.args.training_subset}.jsonl.gz"
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user