flat weights after scaling
This commit is contained in:
parent
539a9d75d4
commit
125eac8dee
@ -17,12 +17,14 @@
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
from typing import Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch import Tensor, _VF
|
from torch import Tensor, _VF
|
||||||
|
|
||||||
|
import torch.backends.cudnn.rnn as rnn
|
||||||
|
|
||||||
|
|
||||||
def _ntuple(n):
|
def _ntuple(n):
|
||||||
def parse(x):
|
def parse(x):
|
||||||
@ -419,12 +421,72 @@ class ScaledLSTM(nn.LSTM):
|
|||||||
elif "bias" in name:
|
elif "bias" in name:
|
||||||
nn.init.constant_(self._flat_weights[idx], 0.0)
|
nn.init.constant_(self._flat_weights[idx], 0.0)
|
||||||
|
|
||||||
def get_flat_weights(self):
|
def _flatten_parameters(self, flat_weights) -> None:
|
||||||
|
"""Resets parameter data pointer so that they can use faster code paths.
|
||||||
|
|
||||||
|
Right now, this works only if the module is on the GPU and cuDNN is enabled.
|
||||||
|
Otherwise, it's a no-op.
|
||||||
|
|
||||||
|
This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
|
||||||
|
"""
|
||||||
|
# Short-circuits if _flat_weights is only partially instantiated
|
||||||
|
if len(flat_weights) != len(self._flat_weights_names):
|
||||||
|
return
|
||||||
|
|
||||||
|
for w in flat_weights:
|
||||||
|
if not isinstance(w, Tensor):
|
||||||
|
return
|
||||||
|
# Short-circuits if any tensor in flat_weights is not acceptable to cuDNN
|
||||||
|
# or the tensors in flat_weights are of different dtypes
|
||||||
|
|
||||||
|
first_fw = flat_weights[0]
|
||||||
|
dtype = first_fw.dtype
|
||||||
|
for fw in flat_weights:
|
||||||
|
if (
|
||||||
|
not isinstance(fw.data, Tensor)
|
||||||
|
or not (fw.data.dtype == dtype)
|
||||||
|
or not fw.data.is_cuda
|
||||||
|
or not torch.backends.cudnn.is_acceptable(fw.data)
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# If any parameters alias, we fall back to the slower, copying code path. This is
|
||||||
|
# a sufficient check, because overlapping parameter buffers that don't completely
|
||||||
|
# alias would break the assumptions of the uniqueness check in
|
||||||
|
# Module.named_parameters().
|
||||||
|
unique_data_ptrs = set(p.data_ptr() for p in flat_weights)
|
||||||
|
if len(unique_data_ptrs) != len(flat_weights):
|
||||||
|
return
|
||||||
|
|
||||||
|
with torch.cuda.device_of(first_fw):
|
||||||
|
|
||||||
|
# Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
|
||||||
|
# an inplace operation on self._flat_weights
|
||||||
|
with torch.no_grad():
|
||||||
|
if torch._use_cudnn_rnn_flatten_weight():
|
||||||
|
num_weights = 4 if self.bias else 2
|
||||||
|
if self.proj_size > 0:
|
||||||
|
num_weights += 1
|
||||||
|
torch._cudnn_rnn_flatten_weight(
|
||||||
|
flat_weights,
|
||||||
|
num_weights,
|
||||||
|
self.input_size,
|
||||||
|
rnn.get_cudnn_mode(self.mode),
|
||||||
|
self.hidden_size,
|
||||||
|
self.proj_size,
|
||||||
|
self.num_layers,
|
||||||
|
self.batch_first,
|
||||||
|
bool(self.bidirectional),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_flat_weights(self):
|
||||||
|
"""Get scaled weights, and resets their data pointer."""
|
||||||
flat_weights = []
|
flat_weights = []
|
||||||
for idx in range(len(self._flat_weights_names)):
|
for idx in range(len(self._flat_weights_names)):
|
||||||
flat_weights.append(
|
flat_weights.append(
|
||||||
self._flat_weights[idx] * self._scales[idx].exp()
|
self._flat_weights[idx] * self._scales[idx].exp()
|
||||||
)
|
)
|
||||||
|
self._flatten_parameters(flat_weights)
|
||||||
return flat_weights
|
return flat_weights
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -432,7 +494,7 @@ class ScaledLSTM(nn.LSTM):
|
|||||||
):
|
):
|
||||||
# This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
|
# This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
|
||||||
# The change for calling `_VF.lstm()` is:
|
# The change for calling `_VF.lstm()` is:
|
||||||
# self._flat_weights -> self.get_flat_weights()
|
# self._flat_weights -> self._get_flat_weights()
|
||||||
if hx is None:
|
if hx is None:
|
||||||
h_zeros = torch.zeros(
|
h_zeros = torch.zeros(
|
||||||
1,
|
1,
|
||||||
@ -454,7 +516,7 @@ class ScaledLSTM(nn.LSTM):
|
|||||||
result = _VF.lstm(
|
result = _VF.lstm(
|
||||||
input,
|
input,
|
||||||
hx,
|
hx,
|
||||||
self.get_flat_weights(),
|
self._get_flat_weights(),
|
||||||
self.bias,
|
self.bias,
|
||||||
self.num_layers,
|
self.num_layers,
|
||||||
self.dropout,
|
self.dropout,
|
||||||
|
|||||||
Reference in New Issue
Block a user