mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
fix the comments; wrap the classifier for jit script
This commit is contained in:
parent
8b234b371a
commit
a8ca0295b7
@ -26,7 +26,7 @@ repos:
|
|||||||
# E121,E123,E126,E226,E24,E704,W503,W504
|
# E121,E123,E126,E226,E24,E704,W503,W504
|
||||||
|
|
||||||
- repo: https://github.com/pycqa/isort
|
- repo: https://github.com/pycqa/isort
|
||||||
rev: 5.10.1
|
rev: 5.12.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
args: ["--profile=black"]
|
args: ["--profile=black"]
|
||||||
|
@ -32,7 +32,6 @@ dataset, you should change the argument values according to your dataset.
|
|||||||
|
|
||||||
./zipformer/export.py \
|
./zipformer/export.py \
|
||||||
--exp-dir ./zipformer/exp \
|
--exp-dir ./zipformer/exp \
|
||||||
--tokens data/lang_bpe_500/tokens.txt \
|
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 9 \
|
--avg 9 \
|
||||||
--jit 1
|
--jit 1
|
||||||
@ -51,7 +50,6 @@ for how to use the exported models outside of icefall.
|
|||||||
|
|
||||||
./zipformer/export.py \
|
./zipformer/export.py \
|
||||||
--exp-dir ./zipformer/exp \
|
--exp-dir ./zipformer/exp \
|
||||||
--tokens data/lang_bpe_500/tokens.txt \
|
|
||||||
--epoch 30 \
|
--epoch 30 \
|
||||||
--avg 9
|
--avg 9
|
||||||
|
|
||||||
@ -151,13 +149,6 @@ def get_parser():
|
|||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--tokens",
|
|
||||||
type=str,
|
|
||||||
default="data/lang_bpe_500/tokens.txt",
|
|
||||||
help="Path to the tokens.txt",
|
|
||||||
)
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--jit",
|
"--jit",
|
||||||
type=str2bool,
|
type=str2bool,
|
||||||
@ -200,6 +191,33 @@ class EncoderModel(nn.Module):
|
|||||||
return encoder_out, encoder_out_lens
|
return encoder_out, encoder_out_lens
|
||||||
|
|
||||||
|
|
||||||
|
class Classifier(nn.Module):
|
||||||
|
"""A wrapper for audio tagging classifier"""
|
||||||
|
|
||||||
|
def __init__(self, classifier: nn.Module) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.classifier = classifier
|
||||||
|
|
||||||
|
def forward(self, encoder_out: Tensor, encoder_out_lens: Tensor):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
encoder_out:
|
||||||
|
A 3-D tensor of shape (N, T, C).
|
||||||
|
encoder_out_lens:
|
||||||
|
A 1-D tensor of shape (N,). It contains the number of frames in `x`
|
||||||
|
before padding.
|
||||||
|
"""
|
||||||
|
logits = self.classifier(encoder_out) # (N, T, num_classes)
|
||||||
|
padding_mask = make_pad_mask(encoder_out_lens)
|
||||||
|
logits[padding_mask] = 0
|
||||||
|
logits = logits.sum(dim=1) # mask the padding frames
|
||||||
|
logits = logits / (~padding_mask).sum(dim=1).unsqueeze(-1).expand_as(
|
||||||
|
logits
|
||||||
|
) # normalize the logits
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def main():
|
def main():
|
||||||
args = get_parser().parse_args()
|
args = get_parser().parse_args()
|
||||||
@ -302,6 +320,7 @@ def main():
|
|||||||
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
|
||||||
|
|
||||||
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
|
model.encoder = EncoderModel(model.encoder, model.encoder_embed)
|
||||||
|
model.classifier = Classifier(model.classifier)
|
||||||
filename = "jit_script.pt"
|
filename = "jit_script.pt"
|
||||||
|
|
||||||
logging.info("Using torch.jit.script")
|
logging.info("Using torch.jit.script")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user