fix the comments; wrap the classifier for jit script

This commit is contained in:
marcoyang 2024-03-29 17:07:24 +08:00
parent 8b234b371a
commit a8ca0295b7
2 changed files with 29 additions and 10 deletions

View File

@ -26,7 +26,7 @@ repos:
# E121,E123,E126,E226,E24,E704,W503,W504
- repo: https://github.com/pycqa/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
args: ["--profile=black"]

View File

@ -32,7 +32,6 @@ dataset, you should change the argument values according to your dataset.
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9 \
--jit 1
@ -51,7 +50,6 @@ for how to use the exported models outside of icefall.
./zipformer/export.py \
--exp-dir ./zipformer/exp \
--tokens data/lang_bpe_500/tokens.txt \
--epoch 30 \
--avg 9
@ -151,13 +149,6 @@ def get_parser():
""",
)
parser.add_argument(
"--tokens",
type=str,
default="data/lang_bpe_500/tokens.txt",
help="Path to the tokens.txt",
)
parser.add_argument(
"--jit",
type=str2bool,
@ -200,6 +191,33 @@ class EncoderModel(nn.Module):
return encoder_out, encoder_out_lens
class Classifier(nn.Module):
"""A wrapper for audio tagging classifier"""
def __init__(self, classifier: nn.Module) -> None:
super().__init__()
self.classifier = classifier
def forward(self, encoder_out: Tensor, encoder_out_lens: Tensor):
"""
Args:
encoder_out:
A 3-D tensor of shape (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,). It contains the number of frames in `x`
before padding.
"""
logits = self.classifier(encoder_out) # (N, T, num_classes)
padding_mask = make_pad_mask(encoder_out_lens)
logits[padding_mask] = 0
logits = logits.sum(dim=1) # mask the padding frames
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
logits
) # normalize the logits
return logits
@torch.no_grad()
def main():
args = get_parser().parse_args()
@ -302,6 +320,7 @@ def main():
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
model.classifier = Classifier(model.classifier)
filename = "jit_script.pt"
logging.info("Using torch.jit.script")