fix scale value in scaling.py

This commit is contained in:
yaozengwei 2022-07-18 10:33:35 +08:00
parent 3cedbe3678
commit 8bd700cff2
2 changed files with 7 additions and 13 deletions

View File

@ -354,12 +354,12 @@ def decode_one_batch(
supervisions = batch["supervisions"] supervisions = batch["supervisions"]
feature_lens = supervisions["num_frames"].to(device) feature_lens = supervisions["num_frames"].to(device)
feature_lens += params.left_context # feature_lens += params.left_context
feature = torch.nn.functional.pad( # feature = torch.nn.functional.pad(
feature, # feature,
pad=(0, 0, 0, params.left_context), # pad=(0, 0, 0, params.left_context),
value=LOG_EPS, # value=LOG_EPS,
) # )
encoder_out, encoder_out_lens = model.encoder( encoder_out, encoder_out_lens = model.encoder(
x=feature, x_lens=feature_lens x=feature, x_lens=feature_lens
@ -668,11 +668,6 @@ def main():
params.unk_id = sp.piece_to_id("<unk>") params.unk_id = sp.piece_to_id("<unk>")
params.vocab_size = sp.get_piece_size() params.vocab_size = sp.get_piece_size()
if params.simulate_streaming:
assert (
params.causal_convolution
), "Decoding in streaming requires causal convolution"
logging.info(params) logging.info(params)
logging.info("About to create model") logging.info("About to create model")

View File

@ -409,8 +409,7 @@ class ScaledLSTM(nn.LSTM):
def _reset_parameters(self, initial_speed: float): def _reset_parameters(self, initial_speed: float):
std = 0.1 / initial_speed std = 0.1 / initial_speed
a = (3 ** 0.5) * std a = (3 ** 0.5) * std
fan_in = self.input_size scale = self.hidden_size ** -0.5
scale = fan_in ** -0.5
v = scale / std v = scale / std
for idx, name in enumerate(self._flat_weights_names): for idx, name in enumerate(self._flat_weights_names):
if "weight" in name: if "weight" in name: