diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py index 771658273..a0f37f148 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/conformer.py @@ -120,10 +120,17 @@ class Conformer(EncoderInterface): layer_dropout, cnn_module_kernel, ) + # aux_layers from 1/3 self.encoder = ConformerEncoder( encoder_layer, num_encoder_layers, - aux_layers=list(range(0, num_encoder_layers - 1, aux_layer_period)), + aux_layers=list( + range( + num_encoder_layers // 3, + num_encoder_layers - 1, + aux_layer_period, + ) + ), ) def forward( @@ -362,10 +369,8 @@ class ConformerEncoder(nn.Module): assert num_layers - 1 not in aux_layers self.aux_layers = aux_layers + [num_layers - 1] - num_channels = encoder_layer.norm_final.num_channels self.combiner = RandomCombine( num_inputs=len(self.aux_layers), - num_channels=num_channels, final_weight=0.5, pure_prob=0.333, stddev=2.0, @@ -1206,7 +1211,6 @@ class RandomCombine(nn.Module): def __init__( self, num_inputs: int, - num_channels: int, final_weight: float = 0.5, pure_prob: float = 0.5, stddev: float = 2.0, @@ -1217,8 +1221,6 @@ class RandomCombine(nn.Module): The number of tensor inputs, which equals the number of layers' outputs that are fed into this module. E.g. in an 18-layer neural net if we output layers 16, 12, 18, num_inputs would be 3. - num_channels: - The number of channels on the input, e.g. 512. final_weight: The amount of weight or probability we assign to the final layer when randomly choosing layers or when choosing @@ -1249,13 +1251,6 @@ class RandomCombine(nn.Module): assert 0 < final_weight < 1, final_weight assert num_inputs >= 1 - self.linear = nn.ModuleList( - [ - nn.Linear(num_channels, num_channels, bias=True) - for _ in range(num_inputs - 1) - ] - ) - self.num_inputs = num_inputs self.final_weight = final_weight self.pure_prob = pure_prob @@ -1268,12 +1263,6 @@ class RandomCombine(nn.Module): .log() .item() ) - self._reset_parameters() - - def _reset_parameters(self): - for i in range(len(self.linear)): - nn.init.eye_(self.linear[i].weight) - nn.init.constant_(self.linear[i].bias, 0.0) def forward(self, inputs: List[Tensor]) -> Tensor: """Forward function. @@ -1294,28 +1283,9 @@ class RandomCombine(nn.Module): num_channels = inputs[0].shape[-1] num_frames = inputs[0].numel() // num_channels - mod_inputs = [] - - if False: - # It throws the following error for torch 1.6.0 when using - # torch script. - # - # Expected integer literal for index. ModuleList/Sequential - # indexing is only supported with integer literals. Enumeration is - # supported, e.g. 'for index, v in enumerate(self): ...': - # for i in range(num_inputs - 1): - # mod_inputs.append(self.linear[i](inputs[i])) - assert False - else: - for i, linear in enumerate(self.linear): - if i < num_inputs - 1: - mod_inputs.append(linear(inputs[i])) - - mod_inputs.append(inputs[num_inputs - 1]) - ndim = inputs[0].ndim # stacked_inputs: (num_frames, num_channels, num_inputs) - stacked_inputs = torch.stack(mod_inputs, dim=ndim).reshape( + stacked_inputs = torch.stack(inputs, dim=ndim).reshape( (num_frames, num_channels, num_inputs) ) diff --git a/egs/timit/ASR/local/prepare_lexicon.py b/egs/timit/ASR/local/prepare_lexicon.py index f0168ebd6..04023a9ab 100644 --- a/egs/timit/ASR/local/prepare_lexicon.py +++ b/egs/timit/ASR/local/prepare_lexicon.py @@ -58,15 +58,19 @@ def prepare_lexicon(manifests_dir: str, lang_dir: str): Return: The lexicon.txt file and the train.text in lang_dir. """ + import gzip + phones = set() - supervisions_train = Path(manifests_dir) / "supervisions_TRAIN.json" + supervisions_train = ( + Path(manifests_dir) / "timit_supervisions_TRAIN.jsonl.gz" + ) lexicon = Path(lang_dir) / "lexicon.txt" logging.info(f"Loading {supervisions_train}!") - with open(supervisions_train, "r") as load_f: - load_dicts = json.load(load_f) - for load_dict in load_dicts: + with gzip.open(supervisions_train, "r") as load_f: + for line in load_f.readlines(): + load_dict = json.loads(line) text = load_dict["text"] # list the phone units and filter the empty item phones_list = list(filter(None, text.split())) diff --git a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py index 200a694d6..10c953e3b 100644 --- a/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py +++ b/egs/wenetspeech/ASR/pruned_transducer_stateless2/asr_datamodule.py @@ -192,13 +192,6 @@ class WenetSpeechAsrDataModule: "with training dataset. ", ) - group.add_argument( - "--lazy-load", - type=str2bool, - default=True, - help="lazily open CutSets to avoid OOM (for L|XL subset)", - ) - group.add_argument( "--training-subset", type=str, @@ -420,17 +413,10 @@ class WenetSpeechAsrDataModule: @lru_cache() def train_cuts(self) -> CutSet: logging.info("About to get train cuts") - if self.args.lazy_load: - logging.info("use lazy cuts") - cuts_train = CutSet.from_jsonl_lazy( - self.args.manifest_dir - / f"cuts_{self.args.training_subset}.jsonl.gz" - ) - else: - cuts_train = CutSet.from_file( - self.args.manifest_dir - / f"cuts_{self.args.training_subset}.jsonl.gz" - ) + cuts_train = load_manifest_lazy( + self.args.manifest_dir + / f"cuts_{self.args.training_subset}.jsonl.gz" + ) return cuts_train @lru_cache()