mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Fix style
This commit is contained in:
parent
88ea4532c0
commit
3b6d416c4f
@ -143,7 +143,7 @@ def get_params() -> AttributeDict:
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
# parameters for decoder
|
||||
"embedding_dim" : 512,
|
||||
"embedding_dim": 512,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
|
@ -28,7 +28,8 @@ Usage:
|
||||
|
||||
It will generate a file exp_dir/pretrained.pt
|
||||
|
||||
To use the generated file with `pruned_transducer_stateless/decode.py`, you can do:
|
||||
To use the generated file with `pruned_transducer_stateless/decode.py`,
|
||||
you can do:
|
||||
|
||||
cd /path/to/exp_dir
|
||||
ln -s pretrained.pt epoch-9999.pt
|
||||
@ -48,6 +49,7 @@ from pathlib import Path
|
||||
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from conformer import Conformer
|
||||
from decoder import Decoder
|
||||
from joiner import Joiner
|
||||
|
@ -136,11 +136,14 @@ class Transducer(nn.Module):
|
||||
lm_only_scale=self.lm_scale,
|
||||
am_only_scale=self.am_scale,
|
||||
boundary=boundary,
|
||||
return_grad=True
|
||||
return_grad=True,
|
||||
)
|
||||
|
||||
ranges = k2.get_rnnt_prune_ranges(
|
||||
px_grad=px_grad, py_grad=py_grad, boundary=boundary, s_range=self.prune_range
|
||||
px_grad=px_grad,
|
||||
py_grad=py_grad,
|
||||
boundary=boundary,
|
||||
s_range=self.prune_range,
|
||||
)
|
||||
|
||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||
@ -150,7 +153,11 @@ class Transducer(nn.Module):
|
||||
logits = self.joiner(am_pruned, lm_pruned)
|
||||
|
||||
pruned_loss = k2.rnnt_loss_pruned(
|
||||
joint=logits, symbols=y_padded, ranges=ranges, termination_symbol=blank_id, boundary=boundary
|
||||
joint=logits,
|
||||
symbols=y_padded,
|
||||
ranges=ranges,
|
||||
termination_symbol=blank_id,
|
||||
boundary=boundary,
|
||||
)
|
||||
|
||||
return (-torch.sum(simple_loss), -torch.sum(pruned_loss))
|
||||
|
@ -49,6 +49,7 @@ from typing import List
|
||||
import kaldifeat
|
||||
import sentencepiece as spm
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchaudio
|
||||
from beam_search import beam_search, greedy_search
|
||||
from conformer import Conformer
|
||||
@ -142,7 +143,7 @@ def get_params() -> AttributeDict:
|
||||
"num_encoder_layers": 12,
|
||||
"vgg_frontend": False,
|
||||
# parameters for decoder
|
||||
"embedding_dim" : 512,
|
||||
"embedding_dim": 512,
|
||||
"env_info": get_env_info(),
|
||||
}
|
||||
)
|
||||
|
@ -116,7 +116,7 @@ def get_parser():
|
||||
parser.add_argument(
|
||||
"--exp-dir",
|
||||
type=str,
|
||||
default="transducer_stateless/exp",
|
||||
default="pruned_transducer_stateless/exp",
|
||||
help="""The experiment dir.
|
||||
It specifies the directory where all training related
|
||||
files, e.g., checkpoints, log, etc, are saved
|
||||
@ -271,7 +271,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
||||
return decoder
|
||||
|
||||
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module :
|
||||
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||
joiner = Joiner(
|
||||
input_dim=params.vocab_size,
|
||||
inner_dim=params.embedding_dim,
|
||||
@ -280,7 +280,7 @@ def get_joiner_model(params: AttributeDict) -> nn.Module :
|
||||
return joiner
|
||||
|
||||
|
||||
def get_transducer_model(params: AttributeDict) ->nn.Module:
|
||||
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||
encoder = get_encoder_model(params)
|
||||
decoder = get_decoder_model(params)
|
||||
joiner = get_joiner_model(params)
|
||||
|
Loading…
x
Reference in New Issue
Block a user