mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
fix the comments; wrap the classifier for jit script
This commit is contained in:
parent
8b234b371a
commit
a8ca0295b7
@ -26,7 +26,7 @@ repos:
|
||||
# E121,E123,E126,E226,E24,E704,W503,W504
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.10.1
|
||||
rev: 5.12.0
|
||||
hooks:
|
||||
- id: isort
|
||||
args: ["--profile=black"]
|
||||
|
@ -32,7 +32,6 @@ dataset, you should change the argument values according to your dataset.
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9 \
|
||||
--jit 1
|
||||
@ -51,7 +50,6 @@ for how to use the exported models outside of icefall.
|
||||
|
||||
./zipformer/export.py \
|
||||
--exp-dir ./zipformer/exp \
|
||||
--tokens data/lang_bpe_500/tokens.txt \
|
||||
--epoch 30 \
|
||||
--avg 9
|
||||
|
||||
@ -151,13 +149,6 @@ def get_parser():
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
default="data/lang_bpe_500/tokens.txt",
|
||||
help="Path to the tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--jit",
|
||||
type=str2bool,
|
||||
@ -200,6 +191,33 @@ class EncoderModel(nn.Module):
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
|
||||
class Classifier(nn.Module):
|
||||
"""A wrapper for audio tagging classifier"""
|
||||
|
||||
def __init__(self, classifier: nn.Module) -> None:
|
||||
super().__init__()
|
||||
self.classifier = classifier
|
||||
|
||||
def forward(self, encoder_out: Tensor, encoder_out_lens: Tensor):
|
||||
"""
|
||||
Args:
|
||||
encoder_out:
|
||||
A 3-D tensor of shape (N, T, C).
|
||||
encoder_out_lens:
|
||||
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||
before padding.
|
||||
"""
|
||||
logits = self.classifier(encoder_out) # (N, T, num_classes)
|
||||
padding_mask = make_pad_mask(encoder_out_lens)
|
||||
logits[padding_mask] = 0
|
||||
logits = logits.sum(dim=1) # mask the padding frames
|
||||
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
|
||||
logits
|
||||
) # normalize the logits
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_parser().parse_args()
|
||||
@ -302,6 +320,7 @@ def main():
|
||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||
|
||||
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
|
||||
model.classifier = Classifier(model.classifier)
|
||||
filename = "jit_script.pt"
|
||||
|
||||
logging.info("Using torch.jit.script")
|
||||
|
Loading…
x
Reference in New Issue
Block a user