minor fix

This commit is contained in:
marcoyang 2023-12-19 17:05:50 +08:00
parent a1aca34e24
commit bf58b63a6a
3 changed files with 1352 additions and 4 deletions

View File

@ -17,18 +17,18 @@
import argparse import argparse
import inspect import inspect
import logging import logging
import pickle
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
import pickle
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
import torch import torch
from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures
AudioTaggingDataset,
CutConcatenate, CutConcatenate,
CutMix, CutMix,
DynamicBucketingSampler, DynamicBucketingSampler,
AudioTaggingDataset,
PrecomputedFeatures, PrecomputedFeatures,
SimpleCutSampler, SimpleCutSampler,
SpecAugment, SpecAugment,
@ -42,6 +42,7 @@ from torch.utils.data import DataLoader
from icefall.utils import str2bool from icefall.utils import str2bool
class _SeedWorkers: class _SeedWorkers:
def __init__(self, seed: int): def __init__(self, seed: int):
self.seed = seed self.seed = seed
@ -53,7 +54,7 @@ class _SeedWorkers:
class AudioSetATDatamodule: class AudioSetATDatamodule:
""" """
DataModule for k2 audio tagging (AT) experiments. DataModule for k2 audio tagging (AT) experiments.
It contains all the common data pipeline modules used in ASR It contains all the common data pipeline modules used in ASR
experiments, e.g.: experiments, e.g.:
@ -65,6 +66,7 @@ class AudioSetATDatamodule:
This class should be derived for specific corpora used in ASR tasks. This class should be derived for specific corpora used in ASR tasks.
""" """
def __init__(self, args: argparse.Namespace): def __init__(self, args: argparse.Namespace):
self.args = args self.args = args
@ -82,7 +84,7 @@ class AudioSetATDatamodule:
"--audioset-subset", "--audioset-subset",
type=str, type=str,
default="balanced", default="balanced",
choices=["balanced", "full"] choices=["balanced", "full"],
) )
group.add_argument( group.add_argument(

View File

@ -0,0 +1,157 @@
# Copyright 2021-2023 Xiaomi Corp. (authors: Xiaoyu Yang,
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
from typing import List, Optional, Tuple
import k2
import torch
import torch.nn as nn
import torch.nn.functional as F
from encoder_interface import EncoderInterface
from icefall.utils import AttributeDict, make_pad_mask
class AudioTaggingModel(nn.Module):
def __init__(
self,
encoder_embed: nn.Module,
encoder: EncoderInterface,
encoder_dim: int = 384,
num_events: int = 527,
):
"""An audio tagging model
Args:
encoder_embed:
It is a Convolutional 2D subsampling module. It converts
an input of shape (N, T, idim) to an output of of shape
(N, T', odim), where T' = (T-3)//2-2 = (T-7)//2.
encoder:
It is the transcription network in the paper. Its accepts
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,).
It returns two tensors: `logits` of shape (N, T, encoder_dim) and
`logit_lens` of shape (N,).
encoder_dim:
Dimension of the encoder.
num_event:
The number of classes.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
self.encoder_embed = encoder_embed
self.encoder = encoder
self.encoder_dim = encoder_dim
self.classifier = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(encoder_dim, num_events),
)
# for multi-class classification
self.criterion = torch.nn.BCEWithLogitsLoss(reduction="sum")
def forward_encoder(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute encoder outputs.
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
Returns:
encoder_out:
Encoder output, of shape (N, T, C).
encoder_out_lens:
Encoder output lengths, of shape (N,).
"""
# logging.info(f"Memory allocated at entry: {torch.cuda.memory_allocated() // 1000000}M")
x, x_lens = self.encoder_embed(x, x_lens)
# logging.info(f"Memory allocated after encoder_embed: {torch.cuda.memory_allocated() // 1000000}M")
src_key_padding_mask = make_pad_mask(x_lens)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask)
encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens)
return encoder_out, encoder_out_lens
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
target: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
x:
A 3-D tensor of shape (N, T, C).
x_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
target:
The ground truth label of audio events, could be many hot
Returns:
Return the binary crossentropy loss
"""
assert x.ndim == 3, x.shape
assert x_lens.ndim == 1, x_lens.shape
# Compute encoder outputs
encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens)
# Forward the speaker module
logits = self.forward_audio_tagging(
encoder_out=encoder_out, encoder_out_lens=encoder_out_lens
) # (N, num_classes)
loss = self.criterion(logits, target)
return loss
def forward_audio_tagging(self, encoder_out, encoder_out_lens):
"""
Args:
encoder_out:
A 3-D tensor of shape (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
Returns:
A 3-D tensor of shape (N, T, num_classes).
"""
logits = self.classifier(encoder_out) # (N, T, num_classes)
padding_mask = make_pad_mask(encoder_out_lens)
logits[padding_mask] = 0
logits = logits.sum(dim=1) # mask the padding frames
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
logits
) # normalize the logits
return logits

File diff suppressed because it is too large Load Diff