flat weights after scaling

This commit is contained in:
yaozengwei 2022-07-17 20:35:29 +08:00
parent 539a9d75d4
commit 125eac8dee

View File

@ -17,12 +17,14 @@
import collections import collections
from itertools import repeat from itertools import repeat
from typing import Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch import Tensor, _VF from torch import Tensor, _VF
import torch.backends.cudnn.rnn as rnn
def _ntuple(n): def _ntuple(n):
def parse(x): def parse(x):
@ -419,12 +421,72 @@ class ScaledLSTM(nn.LSTM):
elif "bias" in name: elif "bias" in name:
nn.init.constant_(self._flat_weights[idx], 0.0) nn.init.constant_(self._flat_weights[idx], 0.0)
def get_flat_weights(self): def _flatten_parameters(self, flat_weights) -> None:
"""Resets parameter data pointer so that they can use faster code paths.
Right now, this works only if the module is on the GPU and cuDNN is enabled.
Otherwise, it's a no-op.
This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
"""
# Short-circuits if _flat_weights is only partially instantiated
if len(flat_weights) != len(self._flat_weights_names):
return
for w in flat_weights:
if not isinstance(w, Tensor):
return
# Short-circuits if any tensor in flat_weights is not acceptable to cuDNN
# or the tensors in flat_weights are of different dtypes
first_fw = flat_weights[0]
dtype = first_fw.dtype
for fw in flat_weights:
if (
not isinstance(fw.data, Tensor)
or not (fw.data.dtype == dtype)
or not fw.data.is_cuda
or not torch.backends.cudnn.is_acceptable(fw.data)
):
return
# If any parameters alias, we fall back to the slower, copying code path. This is
# a sufficient check, because overlapping parameter buffers that don't completely
# alias would break the assumptions of the uniqueness check in
# Module.named_parameters().
unique_data_ptrs = set(p.data_ptr() for p in flat_weights)
if len(unique_data_ptrs) != len(flat_weights):
return
with torch.cuda.device_of(first_fw):
# Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
# an inplace operation on self._flat_weights
with torch.no_grad():
if torch._use_cudnn_rnn_flatten_weight():
num_weights = 4 if self.bias else 2
if self.proj_size > 0:
num_weights += 1
torch._cudnn_rnn_flatten_weight(
flat_weights,
num_weights,
self.input_size,
rnn.get_cudnn_mode(self.mode),
self.hidden_size,
self.proj_size,
self.num_layers,
self.batch_first,
bool(self.bidirectional),
)
def _get_flat_weights(self):
"""Get scaled weights, and resets their data pointer."""
flat_weights = [] flat_weights = []
for idx in range(len(self._flat_weights_names)): for idx in range(len(self._flat_weights_names)):
flat_weights.append( flat_weights.append(
self._flat_weights[idx] * self._scales[idx].exp() self._flat_weights[idx] * self._scales[idx].exp()
) )
self._flatten_parameters(flat_weights)
return flat_weights return flat_weights
def forward( def forward(
@ -432,7 +494,7 @@ class ScaledLSTM(nn.LSTM):
): ):
# This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa # This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
# The change for calling `_VF.lstm()` is: # The change for calling `_VF.lstm()` is:
# self._flat_weights -> self.get_flat_weights() # self._flat_weights -> self._get_flat_weights()
if hx is None: if hx is None:
h_zeros = torch.zeros( h_zeros = torch.zeros(
1, 1,
@ -454,7 +516,7 @@ class ScaledLSTM(nn.LSTM):
result = _VF.lstm( result = _VF.lstm(
input, input,
hx, hx,
self.get_flat_weights(), self._get_flat_weights(),
self.bias, self.bias,
self.num_layers, self.num_layers,
self.dropout, self.dropout,