Merge remote-tracking branch 'dan/master' into rnnt-lstm-2022-04-21

This commit is contained in:
Fangjun Kuang 2022-04-28 11:06:55 +08:00
commit 026978e1c0
4 changed files with 22 additions and 9 deletions

View File

@ -35,7 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--world-size 4 \
--num-epochs 30 \
--start-epoch 0 \
--use_fp16 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \
--max-duration 550

View File

@ -89,9 +89,9 @@ class Decoder(nn.Module):
- (h, c), containing the state information for LSTM layers.
Both are of shape (num_layers, N, C)
"""
embeding_out = self.embedding(y)
embeding_out = self.embedding_dropout(embeding_out)
rnn_out, (h, c) = self.rnn(embeding_out, states)
embedding_out = self.embedding(y)
embedding_out = self.embedding_dropout(embedding_out)
rnn_out, (h, c) = self.rnn(embedding_out, states)
out = self.output_linear(rnn_out)
return out, (h, c)

View File

@ -84,9 +84,9 @@ class Decoder(nn.Module):
- (h, c), which contain the state information for RNN layers.
Both are of shape (num_layers, N, C)
"""
embeding_out = self.embedding(y)
embeding_out = self.embedding_dropout(embeding_out)
rnn_out, (h, c) = self.rnn(embeding_out, states)
embedding_out = self.embedding(y)
embedding_out = self.embedding_dropout(embedding_out)
rnn_out, (h, c) = self.rnn(embedding_out, states)
out = self.output_linear(rnn_out)
return out, (h, c)

View File

@ -150,12 +150,25 @@ def average_checkpoints(
n = len(filenames)
avg = torch.load(filenames[0], map_location=device)["model"]
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
uniqued: Dict[int, str] = dict()
for k, v in avg.items():
v_data_ptr = v.data_ptr()
if v_data_ptr in uniqued:
continue
uniqued[v_data_ptr] = k
uniqued_names = list(uniqued.values())
for i in range(1, n):
state_dict = torch.load(filenames[i], map_location=device)["model"]
for k in avg:
for k in uniqued_names:
avg[k] += state_dict[k]
for k in avg:
for k in uniqued_names:
if avg[k].is_floating_point():
avg[k] /= n
else: