mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-09 01:52:41 +00:00
Removed unused `k2
` dependencies from the AT recipe (#1633)
This commit is contained in:
parent
0df406c5da
commit
1adf1e441d
@ -373,9 +373,11 @@ class AudioSetATDatamodule:
|
|||||||
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
|
||||||
logging.debug("About to create test dataset")
|
logging.debug("About to create test dataset")
|
||||||
test = AudioTaggingDataset(
|
test = AudioTaggingDataset(
|
||||||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
input_strategy=(
|
||||||
|
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
|
||||||
if self.args.on_the_fly_feats
|
if self.args.on_the_fly_feats
|
||||||
else eval(self.args.input_strategy)(),
|
else eval(self.args.input_strategy)()
|
||||||
|
),
|
||||||
return_cuts=self.args.return_cuts,
|
return_cuts=self.args.return_cuts,
|
||||||
)
|
)
|
||||||
sampler = DynamicBucketingSampler(
|
sampler = DynamicBucketingSampler(
|
||||||
|
@ -29,27 +29,18 @@ export CUDA_VISIBLE_DEVICES="0"
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import csv
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
|
||||||
import os
|
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict
|
||||||
|
|
||||||
import k2
|
|
||||||
import numpy as np
|
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from at_datamodule import AudioSetATDatamodule
|
from at_datamodule import AudioSetATDatamodule
|
||||||
from lhotse import load_manifest
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from sklearn.metrics import average_precision_score
|
from sklearn.metrics import average_precision_score
|
||||||
except Exception as ex:
|
except:
|
||||||
raise RuntimeError(f"{ex}\nPlease run\n" "pip3 install -U scikit-learn")
|
raise ImportError(f"Please run\n" "pip3 install -U scikit-learn")
|
||||||
from train import add_model_arguments, get_model, get_params, str2multihot
|
from train import add_model_arguments, get_model, get_params, str2multihot
|
||||||
|
|
||||||
from icefall.checkpoint import (
|
from icefall.checkpoint import (
|
||||||
@ -58,15 +49,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.lexicon import Lexicon
|
from icefall.utils import AttributeDict, setup_logger, str2bool
|
||||||
from icefall.utils import (
|
|
||||||
AttributeDict,
|
|
||||||
make_pad_mask,
|
|
||||||
setup_logger,
|
|
||||||
store_transcripts,
|
|
||||||
str2bool,
|
|
||||||
write_error_stats,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -36,7 +36,6 @@ import logging
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import k2
|
|
||||||
import onnx
|
import onnx
|
||||||
import onnxoptimizer
|
import onnxoptimizer
|
||||||
import torch
|
import torch
|
||||||
@ -53,7 +52,7 @@ from icefall.checkpoint import (
|
|||||||
find_checkpoints,
|
find_checkpoints,
|
||||||
load_checkpoint,
|
load_checkpoint,
|
||||||
)
|
)
|
||||||
from icefall.utils import make_pad_mask, num_tokens, str2bool
|
from icefall.utils import make_pad_mask, str2bool
|
||||||
|
|
||||||
|
|
||||||
def get_parser():
|
def get_parser():
|
||||||
|
@ -50,7 +50,6 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import k2
|
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import torch
|
import torch
|
||||||
import torchaudio
|
import torchaudio
|
||||||
|
@ -14,17 +14,13 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
from typing import Tuple
|
||||||
import random
|
|
||||||
from typing import List, Optional, Tuple
|
|
||||||
|
|
||||||
import k2
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from encoder_interface import EncoderInterface
|
from encoder_interface import EncoderInterface
|
||||||
|
|
||||||
from icefall.utils import AttributeDict, make_pad_mask
|
from icefall.utils import make_pad_mask
|
||||||
|
|
||||||
|
|
||||||
class AudioTaggingModel(nn.Module):
|
class AudioTaggingModel(nn.Module):
|
||||||
|
@ -42,9 +42,8 @@ import argparse
|
|||||||
import csv
|
import csv
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from typing import List, Tuple
|
from typing import List
|
||||||
|
|
||||||
import k2
|
|
||||||
import kaldifeat
|
import kaldifeat
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
import torch
|
import torch
|
||||||
|
@ -41,7 +41,6 @@ from shutil import copyfile
|
|||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import optim
|
import optim
|
||||||
import sentencepiece as spm
|
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@ -632,7 +631,7 @@ def compute_loss(
|
|||||||
model:
|
model:
|
||||||
The model for training. It is an instance of Zipformer in our case.
|
The model for training. It is an instance of Zipformer in our case.
|
||||||
batch:
|
batch:
|
||||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
A batch of data. See `lhotse.dataset.AudioTaggingDataset()`
|
||||||
for the content in it.
|
for the content in it.
|
||||||
is_training:
|
is_training:
|
||||||
True for training. False for validation. When it is True, this
|
True for training. False for validation. When it is True, this
|
||||||
@ -1108,7 +1107,7 @@ def display_and_save_batch(
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch:
|
batch:
|
||||||
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
|
A batch of data. See `lhotse.dataset.AudioTaggingDataset()`
|
||||||
for the content in it.
|
for the content in it.
|
||||||
params:
|
params:
|
||||||
Parameters for training. See :func:`get_params`.
|
Parameters for training. See :func:`get_params`.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user