Removed unused `k2` dependencies from the AT recipe (#1633)

This commit is contained in:
zr_jin 2024-05-21 18:22:19 +08:00 committed by GitHub
parent 0df406c5da
commit 1adf1e441d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 15 additions and 38 deletions

View File

@ -373,9 +373,11 @@ class AudioSetATDatamodule:
def test_dataloaders(self, cuts: CutSet) -> DataLoader:
logging.debug("About to create test dataset")
test = AudioTaggingDataset(
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)(),
input_strategy=(
OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80)))
if self.args.on_the_fly_feats
else eval(self.args.input_strategy)()
),
return_cuts=self.args.return_cuts,
)
sampler = DynamicBucketingSampler(

View File

@ -29,27 +29,18 @@ export CUDA_VISIBLE_DEVICES="0"
"""
import argparse
import csv
import logging
import math
import os
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from typing import Dict
import k2
import numpy as np
import sentencepiece as spm
import torch
import torch.nn as nn
import torch.nn.functional as F
from at_datamodule import AudioSetATDatamodule
from lhotse import load_manifest
try:
from sklearn.metrics import average_precision_score
except Exception as ex:
raise RuntimeError(f"{ex}\nPlease run\n" "pip3 install -U scikit-learn")
except:
raise ImportError(f"Please run\n" "pip3 install -U scikit-learn")
from train import add_model_arguments, get_model, get_params, str2multihot
from icefall.checkpoint import (
@ -58,15 +49,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
make_pad_mask,
setup_logger,
store_transcripts,
str2bool,
write_error_stats,
)
from icefall.utils import AttributeDict, setup_logger, str2bool
def get_parser():

View File

@ -36,7 +36,6 @@ import logging
from pathlib import Path
from typing import Dict
import k2
import onnx
import onnxoptimizer
import torch
@ -53,7 +52,7 @@ from icefall.checkpoint import (
find_checkpoints,
load_checkpoint,
)
from icefall.utils import make_pad_mask, num_tokens, str2bool
from icefall.utils import make_pad_mask, str2bool
def get_parser():

View File

@ -50,7 +50,6 @@ import logging
import math
from typing import List
import k2
import kaldifeat
import torch
import torchaudio

View File

@ -14,17 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import List, Optional, Tuple
from typing import Tuple
import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from encoder_interface import EncoderInterface
from icefall.utils import AttributeDict, make_pad_mask
from icefall.utils import make_pad_mask
class AudioTaggingModel(nn.Module):

View File

@ -42,9 +42,8 @@ import argparse
import csv
import logging
import math
from typing import List, Tuple
from typing import List
import k2
import kaldifeat
import onnxruntime as ort
import torch

View File

@ -41,7 +41,6 @@ from shutil import copyfile
from typing import Any, Dict, List, Optional, Tuple, Union
import optim
import sentencepiece as spm
import torch
import torch.multiprocessing as mp
import torch.nn as nn
@ -632,7 +631,7 @@ def compute_loss(
model:
The model for training. It is an instance of Zipformer in our case.
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
A batch of data. See `lhotse.dataset.AudioTaggingDataset()`
for the content in it.
is_training:
True for training. False for validation. When it is True, this
@ -1108,7 +1107,7 @@ def display_and_save_batch(
Args:
batch:
A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()`
A batch of data. See `lhotse.dataset.AudioTaggingDataset()`
for the content in it.
params:
Parameters for training. See :func:`get_params`.