mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-09-19 05:54:20 +00:00
fix scale value in scaling.py
This commit is contained in:
parent
3cedbe3678
commit
8bd700cff2
@ -354,12 +354,12 @@ def decode_one_batch(
|
|||||||
supervisions = batch["supervisions"]
|
supervisions = batch["supervisions"]
|
||||||
feature_lens = supervisions["num_frames"].to(device)
|
feature_lens = supervisions["num_frames"].to(device)
|
||||||
|
|
||||||
feature_lens += params.left_context
|
# feature_lens += params.left_context
|
||||||
feature = torch.nn.functional.pad(
|
# feature = torch.nn.functional.pad(
|
||||||
feature,
|
# feature,
|
||||||
pad=(0, 0, 0, params.left_context),
|
# pad=(0, 0, 0, params.left_context),
|
||||||
value=LOG_EPS,
|
# value=LOG_EPS,
|
||||||
)
|
# )
|
||||||
|
|
||||||
encoder_out, encoder_out_lens = model.encoder(
|
encoder_out, encoder_out_lens = model.encoder(
|
||||||
x=feature, x_lens=feature_lens
|
x=feature, x_lens=feature_lens
|
||||||
@ -668,11 +668,6 @@ def main():
|
|||||||
params.unk_id = sp.piece_to_id("<unk>")
|
params.unk_id = sp.piece_to_id("<unk>")
|
||||||
params.vocab_size = sp.get_piece_size()
|
params.vocab_size = sp.get_piece_size()
|
||||||
|
|
||||||
if params.simulate_streaming:
|
|
||||||
assert (
|
|
||||||
params.causal_convolution
|
|
||||||
), "Decoding in streaming requires causal convolution"
|
|
||||||
|
|
||||||
logging.info(params)
|
logging.info(params)
|
||||||
|
|
||||||
logging.info("About to create model")
|
logging.info("About to create model")
|
||||||
|
@ -409,8 +409,7 @@ class ScaledLSTM(nn.LSTM):
|
|||||||
def _reset_parameters(self, initial_speed: float):
|
def _reset_parameters(self, initial_speed: float):
|
||||||
std = 0.1 / initial_speed
|
std = 0.1 / initial_speed
|
||||||
a = (3 ** 0.5) * std
|
a = (3 ** 0.5) * std
|
||||||
fan_in = self.input_size
|
scale = self.hidden_size ** -0.5
|
||||||
scale = fan_in ** -0.5
|
|
||||||
v = scale / std
|
v = scale / std
|
||||||
for idx, name in enumerate(self._flat_weights_names):
|
for idx, name in enumerate(self._flat_weights_names):
|
||||||
if "weight" in name:
|
if "weight" in name:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user