flat weights after scaling

This commit is contained in:
yaozengwei 2022-07-17 20:35:29 +08:00
parent 539a9d75d4
commit 125eac8dee

View File

@ -17,12 +17,14 @@
import collections
from itertools import repeat
from typing import Optional, Tuple
from typing import List, Optional, Tuple
import torch
import torch.nn as nn
from torch import Tensor, _VF
import torch.backends.cudnn.rnn as rnn
def _ntuple(n):
def parse(x):
@ -419,12 +421,72 @@ class ScaledLSTM(nn.LSTM):
elif "bias" in name:
nn.init.constant_(self._flat_weights[idx], 0.0)
def get_flat_weights(self):
def _flatten_parameters(self, flat_weights) -> None:
"""Resets parameter data pointer so that they can use faster code paths.
Right now, this works only if the module is on the GPU and cuDNN is enabled.
Otherwise, it's a no-op.
This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
"""
# Short-circuits if _flat_weights is only partially instantiated
if len(flat_weights) != len(self._flat_weights_names):
return
for w in flat_weights:
if not isinstance(w, Tensor):
return
# Short-circuits if any tensor in flat_weights is not acceptable to cuDNN
# or the tensors in flat_weights are of different dtypes
first_fw = flat_weights[0]
dtype = first_fw.dtype
for fw in flat_weights:
if (
not isinstance(fw.data, Tensor)
or not (fw.data.dtype == dtype)
or not fw.data.is_cuda
or not torch.backends.cudnn.is_acceptable(fw.data)
):
return
# If any parameters alias, we fall back to the slower, copying code path. This is
# a sufficient check, because overlapping parameter buffers that don't completely
# alias would break the assumptions of the uniqueness check in
# Module.named_parameters().
unique_data_ptrs = set(p.data_ptr() for p in flat_weights)
if len(unique_data_ptrs) != len(flat_weights):
return
with torch.cuda.device_of(first_fw):
# Note: no_grad() is necessary since _cudnn_rnn_flatten_weight is
# an inplace operation on self._flat_weights
with torch.no_grad():
if torch._use_cudnn_rnn_flatten_weight():
num_weights = 4 if self.bias else 2
if self.proj_size > 0:
num_weights += 1
torch._cudnn_rnn_flatten_weight(
flat_weights,
num_weights,
self.input_size,
rnn.get_cudnn_mode(self.mode),
self.hidden_size,
self.proj_size,
self.num_layers,
self.batch_first,
bool(self.bidirectional),
)
def _get_flat_weights(self):
"""Get scaled weights, and resets their data pointer."""
flat_weights = []
for idx in range(len(self._flat_weights_names)):
flat_weights.append(
self._flat_weights[idx] * self._scales[idx].exp()
)
self._flatten_parameters(flat_weights)
return flat_weights
def forward(
@ -432,7 +494,7 @@ class ScaledLSTM(nn.LSTM):
):
# This function is modified from https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/rnn.py # noqa
# The change for calling `_VF.lstm()` is:
# self._flat_weights -> self.get_flat_weights()
# self._flat_weights -> self._get_flat_weights()
if hx is None:
h_zeros = torch.zeros(
1,
@ -454,7 +516,7 @@ class ScaledLSTM(nn.LSTM):
result = _VF.lstm(
input,
hx,
self.get_flat_weights(),
self._get_flat_weights(),
self.bias,
self.num_layers,
self.dropout,