mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-08-26 18:24:18 +00:00
Fix style
This commit is contained in:
parent
88ea4532c0
commit
3b6d416c4f
@ -143,7 +143,7 @@ def get_params() -> AttributeDict:
|
|||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
# parameters for decoder
|
# parameters for decoder
|
||||||
"embedding_dim" : 512,
|
"embedding_dim": 512,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -28,7 +28,8 @@ Usage:
|
|||||||
|
|
||||||
It will generate a file exp_dir/pretrained.pt
|
It will generate a file exp_dir/pretrained.pt
|
||||||
|
|
||||||
To use the generated file with `pruned_transducer_stateless/decode.py`, you can do:
|
To use the generated file with `pruned_transducer_stateless/decode.py`,
|
||||||
|
you can do:
|
||||||
|
|
||||||
cd /path/to/exp_dir
|
cd /path/to/exp_dir
|
||||||
ln -s pretrained.pt epoch-9999.pt
|
ln -s pretrained.pt epoch-9999.pt
|
||||||
@ -48,6 +49,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
from decoder import Decoder
|
from decoder import Decoder
|
||||||
from joiner import Joiner
|
from joiner import Joiner
|
||||||
|
@ -136,11 +136,14 @@ class Transducer(nn.Module):
|
|||||||
lm_only_scale=self.lm_scale,
|
lm_only_scale=self.lm_scale,
|
||||||
am_only_scale=self.am_scale,
|
am_only_scale=self.am_scale,
|
||||||
boundary=boundary,
|
boundary=boundary,
|
||||||
return_grad=True
|
return_grad=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
ranges = k2.get_rnnt_prune_ranges(
|
ranges = k2.get_rnnt_prune_ranges(
|
||||||
px_grad=px_grad, py_grad=py_grad, boundary=boundary, s_range=self.prune_range
|
px_grad=px_grad,
|
||||||
|
py_grad=py_grad,
|
||||||
|
boundary=boundary,
|
||||||
|
s_range=self.prune_range,
|
||||||
)
|
)
|
||||||
|
|
||||||
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
am_pruned, lm_pruned = k2.do_rnnt_pruning(
|
||||||
@ -150,7 +153,11 @@ class Transducer(nn.Module):
|
|||||||
logits = self.joiner(am_pruned, lm_pruned)
|
logits = self.joiner(am_pruned, lm_pruned)
|
||||||
|
|
||||||
pruned_loss = k2.rnnt_loss_pruned(
|
pruned_loss = k2.rnnt_loss_pruned(
|
||||||
joint=logits, symbols=y_padded, ranges=ranges, termination_symbol=blank_id, boundary=boundary
|
joint=logits,
|
||||||
|
symbols=y_padded,
|
||||||
|
ranges=ranges,
|
||||||
|
termination_symbol=blank_id,
|
||||||
|
boundary=boundary,
|
||||||
)
|
)
|
||||||
|
|
||||||
return (-torch.sum(simple_loss), -torch.sum(pruned_loss))
|
return (-torch.sum(simple_loss), -torch.sum(pruned_loss))
|
||||||
|
@ -49,6 +49,7 @@ from typing import List
|
|||||||
import kaldifeat
|
import kaldifeat
|
||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from beam_search import beam_search, greedy_search
|
from beam_search import beam_search, greedy_search
|
||||||
from conformer import Conformer
|
from conformer import Conformer
|
||||||
@ -142,7 +143,7 @@ def get_params() -> AttributeDict:
|
|||||||
"num_encoder_layers": 12,
|
"num_encoder_layers": 12,
|
||||||
"vgg_frontend": False,
|
"vgg_frontend": False,
|
||||||
# parameters for decoder
|
# parameters for decoder
|
||||||
"embedding_dim" : 512,
|
"embedding_dim": 512,
|
||||||
"env_info": get_env_info(),
|
"env_info": get_env_info(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
@ -116,7 +116,7 @@ def get_parser():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--exp-dir",
|
"--exp-dir",
|
||||||
type=str,
|
type=str,
|
||||||
default="transducer_stateless/exp",
|
default="pruned_transducer_stateless/exp",
|
||||||
help="""The experiment dir.
|
help="""The experiment dir.
|
||||||
It specifies the directory where all training related
|
It specifies the directory where all training related
|
||||||
files, e.g., checkpoints, log, etc, are saved
|
files, e.g., checkpoints, log, etc, are saved
|
||||||
@ -271,7 +271,7 @@ def get_decoder_model(params: AttributeDict) -> nn.Module:
|
|||||||
return decoder
|
return decoder
|
||||||
|
|
||||||
|
|
||||||
def get_joiner_model(params: AttributeDict) -> nn.Module :
|
def get_joiner_model(params: AttributeDict) -> nn.Module:
|
||||||
joiner = Joiner(
|
joiner = Joiner(
|
||||||
input_dim=params.vocab_size,
|
input_dim=params.vocab_size,
|
||||||
inner_dim=params.embedding_dim,
|
inner_dim=params.embedding_dim,
|
||||||
@ -280,7 +280,7 @@ def get_joiner_model(params: AttributeDict) -> nn.Module :
|
|||||||
return joiner
|
return joiner
|
||||||
|
|
||||||
|
|
||||||
def get_transducer_model(params: AttributeDict) ->nn.Module:
|
def get_transducer_model(params: AttributeDict) -> nn.Module:
|
||||||
encoder = get_encoder_model(params)
|
encoder = get_encoder_model(params)
|
||||||
decoder = get_decoder_model(params)
|
decoder = get_decoder_model(params)
|
||||||
joiner = get_joiner_model(params)
|
joiner = get_joiner_model(params)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user