From a8ca0295b72b2b02a5ec95494317bdbf9a4574c0 Mon Sep 17 00:00:00 2001 From: marcoyang Date: Fri, 29 Mar 2024 17:07:24 +0800 Subject: [PATCH] fix the comments; wrap the classifier for jit script --- .pre-commit-config.yaml | 2 +- egs/audioset/AT/zipformer/export.py | 37 ++++++++++++++++++++++------- 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5cb213327..70068f9cf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -26,7 +26,7 @@ repos: # E121,E123,E126,E226,E24,E704,W503,W504 - repo: https://github.com/pycqa/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort args: ["--profile=black"] diff --git a/egs/audioset/AT/zipformer/export.py b/egs/audioset/AT/zipformer/export.py index 1c0e702d6..61e2f9ab7 100755 --- a/egs/audioset/AT/zipformer/export.py +++ b/egs/audioset/AT/zipformer/export.py @@ -32,7 +32,6 @@ dataset, you should change the argument values according to your dataset. ./zipformer/export.py \ --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 \ --jit 1 @@ -51,7 +50,6 @@ for how to use the exported models outside of icefall. ./zipformer/export.py \ --exp-dir ./zipformer/exp \ - --tokens data/lang_bpe_500/tokens.txt \ --epoch 30 \ --avg 9 @@ -151,13 +149,6 @@ def get_parser(): """, ) - parser.add_argument( - "--tokens", - type=str, - default="data/lang_bpe_500/tokens.txt", - help="Path to the tokens.txt", - ) - parser.add_argument( "--jit", type=str2bool, @@ -200,6 +191,33 @@ class EncoderModel(nn.Module): return encoder_out, encoder_out_lens +class Classifier(nn.Module): + """A wrapper for audio tagging classifier""" + + def __init__(self, classifier: nn.Module) -> None: + super().__init__() + self.classifier = classifier + + def forward(self, encoder_out: Tensor, encoder_out_lens: Tensor): + """ + Args: + encoder_out: + A 3-D tensor of shape (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,). It contains the number of frames in `x` + before padding. + """ + logits = self.classifier(encoder_out) # (N, T, num_classes) + padding_mask = make_pad_mask(encoder_out_lens) + logits[padding_mask] = 0 + logits = logits.sum(dim=1) # mask the padding frames + logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as( + logits + ) # normalize the logits + + return logits + + @torch.no_grad() def main(): args = get_parser().parse_args() @@ -302,6 +320,7 @@ def main(): model.__class__.forward = torch.jit.ignore(model.__class__.forward) model.encoder = EncoderModel(model.encoder, model.encoder_embed) + model.classifier = Classifier(model.classifier) filename = "jit_script.pt" logging.info("Using torch.jit.script")