Fix style

This commit is contained in:
pkufool 2022-01-27 17:11:36 +08:00
parent 88ea4532c0
commit 3b6d416c4f
5 changed files with 19 additions and 9 deletions

View File

@ -143,7 +143,7 @@ def get_params() -> AttributeDict:
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
# parameters for decoder # parameters for decoder
"embedding_dim" : 512, "embedding_dim": 512,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )

View File

@ -28,7 +28,8 @@ Usage:
It will generate a file exp_dir/pretrained.pt It will generate a file exp_dir/pretrained.pt
To use the generated file with `pruned_transducer_stateless/decode.py`, you can do: To use the generated file with `pruned_transducer_stateless/decode.py`,
you can do:
cd /path/to/exp_dir cd /path/to/exp_dir
ln -s pretrained.pt epoch-9999.pt ln -s pretrained.pt epoch-9999.pt
@ -48,6 +49,7 @@ from pathlib import Path
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
from conformer import Conformer from conformer import Conformer
from decoder import Decoder from decoder import Decoder
from joiner import Joiner from joiner import Joiner

View File

@ -136,11 +136,14 @@ class Transducer(nn.Module):
lm_only_scale=self.lm_scale, lm_only_scale=self.lm_scale,
am_only_scale=self.am_scale, am_only_scale=self.am_scale,
boundary=boundary, boundary=boundary,
return_grad=True return_grad=True,
) )
ranges = k2.get_rnnt_prune_ranges( ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad, py_grad=py_grad, boundary=boundary, s_range=self.prune_range px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=self.prune_range,
) )
am_pruned, lm_pruned = k2.do_rnnt_pruning( am_pruned, lm_pruned = k2.do_rnnt_pruning(
@ -150,7 +153,11 @@ class Transducer(nn.Module):
logits = self.joiner(am_pruned, lm_pruned) logits = self.joiner(am_pruned, lm_pruned)
pruned_loss = k2.rnnt_loss_pruned( pruned_loss = k2.rnnt_loss_pruned(
joint=logits, symbols=y_padded, ranges=ranges, termination_symbol=blank_id, boundary=boundary joint=logits,
symbols=y_padded,
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
) )
return (-torch.sum(simple_loss), -torch.sum(pruned_loss)) return (-torch.sum(simple_loss), -torch.sum(pruned_loss))

View File

@ -49,6 +49,7 @@ from typing import List
import kaldifeat import kaldifeat
import sentencepiece as spm import sentencepiece as spm
import torch import torch
import torch.nn as nn
import torchaudio import torchaudio
from beam_search import beam_search, greedy_search from beam_search import beam_search, greedy_search
from conformer import Conformer from conformer import Conformer
@ -142,7 +143,7 @@ def get_params() -> AttributeDict:
"num_encoder_layers": 12, "num_encoder_layers": 12,
"vgg_frontend": False, "vgg_frontend": False,
# parameters for decoder # parameters for decoder
"embedding_dim" : 512, "embedding_dim": 512,
"env_info": get_env_info(), "env_info": get_env_info(),
} }
) )

View File

@ -116,7 +116,7 @@ def get_parser():
parser.add_argument( parser.add_argument(
"--exp-dir", "--exp-dir",
type=str, type=str,
default="transducer_stateless/exp", default="pruned_transducer_stateless/exp",
help="""The experiment dir. help="""The experiment dir.
It specifies the directory where all training related It specifies the directory where all training related
files, e.g., checkpoints, log, etc, are saved files, e.g., checkpoints, log, etc, are saved
@ -271,7 +271,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
return decoder return decoder
def get_joiner_model(params: AttributeDict) -> nn.Module : def get_joiner_model(params: AttributeDict) -> nn.Module:
joiner = Joiner( joiner = Joiner(
input_dim=params.vocab_size, input_dim=params.vocab_size,
inner_dim=params.embedding_dim, inner_dim=params.embedding_dim,
@ -280,7 +280,7 @@ def get_joiner_model(params: AttributeDict) -> nn.Module :
return joiner return joiner
def get_transducer_model(params: AttributeDict) ->nn.Module: def get_transducer_model(params: AttributeDict) -> nn.Module:
encoder = get_encoder_model(params) encoder = get_encoder_model(params)
decoder = get_decoder_model(params) decoder = get_decoder_model(params)
joiner = get_joiner_model(params) joiner = get_joiner_model(params)