mirror of
https://github.com/k2-fsa/icefall.git
synced 2025-12-11 06:55:27 +00:00
Merge remote-tracking branch 'dan/master' into rnnt-lstm-2022-04-21
This commit is contained in:
commit
026978e1c0
@ -35,7 +35,7 @@ export CUDA_VISIBLE_DEVICES="0,1,2,3"
|
||||
--world-size 4 \
|
||||
--num-epochs 30 \
|
||||
--start-epoch 0 \
|
||||
--use_fp16 1 \
|
||||
--use-fp16 1 \
|
||||
--exp-dir pruned_transducer_stateless2/exp \
|
||||
--full-libri 1 \
|
||||
--max-duration 550
|
||||
|
||||
@ -89,9 +89,9 @@ class Decoder(nn.Module):
|
||||
- (h, c), containing the state information for LSTM layers.
|
||||
Both are of shape (num_layers, N, C)
|
||||
"""
|
||||
embeding_out = self.embedding(y)
|
||||
embeding_out = self.embedding_dropout(embeding_out)
|
||||
rnn_out, (h, c) = self.rnn(embeding_out, states)
|
||||
embedding_out = self.embedding(y)
|
||||
embedding_out = self.embedding_dropout(embedding_out)
|
||||
rnn_out, (h, c) = self.rnn(embedding_out, states)
|
||||
out = self.output_linear(rnn_out)
|
||||
|
||||
return out, (h, c)
|
||||
|
||||
@ -84,9 +84,9 @@ class Decoder(nn.Module):
|
||||
- (h, c), which contain the state information for RNN layers.
|
||||
Both are of shape (num_layers, N, C)
|
||||
"""
|
||||
embeding_out = self.embedding(y)
|
||||
embeding_out = self.embedding_dropout(embeding_out)
|
||||
rnn_out, (h, c) = self.rnn(embeding_out, states)
|
||||
embedding_out = self.embedding(y)
|
||||
embedding_out = self.embedding_dropout(embedding_out)
|
||||
rnn_out, (h, c) = self.rnn(embedding_out, states)
|
||||
out = self.output_linear(rnn_out)
|
||||
|
||||
return out, (h, c)
|
||||
|
||||
@ -150,12 +150,25 @@ def average_checkpoints(
|
||||
n = len(filenames)
|
||||
|
||||
avg = torch.load(filenames[0], map_location=device)["model"]
|
||||
|
||||
# Identify shared parameters. Two parameters are said to be shared
|
||||
# if they have the same data_ptr
|
||||
uniqued: Dict[int, str] = dict()
|
||||
|
||||
for k, v in avg.items():
|
||||
v_data_ptr = v.data_ptr()
|
||||
if v_data_ptr in uniqued:
|
||||
continue
|
||||
uniqued[v_data_ptr] = k
|
||||
|
||||
uniqued_names = list(uniqued.values())
|
||||
|
||||
for i in range(1, n):
|
||||
state_dict = torch.load(filenames[i], map_location=device)["model"]
|
||||
for k in avg:
|
||||
for k in uniqued_names:
|
||||
avg[k] += state_dict[k]
|
||||
|
||||
for k in avg:
|
||||
for k in uniqued_names:
|
||||
if avg[k].is_floating_point():
|
||||
avg[k] /= n
|
||||
else:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user