Merge remote-tracking branch 'dan/master' into rnnt-lstm-2022-04-21

This commit is contained in:
Fangjun Kuang 2022-04-28 11:06:55 +08:00
commit 026978e1c0
4 changed files with 22 additions and 9 deletions

View File

@ -35,7 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
--world-size 4 \ --world-size 4 \
--num-epochs 30 \ --num-epochs 30 \
--start-epoch 0 \ --start-epoch 0 \
--use_fp16 1 \ --use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \ --exp-dir pruned_transducer_stateless2/exp \
--full-libri 1 \ --full-libri 1 \
--max-duration 550 --max-duration 550

View File

@ -89,9 +89,9 @@ class Decoder(nn.Module):
- (h, c), containing the state information for LSTM layers. - (h, c), containing the state information for LSTM layers.
Both are of shape (num_layers, N, C) Both are of shape (num_layers, N, C)
""" """
embeding_out = self.embedding(y) embedding_out = self.embedding(y)
embeding_out = self.embedding_dropout(embeding_out) embedding_out = self.embedding_dropout(embedding_out)
rnn_out, (h, c) = self.rnn(embeding_out, states) rnn_out, (h, c) = self.rnn(embedding_out, states)
out = self.output_linear(rnn_out) out = self.output_linear(rnn_out)
return out, (h, c) return out, (h, c)

View File

@ -84,9 +84,9 @@ class Decoder(nn.Module):
- (h, c), which contain the state information for RNN layers. - (h, c), which contain the state information for RNN layers.
Both are of shape (num_layers, N, C) Both are of shape (num_layers, N, C)
""" """
embeding_out = self.embedding(y) embedding_out = self.embedding(y)
embeding_out = self.embedding_dropout(embeding_out) embedding_out = self.embedding_dropout(embedding_out)
rnn_out, (h, c) = self.rnn(embeding_out, states) rnn_out, (h, c) = self.rnn(embedding_out, states)
out = self.output_linear(rnn_out) out = self.output_linear(rnn_out)
return out, (h, c) return out, (h, c)

View File

@ -150,12 +150,25 @@ def average_checkpoints(
n = len(filenames) n = len(filenames)
avg = torch.load(filenames[0], map_location=device)["model"] avg = torch.load(filenames[0], map_location=device)["model"]
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
uniqued: Dict[int, str] = dict()
for k, v in avg.items():
v_data_ptr = v.data_ptr()
if v_data_ptr in uniqued:
continue
uniqued[v_data_ptr] = k
uniqued_names = list(uniqued.values())
for i in range(1, n): for i in range(1, n):
state_dict = torch.load(filenames[i], map_location=device)["model"] state_dict = torch.load(filenames[i], map_location=device)["model"]
for k in avg: for k in uniqued_names:
avg[k] += state_dict[k] avg[k] += state_dict[k]
for k in avg: for k in uniqued_names:
if avg[k].is_floating_point(): if avg[k].is_floating_point():
avg[k] /= n avg[k] /= n
else: else: