flat weights after scaling
This commit is contained in:
parent
539a9d75d4
commit
125eac8dee
@ -17,12 +17,14 @@
|
||||
|
||||
import collections
|
||||
from itertools import repeat
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor, _VF
|
||||
|
||||
import torch.backends.cudnn.rnn as rnn
|
||||
|
||||
|
||||
def _ntuple(n):
|
||||
def parse(x):
|
||||
@ -419,12 +421,72 @@ class ScaledLSTM(nn.LSTM):
|
||||
elif "bias" in name:
|
||||
nn.init.constant_(self._flat_weights[idx], 0.0)
|
||||
|
||||
def get_flat_weights(self):
|
||||
def _flatten_parameters(self, flat_weights) -> None:
|
||||
"""Resets parameter data pointer so that they can use faster code paths.
|
||||
|
||||
Right now, this works only if the module is on the GPU and cuDNN is enabled.
|
||||
Otherwise, it's a no-op.
|
||||
|
||||
This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
|
||||
"""
|
||||
# Short-circuits if _flat_weights is only partially instantiated
|
||||
if len(flat_weights) != len(self._flat_weights_names):
|
||||
return
|
||||
|
||||
for w in flat_weights:
|
||||
if not isinstance(w, Tensor):
|
||||
return
|
||||
# Short-circuits if any tensor in flat_weights is not acceptable to cuDNN
|
||||
# or the tensors in flat_weights are of different dtypes
|
||||
|
||||
first_fw = flat_weights[0]
|
||||
dtype = first_fw.dtype
|
||||
for fw in flat_weights:
|
||||
if (
|
||||
not isinstance(fw.data, Tensor)
|
||||
or not (fw.data.dtype == dtype)
|
||||
or not fw.data.is_cuda
|
||||
or not torch.backends.cudnn.is_acceptable(fw.data)
|
||||
):
|
||||
return
|
||||
|
||||
# If any parameters alias, we fall back to the slower, copying code path. This is
|
||||
# a sufficient check, because overlapping parameter buffers that don't completely
|
||||
# alias would break the assumptions of the uniqueness check in
|
||||
# Module.named_parameters().
|
||||
unique_data_ptrs = set(p.data_ptr() for p in flat_weights)
|
||||
if len(unique_data_ptrs) != len(flat_weights):
|
||||
return
|
||||
|
||||
with torch.cuda.device_of(first_fw):
|
||||
|
||||
# Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
|
||||
# an inplace operation on self._flat_weights
|
||||
with torch.no_grad():
|
||||
if torch._use_cudnn_rnn_flatten_weight():
|
||||
num_weights = 4 if self.bias else 2
|
||||
if self.proj_size > 0:
|
||||
num_weights += 1
|
||||
torch._cudnn_rnn_flatten_weight(
|
||||
flat_weights,
|
||||
num_weights,
|
||||
self.input_size,
|
||||
rnn.get_cudnn_mode(self.mode),
|
||||
self.hidden_size,
|
||||
self.proj_size,
|
||||
self.num_layers,
|
||||
self.batch_first,
|
||||
bool(self.bidirectional),
|
||||
)
|
||||
|
||||
def _get_flat_weights(self):
|
||||
"""Get scaled weights, and resets their data pointer."""
|
||||
flat_weights = []
|
||||
for idx in range(len(self._flat_weights_names)):
|
||||
flat_weights.append(
|
||||
self._flat_weights[idx] * self._scales[idx].exp()
|
||||
)
|
||||
self._flatten_parameters(flat_weights)
|
||||
return flat_weights
|
||||
|
||||
def forward(
|
||||
@ -432,7 +494,7 @@ class ScaledLSTM(nn.LSTM):
|
||||
):
|
||||
# This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
|
||||
# The change for calling `_VF.lstm()` is:
|
||||
# self._flat_weights -> self.get_flat_weights()
|
||||
# self._flat_weights -> self._get_flat_weights()
|
||||
if hx is None:
|
||||
h_zeros = torch.zeros(
|
||||
1,
|
||||
@ -454,7 +516,7 @@ class ScaledLSTM(nn.LSTM):
|
||||
result = _VF.lstm(
|
||||
input,
|
||||
hx,
|
||||
self.get_flat_weights(),
|
||||
self._get_flat_weights(),
|
||||
self.bias,
|
||||
self.num_layers,
|
||||
self.dropout,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user