format files with isort to meet style guidelines

This commit is contained in:
root 2024-05-02 10:02:02 +09:00
parent d61b73964b
commit 0925a0c300
4 changed files with 54 additions and 30 deletions

View File

@ -54,36 +54,51 @@ def make_cutset_blueprints(
# Create test dataset # Create test dataset
logging.info("Creating test cuts.") logging.info("Creating test cuts.")
cut_sets.append(("test", CutSet.from_manifests( cut_sets.append(
recordings=RecordingSet.from_file( (
manifest_dir / "reazonspeech_recordings_test.jsonl.gz" "test",
), CutSet.from_manifests(
supervisions=SupervisionSet.from_file( recordings=RecordingSet.from_file(
manifest_dir / "reazonspeech_supervisions_test.jsonl.gz" manifest_dir / "reazonspeech_recordings_test.jsonl.gz"
), ),
))) supervisions=SupervisionSet.from_file(
manifest_dir / "reazonspeech_supervisions_test.jsonl.gz"
),
),
)
)
# Create valid dataset # Create valid dataset
logging.info("Creating valid cuts.") logging.info("Creating valid cuts.")
cut_sets.append(("valid", CutSet.from_manifests( cut_sets.append(
recordings=RecordingSet.from_file( (
manifest_dir / "reazonspeech_recordings_valid.jsonl.gz" "valid",
), CutSet.from_manifests(
supervisions=SupervisionSet.from_file( recordings=RecordingSet.from_file(
manifest_dir / "reazonspeech_supervisions_valid.jsonl.gz" manifest_dir / "reazonspeech_recordings_valid.jsonl.gz"
), ),
))) supervisions=SupervisionSet.from_file(
manifest_dir / "reazonspeech_supervisions_valid.jsonl.gz"
),
),
)
)
# Create train dataset # Create train dataset
logging.info("Creating train cuts.") logging.info("Creating train cuts.")
cut_sets.append(("train", CutSet.from_manifests( cut_sets.append(
recordings=RecordingSet.from_file( (
manifest_dir / "reazonspeech_recordings_train.jsonl.gz" "train",
), CutSet.from_manifests(
supervisions=SupervisionSet.from_file( recordings=RecordingSet.from_file(
manifest_dir / "reazonspeech_supervisions_train.jsonl.gz" manifest_dir / "reazonspeech_recordings_train.jsonl.gz"
), ),
))) supervisions=SupervisionSet.from_file(
manifest_dir / "reazonspeech_supervisions_train.jsonl.gz"
),
),
)
)
return cut_sets return cut_sets

View File

@ -22,6 +22,7 @@ from pathlib import Path
from lhotse import CutSet from lhotse import CutSet
def get_args(): def get_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter, formatter_class=argparse.ArgumentDefaultsHelpFormatter,

View File

@ -336,14 +336,20 @@ class ReazonSpeechAsrDataModule:
@lru_cache() @lru_cache()
def train_cuts(self) -> CutSet: def train_cuts(self) -> CutSet:
logging.info("About to get train cuts") logging.info("About to get train cuts")
return load_manifest_lazy(self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz") return load_manifest_lazy(
self.args.manifest_dir / "reazonspeech_cuts_train.jsonl.gz"
)
@lru_cache() @lru_cache()
def valid_cuts(self) -> CutSet: def valid_cuts(self) -> CutSet:
logging.info("About to get valid cuts") logging.info("About to get valid cuts")
return load_manifest_lazy(self.args.manifest_dir / "reazonspeech_cuts_valid.jsonl.gz") return load_manifest_lazy(
self.args.manifest_dir / "reazonspeech_cuts_valid.jsonl.gz"
)
@lru_cache() @lru_cache()
def test_cuts(self) -> List[CutSet]: def test_cuts(self) -> List[CutSet]:
logging.info("About to get test cuts") logging.info("About to get test cuts")
return load_manifest_lazy(self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz") return load_manifest_lazy(
self.args.manifest_dir / "reazonspeech_cuts_test.jsonl.gz"
)

View File

@ -103,7 +103,6 @@ from pathlib import Path
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
import k2 import k2
from tokenizer import Tokenizer
import torch import torch
import torch.nn as nn import torch.nn as nn
from asr_datamodule import ReazonSpeechAsrDataModule from asr_datamodule import ReazonSpeechAsrDataModule
@ -121,6 +120,7 @@ from beam_search import (
modified_beam_search_lm_shallow_fusion, modified_beam_search_lm_shallow_fusion,
modified_beam_search_LODR, modified_beam_search_LODR,
) )
from tokenizer import Tokenizer
from train import add_model_arguments, get_model, get_params from train import add_model_arguments, get_model, get_params
from icefall import ContextGraph, LmScorer, NgramLm from icefall import ContextGraph, LmScorer, NgramLm
@ -1039,7 +1039,9 @@ def main():
for subdir in ["valid"]: for subdir in ["valid"]:
results_dict = decode_dataset( results_dict = decode_dataset(
dl=reazonspeech_corpus.test_dataloaders(getattr(reazonspeech_corpus, f"{subdir}_cuts")()), dl=reazonspeech_corpus.test_dataloaders(
getattr(reazonspeech_corpus, f"{subdir}_cuts")()
),
params=params, params=params,
model=model, model=model,
sp=sp, sp=sp,
@ -1065,7 +1067,7 @@ def main():
# if len(tot_err) == 1: # if len(tot_err) == 1:
# fout.write(f"{tot_err[0][1]}") # fout.write(f"{tot_err[0][1]}")
# else: # else:
# fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err) # fout.write("\n".join(f"{k}\t{v}") for k, v in tot_err)
logging.info("Done!") logging.info("Done!")