Fix style

This commit is contained in:
pkufool 2022-01-27 17:11:36 +08:00
parent 88ea4532c0
commit 3b6d416c4f
5 changed files with 19 additions and 9 deletions

View File

@ -143,7 +143,7 @@ def get_params() -> AttributeDict:
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"embedding_dim" : 512,
"embedding_dim": 512,
"env_info": get_env_info(),
}
)

View File

@ -28,7 +28,8 @@ Usage:
It will generate a file exp_dir/pretrained.pt
To use the generated file with `pruned_transducer_stateless/decode.py`, you can do:
To use the generated file with `pruned_transducer_stateless/decode.py`,
you can do:
cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt
@ -48,6 +49,7 @@ from pathlib import Path
import sentencepiece as spm
import torch
import torch.nn as nn
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner

View File

@ -136,11 +136,14 @@ class Transducer(nn.Module):
lm_only_scale=self.lm_scale,
am_only_scale=self.am_scale,
boundary=boundary,
return_grad=True
return_grad=True,
)
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad, py_grad=py_grad, boundary=boundary, s_range=self.prune_range
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=self.prune_range,
)
am_pruned, lm_pruned = k2.do_rnnt_pruning(
@ -150,7 +153,11 @@ class Transducer(nn.Module):
logits = self.joiner(am_pruned, lm_pruned)
pruned_loss = k2.rnnt_loss_pruned(
joint=logits, symbols=y_padded, ranges=ranges, termination_symbol=blank_id, boundary=boundary
joint=logits,
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
)
return (-torch.sum(simple_loss), -torch.sum(pruned_loss))

View File

@ -49,6 +49,7 @@ from typing import List
import kaldifeat
import sentencepiece as spm
import torch
import torch.nn as nn
import torchaudio
from beam_search import beam_search, greedy_search
from conformer import Conformer
@ -142,7 +143,7 @@ def get_params() -> AttributeDict:
"num_encoder_layers": 12,
"vgg_frontend": False,
# parameters for decoder
"embedding_dim" : 512,
"embedding_dim": 512,
"env_info": get_env_info(),
}
)

View File

@ -116,7 +116,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="transducer_stateless/exp",
default="pruned_transducer_stateless/exp",
help="""The experiment dir.
It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved
@ -271,7 +271,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module :
def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner(
input_dim=params.vocab_size,
inner_dim=params.embedding_dim,
@ -280,7 +280,7 @@ def get_joiner_model(params: AttributeDict) -> nn.Module :
return joiner
def get_transducer_model(params: AttributeDict) ->nn.Module:
def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params)
decoder = get_decoder_model(params)
joiner = get_joiner_model(params)