diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..e6d4c464d --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +apex.egg-info +dist +build +docs/build \ No newline at end of file diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 000000000..e69de29bb diff --git a/LICENSE b/LICENSE new file mode 100644 index 000000000..3d1e9454f --- /dev/null +++ b/LICENSE @@ -0,0 +1,11 @@ +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 000000000..51098cf2d --- /dev/null +++ b/README.md @@ -0,0 +1,40 @@ +# Introduction + +This is a repo is designed to hold PyTorch modules and utilities that are under active development and experimental. This repo is not designed as a long term solution or a production solution. Things placed in here are intended to be eventually moved to upstream PyTorch. + +# Requirements + +Python 3 +PyTorch 0.3 or newer +CUDA 9 + +# [Full Documentation](https://nvidia.github.io/apex) + +# Quick Start + +To build the extension run the following command in the root directory of this project +``` +python setup.py install +``` + +To use the extension simply run +``` +import apex +``` +and optionally (if required for your use) +``` +import apex._C as apex_backend +``` + +# What's included + +Current version of apex contains: +1. Mixed precision utilities can be found [here](https://nvidia.github.io/apex/fp16_utils) examples of using mixed precision utilities can be found for the [PyTorch imagenet example](https://github.com/csarofeen/examples/tree/apex/imagenet) and the [PyTorch word language model example](https://github.com/csarofeen/examples/tree/apex/word_language_model). +2. Parallel utilities can be found [here](https://nvidia.github.io/apex/parallel) and an example/walkthrough can be found [here](https://github.com/csarofeen/examples/tree/apex/distributed) + - apex/parallel/distributed.py contains a simplified implementation of PyTorch's DistributedDataParallel that's optimized for use with NCCL in single gpu / process mode + - apex/parallel/multiproc.py is a simple multi-process launcher that can be used on a single node/computer with multiple GPU's +3. Reparameterization function that allows you to recursively apply reparameterization to an entire module (including children modules). +4. An experimental and in development flexible RNN API. + + + diff --git a/apex/RNN/RNNBackend.py b/apex/RNN/RNNBackend.py new file mode 100644 index 000000000..d0a4eb6e5 --- /dev/null +++ b/apex/RNN/RNNBackend.py @@ -0,0 +1,365 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable + +import torch.nn.functional as F + +import math + + +def is_iterable(maybe_iterable): + return isinstance(maybe_iterable, list) or isinstance(maybe_iterable, tuple) + + +def flatten_list(tens_list): + """ + flatten_list + """ + if not is_iterable(tens_list): + return tens_list + + return torch.cat(tens_list, dim=0).view(len(tens_list), *tens_list[0].size() ) + + +#These modules always assumes batch_first +class bidirectionalRNN(nn.Module): + """ + bidirectionalRNN + """ + def __init__(self, inputRNN, num_layers=1, dropout = 0): + super(bidirectionalRNN, self).__init__() + self.dropout = dropout + self.fwd = stackedRNN(inputRNN, num_layers=num_layers, dropout = dropout) + self.bckwrd = stackedRNN(inputRNN.new_like(), num_layers=num_layers, dropout = dropout) + self.rnns = nn.ModuleList([self.fwd, self.bckwrd]) + + #collect hidden option will return all hidden/cell states from entire RNN + def forward(self, input, collect_hidden=False): + """ + forward() + """ + seq_len = input.size(0) + bsz = input.size(1) + + fwd_out, fwd_hiddens = list(self.fwd(input, collect_hidden = collect_hidden)) + bckwrd_out, bckwrd_hiddens = list(self.bckwrd(input, reverse=True, collect_hidden = collect_hidden)) + + output = torch.cat( [fwd_out, bckwrd_out], -1 ) + hiddens = tuple( torch.cat(hidden, -1) for hidden in zip( fwd_hiddens, bckwrd_hiddens) ) + + return output, hiddens + + def reset_parameters(self): + """ + reset_parameters() + """ + for rnn in self.rnns: + rnn.reset_parameters() + + def init_hidden(self, bsz): + """ + init_hidden() + """ + for rnn in self.rnns: + rnn.init_hidden(bsz) + + def detach_hidden(self): + """ + detach_hidden() + """ + for rnn in self.rnns: + rnn.detachHidden() + + def reset_hidden(self, bsz): + """ + reset_hidden() + """ + for rnn in self.rnns: + rnn.reset_hidden(bsz) + + def init_inference(self, bsz): + """ + init_inference() + """ + for rnn in self.rnns: + rnn.init_inference(bsz) + + +#assumes hidden_state[0] of inputRNN is output hidden state +#constructor either takes an RNNCell or list of RNN layers +class stackedRNN(nn.Module): + """ + stackedRNN + """ + def __init__(self, inputRNN, num_layers=1, dropout=0): + super(stackedRNN, self).__init__() + + self.dropout = dropout + + if isinstance(inputRNN, RNNCell): + self.rnns = [inputRNN] + for i in range(num_layers-1): + self.rnns.append(inputRNN.new_like(inputRNN.output_size)) + elif isinstance(inputRNN, list): + assert len(inputRNN) == num_layers, "RNN list length must be equal to num_layers" + self.rnns=inputRNN + else: + raise RuntimeError() + + self.nLayers = len(self.rnns) + + self.rnns = nn.ModuleList(self.rnns) + + + ''' + Returns output as hidden_state[0] Tensor([sequence steps][batch size][features]) + If collect hidden will also return Tuple( + [n_hidden_states][sequence steps] Tensor([layer][batch size][features]) + ) + If not collect hidden will also return Tuple( + [n_hidden_states] Tensor([layer][batch size][features]) + ''' + def forward(self, input, collect_hidden=False, reverse=False): + """ + forward() + """ + seq_len = input.size(0) + bsz = input.size(1) + inp_iter = reversed(range(seq_len)) if reverse else range(seq_len) + + hidden_states = [[] for i in range(self.nLayers)] + outputs = [] + + for seq in inp_iter: + for layer in range(self.nLayers): + + if layer == 0: + prev_out = input[seq] + + outs = self.rnns[layer](prev_out) + + if collect_hidden: + hidden_states[layer].append(outs) + elif seq == seq_len-1: + hidden_states[layer].append(outs) + + prev_out = outs[0] + + outputs.append(prev_out) + + if reverse: + outputs = list(reversed(outputs)) + ''' + At this point outputs is in format: + list( [seq_length] x Tensor([bsz][features]) ) + need to convert it to: + list( Tensor([seq_length][bsz][features]) ) + ''' + output = flatten_list(outputs) + + ''' + hidden_states at this point is in format: + list( [layer][seq_length][hidden_states] x Tensor([bsz][features]) ) + need to convert it to: + For not collect hidden: + list( [hidden_states] x Tensor([layer][bsz][features]) ) + For collect hidden: + list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) ) + ''' + if not collect_hidden: + seq_len = 1 + n_hid = self.rnns[0].n_hidden_states + new_hidden = [ [ [ None for k in range(self.nLayers)] for j in range(seq_len) ] for i in range(n_hid) ] + + + for i in range(n_hid): + for j in range(seq_len): + for k in range(self.nLayers): + new_hidden[i][j][k] = hidden_states[k][j][i] + + hidden_states = new_hidden + #Now in format list( [hidden_states][seq_length][layer] x Tensor([bsz][features]) ) + #Reverse seq_length if reverse + if reverse: + hidden_states = list( list(reversed(list(entry))) for entry in hidden_states) + + #flatten layer dimension into tensor + hiddens = list( list( + flatten_list(seq) for seq in hidden ) + for hidden in hidden_states ) + + #Now in format list( [hidden_states][seq_length] x Tensor([layer][bsz][features]) ) + #Remove seq_length dimension if not collect_hidden + if not collect_hidden: + hidden_states = list( entry[0] for entry in hidden_states) + return output, hidden_states + + def reset_parameters(self): + """ + reset_parameters() + """ + for rnn in self.rnns: + rnn.reset_parameters() + + def init_hidden(self, bsz): + """ + init_hidden() + """ + for rnn in self.rnns: + rnn.init_hidden(bsz) + + def detach_hidden(self): + """ + detach_hidden() + """ + for rnn in self.rnns: + rnn.detach_hidden() + + def reset_hidden(self, bsz): + """ + reset_hidden() + """ + for rnn in self.rnns: + rnn.reset_hidden(bsz) + + def init_inference(self, bsz): + """ + init_inference() + """ + for rnn in self.rnns: + rnn.init_inference(bsz) + +class RNNCell(nn.Module): + """ + RNNCell + gate_multiplier is related to the architecture you're working with + For LSTM-like it will be 4 and GRU-like will be 3. + Always assumes input is NOT batch_first. + Output size that's not hidden size will use output projection + Hidden_states is number of hidden states that are needed for cell + if one will go directly to cell as tensor, if more will go as list + """ + def __init__(self, gate_multiplier, input_size, hidden_size, cell, n_hidden_states = 2, bias = False, output_size = None): + super(RNNCell, self).__init__() + + self.gate_multiplier = gate_multiplier + self.input_size = input_size + self.hidden_size = hidden_size + self.cell = cell + self.bias = bias + self.output_size = output_size + if output_size is None: + self.output_size = hidden_size + + self.gate_size = gate_multiplier * self.hidden_size + self.n_hidden_states = n_hidden_states + + self.w_ih = nn.Parameter(torch.Tensor(self.gate_size, self.input_size)) + self.w_hh = nn.Parameter(torch.Tensor(self.gate_size, self.output_size)) + + #Check if there's recurrent projection + if(self.output_size != self.hidden_size): + self.w_ho = nn.Parameter(torch.Tensor(self.output_size, self.hidden_size)) + + self.b_ih = self.b_hh = None + if self.bias: + self.b_ih = nn.Parameter(torch.Tensor(self.gate_size)) + self.b_hh = nn.Parameter(torch.Tensor(self.gate_size)) + + #hidden states for forward + self.hidden = [ None for states in range(self.n_hidden_states)] + + self.reset_parameters() + + def new_like(self, new_input_size=None): + """ + new_like() + """ + if new_input_size is None: + new_input_size = self.input_size + + return type(self)(self.gate_multiplier, + new_input_size, + self.hidden_size, + self.cell, + self.n_hidden_states, + self.bias, + self.output_size) + + + #Use xavier where we can (weights), otherwise use uniform (bias) + def reset_parameters(self, gain=1): + """ + reset_parameters() + """ + stdev = 1.0 / math.sqrt(self.hidden_size) + for param in self.parameters(): + param.data.uniform_(-stdev, stdev) + ''' + Xavier reset: + def reset_parameters(self, gain=1): + stdv = 1.0 / math.sqrt(self.gate_size) + + for param in self.parameters(): + if (param.dim() > 1): + torch.nn.init.xavier_normal(param, gain) + else: + param.data.uniform_(-stdv, stdv) + ''' + def init_hidden(self, bsz): + """ + init_hidden() + """ + for param in self.parameters(): + if param is not None: + a_param = param + break + + for i, _ in enumerate(self.hidden): + if(self.hidden[i] is None or self.hidden[i].data.size()[0] != bsz): + + if i==0: + hidden_size = self.output_size + else: + hidden_size = self.hidden_size + + tens = a_param.data.new(bsz, hidden_size).zero_() + self.hidden[i] = Variable(tens, requires_grad=False) + + + def reset_hidden(self, bsz): + """ + reset_hidden() + """ + for i, _ in enumerate(self.hidden): + self.hidden[i] = None + self.init_hidden(bsz) + + def detach_hidden(self): + """ + detach_hidden() + """ + for i, _ in enumerate(self.hidden): + if self.hidden[i] is None: + raise RuntimeError("Must inialize hidden state before you can detach it") + for i, _ in enumerate(self.hidden): + self.hidden[i] = self.hidden[i].detach() + + def forward(self, input): + """ + forward() + if not inited or bsz has changed this will create hidden states + """ + self.init_hidden(input.size()[0]) + + hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden + self.hidden = self.cell(input, hidden_state, self.w_ih, self.w_hh, b_ih=self.b_ih, b_hh=self.b_hh) + if(self.n_hidden_states > 1): + self.hidden = list(self.hidden) + else: + self.hidden=[self.hidden] + + if self.output_size != self.hidden_size: + self.hidden[0] = F.linear(self.hidden[0], self.w_ho) + + return tuple(self.hidden) diff --git a/apex/RNN/__init__.py b/apex/RNN/__init__.py new file mode 100644 index 000000000..d70674666 --- /dev/null +++ b/apex/RNN/__init__.py @@ -0,0 +1,3 @@ +from .models import LSTM, GRU, ReLU, Tanh, mLSTM + +__all__ = ['models'] diff --git a/apex/RNN/cells.py b/apex/RNN/cells.py new file mode 100644 index 000000000..bf7bffdd6 --- /dev/null +++ b/apex/RNN/cells.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .RNNBackend import RNNCell + +from torch.nn._functions.thnn import rnnFusedPointwise as fusedBackend + +import math + + +class mLSTMRNNCell(RNNCell): + """ + mLSTMRNNCell + """ + + def __init__(self, input_size, hidden_size, bias = False, output_size = None): + gate_multiplier = 4 + super(mLSTMRNNCell, self).__init__(gate_multiplier, input_size, hidden_size, mLSTMCell, n_hidden_states = 2, bias = bias, output_size = output_size) + + self.w_mih = nn.Parameter(torch.Tensor(self.output_size, self.input_size)) + self.w_mhh = nn.Parameter(torch.Tensor(self.output_size, self.output_size)) + + self.reset_parameters() + + def forward(self, input): + """ + mLSTMRNNCell.forward() + """ + #if not inited or bsz has changed this will create hidden states + self.init_hidden(input.size()[0]) + + hidden_state = self.hidden[0] if self.n_hidden_states == 1 else self.hidden + + self.hidden = list( + self.cell(input, hidden_state, self.w_ih, self.w_hh, self.w_mih, self.w_mhh, + b_ih=self.b_ih, b_hh=self.b_hh) + ) + + if self.output_size != self.hidden_size: + self.hidden[0] = F.linear(self.hidden[0], self.w_ho) + return tuple(self.hidden) + + + def new_like(self, new_input_size=None): + if new_input_size is None: + new_input_size = self.input_size + + return type(self)( + new_input_size, + self.hidden_size, + self.bias, + self.output_size) + +def mLSTMCell(input, hidden, w_ih, w_hh, w_mih, w_mhh, b_ih=None, b_hh=None): + """ + mLSTMCell + """ + + if input.is_cuda: + igates = F.linear(input, w_ih) + m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh) + hgates = F.linear(m, w_hh) + + state = fusedBackend.LSTMFused.apply + return state(igates, hgates, hidden[1], b_ih, b_hh) + + hx, cx = hidden + + m = F.linear(input, w_mih) * F.linear(hidden[0], w_mhh) + igates = F.linear(input, w_ih, b_ih) + F.linear(m, w_hh, b_hh) + + ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) + + ingate = F.sigmoid(ingate) + forgetgate = F.sigmoid(forgetgate) + cellgate = F.tanh(cellgate) + outgate = F.sigmoid(outgate) + + cy = (forgetgate * cx) + (ingate * cellgate) + hy = outgate * F.tanh(cy) + + return hy, cy + diff --git a/apex/RNN/models.py b/apex/RNN/models.py new file mode 100644 index 000000000..dd7adce04 --- /dev/null +++ b/apex/RNN/models.py @@ -0,0 +1,54 @@ +import torch + +from torch.nn._functions.rnn import LSTMCell, RNNReLUCell, RNNTanhCell, GRUCell + +from .RNNBackend import bidirectionalRNN, stackedRNN, RNNCell +from .cells import mLSTMRNNCell, mLSTMCell + +def toRNNBackend(inputRNN, num_layers, bidirectional=False, dropout = 0): + """ + :class:`toRNNBackend` + """ + + if bidirectional: + return bidirectionalRNN(inputRNN, num_layers, dropout = dropout) + else: + return stackedRNN(inputRNN, num_layers, dropout = dropout) + + +def LSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None): + """ + :class:`LSTM` + """ + inputRNN = RNNCell(4, input_size, hidden_size, LSTMCell, 2, bias, output_size) + return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout) + +def GRU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None): + """ + :class:`GRU` + """ + inputRNN = RNNCell(3, input_size, hidden_size, GRUCell, 1, bias, output_size) + return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout) + +def ReLU(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None): + """ + :class:`ReLU` + """ + inputRNN = RNNCell(1, input_size, hidden_size, RNNReLUCell, 1, bias, output_size) + return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout) + +def Tanh(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None): + """ + :class:`Tanh` + """ + inputRNN = RNNCell(1, input_size, hidden_size, RNNTanhCell, 1, bias, output_size) + return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout) + +def mLSTM(input_size, hidden_size, num_layers, bias=True, batch_first=False, dropout=0, bidirectional=False, output_size = None): + """ + :class:`mLSTM` + """ + inputRNN = mLSTMRNNCell(input_size, hidden_size, bias=bias, output_size=output_size) + return toRNNBackend(inputRNN, num_layers, bidirectional, dropout=dropout) + + diff --git a/apex/__init__.py b/apex/__init__.py new file mode 100644 index 000000000..7e46d0b09 --- /dev/null +++ b/apex/__init__.py @@ -0,0 +1,4 @@ +from . import RNN +from . import reparameterization +from . import fp16_utils +from . import parallel diff --git a/apex/fp16_utils/__init__.py b/apex/fp16_utils/__init__.py new file mode 100644 index 000000000..d62b935a0 --- /dev/null +++ b/apex/fp16_utils/__init__.py @@ -0,0 +1,17 @@ +from .fp16util import ( + BN_convert_float, + network_to_half, + prep_param_lists, + model_grads_to_master_grads, + master_params_to_model_params, + tofp16, +) + + +from .fused_weight_norm import Fused_Weight_Norm + + +from .fp16_optimizer import fp32_to_fp16, fp16_to_fp32, FP16_Module, FP16_Optimizer + + +from .loss_scaler import LossScaler, DynamicLossScaler diff --git a/apex/fp16_utils/fp16_optimizer.py b/apex/fp16_utils/fp16_optimizer.py new file mode 100755 index 000000000..f4e5e2d29 --- /dev/null +++ b/apex/fp16_utils/fp16_optimizer.py @@ -0,0 +1,552 @@ +import torch +from torch import nn +from torch.autograd import Variable +from torch.nn.parameter import Parameter +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +from .loss_scaler import DynamicLossScaler, LossScaler +from .fp16util import model_grads_to_master_grads, master_params_to_model_params + +FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor) +HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor) + +def conversion_helper(val, conversion): + """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure.""" + if not isinstance(val, (tuple, list)): + return conversion(val) + rtn = [conversion_helper(v, conversion) for v in val] + if isinstance(val, tuple): + rtn = tuple(rtn) + return rtn + +def fp32_to_fp16(val): + """Convert fp32 `val` to fp16""" + def half_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, FLOAT_TYPES): + val = val.half() + return val + return conversion_helper(val, half_conversion) + +def fp16_to_fp32(val): + """Convert fp16 `val` to fp32""" + def float_conversion(val): + val_typecheck = val + if isinstance(val_typecheck, (Parameter, Variable)): + val_typecheck = val.data + if isinstance(val_typecheck, HALF_TYPES): + val = val.float() + return val + return conversion_helper(val, float_conversion) + +class FP16_Module(nn.Module): + def __init__(self, module): + super(FP16_Module, self).__init__() + self.add_module('module', module.half()) + + def forward(self, *inputs, **kwargs): + return fp16_to_fp32(self.module(*(fp32_to_fp16(inputs)), **kwargs)) + +class FP16_Optimizer(object): + """ + :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer, + and manage (dynamic) loss scaling and master weights in a manner transparent to the user. + For standard use, only two lines must be changed: creating the :class:`FP16_Optimizer` instance, + and changing the call to ``backward``. + + Example:: + + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + # Name the FP16_Optimizer instance to replace the existing optimizer + # (recommended but not required): + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + # loss.backward() becomes: + optimizer.backward(loss) + ... + + Example with dynamic loss scaling:: + + ... + optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True, + dynamic_loss_args={'scale_window' : 500}) + # dynamic_loss_args is optional. + + Args: + init_optimizer (torch.optim.optimizer): Existing optimizer containing initialized fp16 parameters. Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters with new fp32 parameters copied from the original ones. :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy after each step. + static_loss_scale (float, optional, default=1.0): Loss scale used internally to scale fp16 gradients computed by the model. Scaled gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate. + dynamic_loss_scale (bool, optional, default=False): Use dynamic loss scaling. If True, this will override any ``static_loss_scale`` option. + dynamic_loss_args (dict, optional, default=None): Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor. Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor. If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used. + + ``init_optimizer`` is expected to have been constructed in the ordinary way. + It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be + named to replace ``init_optimizer``, for two reasons: + First, it means that references to the same name + later in the file will not have to change. + Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to + modify ``init_optimizer``. If you do choose a unique name for the new + :class:`FP16_Optimizer` instance, you should only work with this new instance, + because the preexisting optimizer might no longer behave as expected. + + ``init_optimizer`` may be any Pytorch optimizer. + It may contain a mixture of fp16 and fp32 parameters organized into any number of + ``param_groups`` with different hyperparameters. The :class:`FP16_Optimizer` constructor will + ingest these ``param_groups`` and remember them. + + Calls to :: + + loss.backward() + + must be replaced with :: + + optimizer.backward(loss) + + because :class:`FP16_Optimizer` requires ownership of the backward pass to implement + loss scaling and copies to master gradients. + + .. note:: + Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients + are downscaled before being applied. This means that adjusting the loss scale, or using + dynamic loss scaling, should not require retuning the learning rate or any other + hyperparameters. + + + **Advanced options** + + **Closures**: :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure. + See docstring for :attr:`step`. + + **Gradient clipping**: Use :attr:`clip_master_grads`. + + **Multiple losses**: If your model accumulates gradients from multiple losses, + this can be made more efficient by supplying ``update_master_grads=False`` + to :attr:`backward`. See docstring for :attr:`backward`. + + **Manually adjusting loss scale**: The current loss scale can be retrieved or set via :: + + print(optimizer.loss_scale) + optimizer.loss_scale = new_loss_scale + + For static loss scaling, manually adjusting the loss scale over time is a reasonable + thing to do. During later epochs, gradients may become smaller, and a + higher loss scale may be required, analogous to scheduling the learning rate. Dynamic loss + scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting + the loss scale is not recommended. + + **Multi_GPU training**: If the wrapped ``init_optimizer`` was created from a model wrapped in + Pytorch DataParallel or DistributedDataParallel, :class:`FP16_Optimizer` should still work as + intended. + """ + + def __init__(self, + init_optimizer, + static_loss_scale=1.0, + dynamic_loss_scale=False, + dynamic_loss_args=None): + if not torch.cuda.is_available: + raise SystemError("Cannot use fp16 without CUDA.") + + self.fp16_groups = [] + self.fp32_from_fp16_groups = [] + self.fp32_from_fp32_groups = [] + for i, param_group in enumerate(init_optimizer.param_groups): + print("FP16_Optimizer processing param group {}:".format(i)) + fp16_params_this_group = [] + fp32_params_this_group = [] + for param in param_group['params']: + if param.requires_grad: + if param.type() == 'torch.cuda.HalfTensor': + print("FP16_Optimizer received torch.cuda.HalfTensor with {}" + .format(param.size())) + fp16_params_this_group.append(param) + elif param.type() == 'torch.cuda.FloatTensor': + print("FP16_Optimizer received torch.cuda.FloatTensor with {}" + .format(param.size())) + fp32_params_this_group.append(param) + else: + raise TypeError("Wrapped parameters must be either " + "torch.cuda.FloatTensor or torch.cuda.HalfTensor. " + "Received {}".format(param.type())) + + fp32_from_fp16_params_this_group = [param.detach().clone().float() + for param in fp16_params_this_group] + for param in fp32_from_fp16_params_this_group: + param.requires_grad = True + + param_group['params'] = fp32_from_fp16_params_this_group + fp32_params_this_group + + self.fp16_groups.append(fp16_params_this_group) + self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group) + self.fp32_from_fp32_groups.append(fp32_params_this_group) + + self.optimizer = init_optimizer.__class__(init_optimizer.param_groups) + + if dynamic_loss_scale: + self.dynamic_loss_scale = True + if dynamic_loss_args is not None: + self.loss_scaler = DynamicLossScaler(**dynamic_loss_args) + else: + self.loss_scaler = DynamicLossScaler() + else: + self.dynamic_loss_scale = False + self.loss_scaler = LossScaler(static_loss_scale) + + self.overflow = False + self.first_closure_call_this_step = True + + # Promote optimizer.state, and optimizer.param_groups, to accommodate user code that + # directly manipulates "optimizer.param_groups" (for example, to adjust the learning rate). + def __getattribute__(self, name): + # I could condense the two cases by saying + # if name in ['state', 'param_groups']: + # return self.optimizer.__dict__[name], + # but this would bypass self.optimizer's custom getters and setters, if it chose to define any. + # I could also use properties, as for loss_scale, but I have no idea if properties bypass + # self.optimizer's custom getters and setters. + if name == 'state': + return self.optimizer.state + elif name == 'param_groups': + return self.optimizer.param_groups + else: + return object.__getattribute__(self, name) + + def __setattr__(self, name, value): + if name == 'state': + self.optimizer.state = value + elif name == 'param_groups': + self.optimizer.param_groups = value + else: + object.__setattr__(self, name, value) + + def zero_grad(self): + """ + Zero fp32 and fp16 parameter grads. + """ + self.optimizer.zero_grad() + for fp16_group in self.fp16_groups: + for param in fp16_group: + if param.grad is not None: + param.grad.detach_() # as in torch.optim.optimizer.zero_grad() + param.grad.zero_() + + def _check_overflow(self): + params = [] + for group in self.fp16_groups: + for param in group: + params.append(param) + for group in self.fp32_from_fp32_groups: + for param in group: + params.append(param) + self.overflow = self.loss_scaler.has_overflow(params) + + def _update_scale(self, has_overflow=False): + self.loss_scaler.update_scale(has_overflow) + + # TODO: Register a hook on each variable to do the overflow check, gradient copy + downscale, + # fp32 allreduce for distributed in a different stream. Debatable which ops should be + # treated that way, but it'll be fun to play with. + def _model_grads_to_master_grads(self): + for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): + model_grads_to_master_grads(fp16_group, fp32_from_fp16_group) + + def _downscale_master(self): + if self.loss_scale != 1.0: + # print("downscaling fp32 gradients") + for group in self.optimizer.param_groups: + for param in group['params']: + param.grad.data.mul_(1./self.loss_scale) + + def clip_master_grads(self, max_norm, norm_type=2): + """ + Clips fp32 master gradients via torch.nn.utils.clip_grad_norm. + + Args: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + + Returns: + Total norm of the current fp32 gradients (viewed as a single vector). + + .. warning:: + Returns -1 if the most recently computed fp16 gradients overflowed (that is, if self.overflow is True). + """ + if not self.overflow: + fp32_params = [] + for param_group in self.optimizer.param_groups: + for param in param_group['params']: + fp32_params.append(param) + return torch.nn.utils.clip_grad_norm(fp32_params, max_norm, norm_type) + else: + return -1 + + def _master_params_to_model_params(self): + for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups): + master_params_to_model_params(fp16_group, fp32_from_fp16_group) + + def state_dict(self): + """ + Returns a dict containing the current state of this :class:`FP16_Optimizer` instance. + This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict + of the contained Pytorch optimizer. + Example:: + + checkpoint = {} + checkpoint['model'] = model.state_dict() + checkpoint['optimizer'] = optimizer.state_dict() + torch.save(checkpoint, "saved.pth") + """ + state_dict = {} + state_dict['loss_scaler'] = self.loss_scaler + state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale + state_dict['overflow'] = self.overflow + state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step + state_dict['optimizer_state_dict'] = self.optimizer.state_dict() + state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups + return state_dict + + def load_state_dict(self, state_dict): + """ + Loads a state_dict created by an earlier call to state_dict(). + If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``, + whose parameters in turn came from ``model``, it is expected that the user + will call ``model.load_state_dict()`` before + ``fp16_optimizer_instance.load_state_dict()`` is called. + + Example:: + + model = torch.nn.Linear(D_in, D_out).cuda().half() + optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) + optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0) + ... + checkpoint = torch.load("saved.pth") + model.load_state_dict(checkpoint['model']) + optimizer.load_state_dict(checkpoint['optimizer']) + """ + # I think it should actually be ok to reload the optimizer before the model. + self.loss_scaler = state_dict['loss_scaler'] + self.dynamic_loss_scale = state_dict['dynamic_loss_scale'] + self.overflow = state_dict['overflow'] + self.first_closure_call_this_step = state_dict['first_closure_call_this_step'] + self.optimizer.load_state_dict(state_dict['optimizer_state_dict']) + # At this point, the optimizer's references to the model's fp32 parameters are up to date. + # The optimizer's hyperparameters and internal buffers are also up to date. + # However, the fp32 master copies of the model's fp16 params stored by the optimizer are now + # out of date. There are two options. + # 1: Refresh the master params from the model's fp16 params. + # This requires less storage but incurs precision loss. + # 2: Save and restore the fp32 master copies separately. + # We choose option 2. + # + # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device + # of # their associated parameters, because it's possible those buffers might not exist yet in + # the current optimizer instance. In our case, as long as the current FP16_Optimizer has been + # constructed in the same way as the one whose state_dict we are loading, the same master params + # are guaranteed to exist, so we can just copy_() from the saved master params. + for current_group, saved_group in zip(self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']): + for current, saved in zip(current_group, saved_group): + current.data.copy_(saved.data) + + def step(self, closure=None): # could add clip option. + """ + If no closure is supplied, step should be called after ``fp16_optimizer_obj.backward(loss)``. + step updates the fp32 master copy of parameters using the optimizer supplied to + :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params + originally referenced by Fp16_Optimizer's constructor, so the user may immediately run + another forward pass using their model. + + If a closure is supplied, :attr:`step` may be called without a prior call to + :attr:`backward(loss)`. + This control flow is identical to `ordinary Pytorch optimizer use`_ with closures. + However, the user should take care that any ``loss.backward()`` call within the closure + has been replaced by ``fp16_optimizer_obj.backward(loss)``. + + Args: + closure (optional): Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor. closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss. + + Example with closure:: + + # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an + # existing pytorch optimizer. + for input, target in dataset: + def closure(): + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + optimizer.backward(loss) + return loss + optimizer.step(closure) + + .. warning:: + Currently, calling step with a closure is not compatible with dynamic loss scaling. + + .. _`ordinary Pytorch optimizer use`: + http://pytorch.org/docs/master/optim.html#optimizer-step-closure + """ + if closure is not None and isinstance(self.loss_scaler, DynamicLossScaler): + raise TypeError("Using step with a closure is currently not " + "compatible with dynamic loss scaling.") + + scale = self.loss_scaler.loss_scale + self._update_scale(self.overflow) + + if self.overflow: + print("OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}" + .format(scale, self.loss_scale)) + return + + if closure is not None: + self._step_with_closure(closure) + else: + self.optimizer.step() + + self._master_params_to_model_params() + + return + + def _step_with_closure(self, closure): + def wrapped_closure(): + if self.first_closure_call_this_step: + # We expect that the fp16 params are initially fresh on entering self.step(), + # so _master_params_to_model_params() is unnecessary the first time wrapped_closure() + # is called within self.optimizer.step(). + self.first_closure_call_this_step = False + else: + # If self.optimizer.step() internally calls wrapped_closure more than once, + # it may update the fp32 params after each call. However, self.optimizer + # doesn't know about the fp16 params at all. If the fp32 params get updated, + # we can't rely on self.optimizer to refresh the fp16 params. We need + # to handle that manually: + self._master_params_to_model_params() + # Our API expects the user to give us ownership of the backward() call by + # replacing all calls to loss.backward() with optimizer.backward(loss). + # This requirement holds whether or not the call to backward() is made within + # a closure. + # If the user is properly calling optimizer.backward(loss) within "closure," + # calling closure() here will give the fp32 master params fresh gradients + # for the optimizer to play with, + # so all wrapped_closure needs to do is call closure() and return the loss. + temp_loss = closure() + return temp_loss + + self.optimizer.step(wrapped_closure) + + self.first_closure_call_this_step = True + + def backward(self, loss, update_master_grads=True): + """ + :attr:`backward` performs the following conceptual operations: + + fp32_loss = loss.float() (see first Note below) + + scaled_loss = fp32_loss*loss_scale + + scaled_loss.backward(), which accumulates scaled gradients into the .grad attributes of the + model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined). + + fp16 grads are then copied to the master params' .grad attributes (see second Note), which + are guaranteed to be fp32. + + Finally, master grads are divided by loss_scale. + + In this way, after :attr:`backward`, the master params have fresh gradients, + and :attr:`step` may be called. + + .. note:: + :attr:`backward` internally converts the loss to fp32 before applying the loss scale. + This provides some additional safety against overflow if the user has supplied an + fp16 loss value. + However, for maximum overflow safety, the user should + compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to + :attr:`backward`. + + .. note:: + The gradients found in a model's leaves after the call to + `backward` should not be regarded as valid in general, + because it's possible + they have been scaled (and in the case of dynamic loss scaling, + the scale factor may change over time). + If the user wants to inspect gradients after a call to `backward`, + only the master gradients should be regarded as valid. These can be retrieved via + :attr:`inspect_master_grad_data()`. + + + Args: + loss: The loss output by the user's model. loss may be either float or half (but see first Note above). + update_master_grads (bool, optional, default=True): Option to copy fp16 grads to fp32 grads on this call. By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if fp16_optimizer_obj.backward is being called on multiple losses in one iteration. If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`. + + Example:: + + # Ordinary operation: + optimizer.backward(loss) + + # Naive operation with multiple losses (technically valid, but less efficient): + # fp32 grads will be correct after the second call, but + # the first call incurs an unnecessary fp16->fp32 grad copy. + optimizer.backward(loss1) + optimizer.backward(loss2) + + # More efficient way to handle multiple losses: + # The fp16->fp32 grad copy is delayed until fp16 grads from all + # losses have been accumulated. + optimizer.backward(loss1, update_master_grads=False) + optimizer.backward(loss2, update_master_grads=False) + optimizer.update_master_grads() + """ + # To think about: try multiple backward passes using retain_grad=True to find + # a loss scale that works. After you find a loss scale that works, do a final dummy + # backward pass with retain_graph=False to tear down the graph. + # Doing this would avoid discarding the iteration, but probably wouldn't + # improve overall efficiency. + self.loss_scaler.backward(loss.float()) + if update_master_grads: + self.update_master_grads() + + def update_master_grads(self): + """ + Copy the ``.grad`` attribute from stored references to fp16 parameters to + the ``.grad`` attribute of the fp32 master parameters that are directly + updated by the optimizer. :attr:`update_master_grads` only needs to be called if + ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``. + """ + if self.dynamic_loss_scale: + self._check_overflow() + if self.overflow: return + self._model_grads_to_master_grads() + self._downscale_master() + + def inspect_master_grad_data(self): + """ + When running with :class:`FP16_Optimizer`, + ``.grad`` attributes of a model's fp16 leaves should not be + regarded as truthful, because they might be scaled. + After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered, + the fp32 master params' ``.grad`` + attributes will contain valid gradients properly divided by the loss scale. However, + because :class:`FP16_Optimizer` flattens some parameters, accessing them may be + nonintuitive. :attr:`inspect_master_grad_data` + allows those gradients to be viewed with shapes corresponding to their associated model leaves. + + Returns: + List of lists (one list for each parameter group). The list for each parameter group + is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group. + """ + raise NotImplementedError("Currently not implemented, working on it...") + fp32_grads_each_group = [] + if self.overflow: + print("Warning: calling FP16_Optimizer.inspect_master_grad_data while in an overflow state. " + "Gradients are currently invalid (may be inf, nan, or stale). Returning None.") + return None + else: + return None + + def _get_loss_scale(self): + return self.loss_scaler.loss_scale + + def _set_loss_scale(self, value): + self.loss_scaler.cur_scale = value + + loss_scale = property(_get_loss_scale, _set_loss_scale) diff --git a/apex/fp16_utils/fp16util.py b/apex/fp16_utils/fp16util.py new file mode 100644 index 000000000..ae7cfb184 --- /dev/null +++ b/apex/fp16_utils/fp16util.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors + +class tofp16(nn.Module): + """ + Model wrapper that implements:: + + def forward(self, input): + return input.half() + """ + + def __init__(self): + super(tofp16, self).__init__() + + def forward(self, input): + return input.half() + + +def BN_convert_float(module): + ''' + Designed to work with network_to_half. + BatchNorm layers need parameters in single precision. + Find all layers and convert them back to float. This can't + be done with built in .apply as that function will apply + fn to all modules, parameters, and buffers. Thus we wouldn't + be able to guard the float conversion based on the module type. + ''' + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + module.float() + for child in module.children(): + BN_convert_float(child) + return module + + +def network_to_half(network): + """ + Convert model to half precision in a batchnorm-safe way. + """ + return nn.Sequential(tofp16(), BN_convert_float(network.half())) + + +def backwards_debug_hook(grad): + raise RuntimeError("master_params recieved a gradient in the backward pass!") + +def prep_param_lists(model, flat_master=False): + r""" + Creates a list of FP32 master parameters for a given model, as in + `Training Neural Networks with Mixed Precision: Real Examples`_. + + Args: + model (torch.nn.Module): Existing Pytorch model + flat_master (bool, optional, default=False): Flatten the master parameters into a single tensor, as a performance optimization. + Returns: + A tuple (``model_params``, ``master_params``). ``model_params`` is a list of the model's parameters for later use with :func:`model_grads_to_master_grads` and :func:`master_params_to_model_params`. ``master_params`` is a list of FP32 master gradients. If ``flat_master=True``, ``master_params`` will be a list with one element. + + Example:: + + model_params, master_params = prep_param_lists(model) + + .. warning:: + Currently, if ``flat_master=True``, all the model's parameters must be the same type. If the model has parameters of different types, use ``flat_master=False``, or use :class:`FP16_Optimizer`. + + .. _`Training Neural Networks with Mixed Precision: Real Examples`: + http://on-demand.gputechconf.com/gtc/2018/video/S81012/ + """ + model_params = [param for param in model.parameters() if param.requires_grad] + + if flat_master: + # flatten_dense_tensors returns a contiguous flat array. + # http://pytorch.org/docs/master/_modules/torch/_utils.html + master_params = _flatten_dense_tensors([param.data for param in model_params]).float() + master_params = torch.nn.Parameter(master_params) + master_params.requires_grad = True + # master_params.register_hook(backwards_debug_hook) + if master_params.grad is None: + master_params.grad = master_params.new(*master_params.size()) + return model_params, [master_params] + else: + master_params = [param.detach().clone().float() for param in model_params] + for param in master_params: + param.requires_grad = True + return model_params, master_params + + +def model_grads_to_master_grads(model_params, master_params, flat_master=False): + """ + Copy model gradients to master gradients. + + Args: + model_params: List of model parameters created by :func:`prep_param_lists`. + master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`model_grads_to_master_grads`. + """ + if flat_master: + # The flattening may incur one more deep copy than is necessary. + master_params[0].grad.data.copy_( + _flatten_dense_tensors([p.grad.data for p in model_params])) + else: + for model, master in zip(model_params, master_params): + if model.grad is not None: + if master.grad is None: + master.grad = Variable(master.data.new(*master.data.size())) + master.grad.data.copy_(model.grad.data) + else: + master.grad = None + + +def master_params_to_model_params(model_params, master_params, flat_master=False): + """ + Copy master parameters to model parameters. + + Args: + model_params: List of model parameters created by :func:`prep_param_lists`. + master_params: List of FP32 master parameters created by :func:`prep_param_lists`. If ``master_params`` was created with ``flat_master=True``, ``flat_master=True`` should also be supplied to :func:`master_params_to_model_params`. + """ + if flat_master: + for model, master in zip(model_params, + _unflatten_dense_tensors(master_params[0].data, model_params)): + model.data.copy_(master) + else: + for model, master in zip(model_params, master_params): + model.data.copy_(master.data) diff --git a/apex/fp16_utils/fused_weight_norm.py b/apex/fp16_utils/fused_weight_norm.py new file mode 100644 index 000000000..8cee0c617 --- /dev/null +++ b/apex/fp16_utils/fused_weight_norm.py @@ -0,0 +1,105 @@ +import torch +from torch.autograd import Variable +from torch.autograd.function import Function, once_differentiable +import apex._C + +def check_contig_cuda(tensors, names): + for tensor, name in zip(tensors, names): + if not tensor.is_contiguous(): + raise RuntimeError(name+" with size {} is not contiguous" + .format(tensor.size())) + if not tensor.is_cuda: + raise RuntimeError(name+".is_cuda = False." + "Currently, only cuda tensors are supported.") + +class Fused_Weight_Norm(Function): + """ + Implements weight norm along a tensor's slowest dimension using fused kernel launches for + the forward and backward pass. + Accepts fp32 or fp16 input; the output type will match the input type. + Within the kernels, all calculations are performed in fp32 for numerical stability, regardless + of input/output precision. + """ + + @staticmethod + def forward(ctx, input, g, dim=0): + """ + :attr:`input` is assumed to be contiguous. + :attr:`input` may be either float or half precision. + The precision of :attr:`output` will match the precision of :attr:`input`. + A float copy of the L2 norm across each slow dimension + is also created and saved for the backward pass. + """ + # torch.cuda.nvtx.range_push("FusedNorm.forward, input.size() = {}" + # .format(input.size())) + + check_contig_cuda((input,g),("input","g")) + + """ + This is ok, new() treats a torch.Size object properly. + No need to unpack with an asterisk via new(*input.size()). + """ + output = input.new(input.size()).contiguous() + + """ + For output with size (slow, faster, faster, ...fastest), we may want + norms with size (slow, 1, 1, ...1), so that if you want retrieve norms + and apply the same normalizing factors to another Tensor "t" with the + same size as output, "t/norms" will broadcast each element of norms + across the corresponding slowest dim of t. + """ + if dim == 0: + norm_size = (output.size(0),) + (1,)*(output.dim() - 1) + elif dim == output.dim() - 1: + norm_size = (1,)*(output.dim() - 1) + (output.size(-1),) + else: + raise RuntimeError("Currently, Fused_Weight_Norm only supports first or last dimension.") + + norms = torch.cuda.FloatTensor(*norm_size).contiguous() + """ + Beware: If you call the following: + norms = torch.cuda.FloatTensor(norm_size).contiguous() + the constructor sees a tuple: + FloatTensor( (output_size(0),1,1,...) ) + and creates a 1D tensor with values from the tuple: + [output_size(0),1,1,...]. + """ + + apex._C.weight_norm_fwd(output, norms, input, g, dim) + ctx.save_for_backward(input, g) + + # save_for_backward can only save input or output tensors, + # use ctx state to save the norms and dimension: + ctx.norms = norms + ctx.dim = dim + + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + """ + :attr:`grad_output` is assumed to be contiguous. + :attr:`grad_output` may be either float or half precision. + The precision of :attr:`grad_input` will match the precision of :attr:`grad_output`. + """ + check_contig_cuda((grad_output), ("grad_output")) + + savedInput, savedg = ctx.saved_tensors + savedNorms = ctx.norms + + # better safe than sorry + grad_output_contig = grad_output.contiguous() + + grad_input = grad_output_contig.new(grad_output.size()).contiguous() + grad_g = savedg.new(savedg.size()).contiguous() + + apex._C.weight_norm_bwd(grad_input, + grad_g, + grad_output_contig, + savedInput, + savedg, + savedNorms, + ctx.dim) + + return grad_input, grad_g, None diff --git a/apex/fp16_utils/loss_scaler.py b/apex/fp16_utils/loss_scaler.py new file mode 100644 index 000000000..4739c0fb4 --- /dev/null +++ b/apex/fp16_utils/loss_scaler.py @@ -0,0 +1,185 @@ +""" +Top of loss_scaler.py stub. Can't figure out a way to get the module file +highlighted in a pretty way, or link back to source. +""" +import torch + +# item() is a recent addition, so this helps with backward compatibility. +def to_python_float(t): + if hasattr(t, 'item'): + return t.item() + else: + return t[0] + +class LossScaler: + """ + Class that manages a static loss scale. This class is intended to interact with + :class:`FP16_Optimizer`, and should not be directly manipulated by the user. + + Use of LossScaler is enabled via the ``static_loss_scale`` argument to + :class:`FP16_Optimizer`'s constructor. + """ + + def __init__(self, scale=1): + self.cur_scale = scale + + # `params` is a list / generator of torch.Variable + def has_overflow(self, params): + return False + + # `x` is a torch.Tensor + def _has_inf_or_nan(x): + return False + + # `overflow` is boolean indicating whether we overflowed in gradient + def update_scale(self, overflow): + pass + + @property + def loss_scale(self): + return self.cur_scale + + def scale_gradient(self, module, grad_in, grad_out): + return tuple(self.loss_scale * g for g in grad_in) + + def backward(self, loss): + scaled_loss = loss*self.loss_scale + scaled_loss.backward() + +class DynamicLossScaler: + """ + Class that manages dynamic loss scaling. It is recommended to use :class:`DynamicLossScaler` + indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of + :class:`FP16_Optimizer`. However, it's important to understand how :class:`DynamicLossScaler` + operates, because the default options can be changed using the + the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor. + + Loss scaling is designed to combat the problem of underflowing gradients encountered at long + times when training FP16 networks. Dynamic loss scaling begins by attempting a very high loss + scale. Ironically, this may result in OVERflowing gradients. If overflowing gradients are + encountered, DynamicLossScaler informs :class:`FP16_Optimizer` that an overflow has occurred. + :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch, + and :class:`DynamicLossScaler` adjusts the loss scale to a lower value. + If a certain number of iterations occur without overflowing gradients detected, + :class:`DynamicLossScaler` increases the loss scale once more. + In this way :class:`DynamicLossScaler` attempts to "ride the edge" of + always using the highest loss scale possible without incurring overflow. + + Args: + init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` + scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. + scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. + """ + + def __init__(self, + init_scale=2**32, + scale_factor=2., + scale_window=1000): + self.cur_scale = init_scale + self.cur_iter = 0 + self.last_overflow_iter = -1 + self.scale_factor = scale_factor + self.scale_window = scale_window + + # `params` is a list / generator of torch.Variable + def has_overflow(self, params): + for p in params: + if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data): + return True + + return False + + # `x` is a torch.Tensor + def _has_inf_or_nan(x): + try: + # Stopgap until upstream fixes sum() on HalfTensors + cpu_sum = float(x.float().sum()) + # cpu_sum = float(x.sum()) + # print(cpu_sum) + except RuntimeError as instance: + # We want to check if inst is actually an overflow exception. + # RuntimeError could come from a different error. + # If so, we still want the exception to propagate. + if "value cannot be converted" not in instance.args[0]: + raise + return True + else: + if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum: + return True + return False + + # `overflow` is boolean indicating whether we overflowed in gradient + def update_scale(self, overflow): + if overflow: + # self.cur_scale /= self.scale_factor + self.cur_scale = max(self.cur_scale/self.scale_factor, 1) + self.last_overflow_iter = self.cur_iter + else: + if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0: + self.cur_scale *= self.scale_factor + self.cur_iter += 1 + + @property + def loss_scale(self): + return self.cur_scale + + def scale_gradient(self, module, grad_in, grad_out): + return tuple(self.loss_scale * g for g in grad_in) + + def backward(self, loss): + scaled_loss = loss*self.loss_scale + scaled_loss.backward() + +############################################################## +# Example usage below here -- assuming it's in a separate file +############################################################## +""" +TO-DO separate out into an example. +if __name__ == "__main__": + import torch + from torch.autograd import Variable + from dynamic_loss_scaler import DynamicLossScaler + + # N is batch size; D_in is input dimension; + # H is hidden dimension; D_out is output dimension. + N, D_in, H, D_out = 64, 1000, 100, 10 + + # Create random Tensors to hold inputs and outputs, and wrap them in Variables. + x = Variable(torch.randn(N, D_in), requires_grad=False) + y = Variable(torch.randn(N, D_out), requires_grad=False) + + w1 = Variable(torch.randn(D_in, H), requires_grad=True) + w2 = Variable(torch.randn(H, D_out), requires_grad=True) + parameters = [w1, w2] + + learning_rate = 1e-6 + optimizer = torch.optim.SGD(parameters, lr=learning_rate) + loss_scaler = DynamicLossScaler() + + for t in range(500): + y_pred = x.mm(w1).clamp(min=0).mm(w2) + loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale + print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale)) + print('Iter {} scaled loss: {}'.format(t, loss.data[0])) + print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale)) + + # Run backprop + optimizer.zero_grad() + loss.backward() + + # Check for overflow + has_overflow = DynamicLossScaler.has_overflow(parameters) + + # If no overflow, unscale grad and update as usual + if not has_overflow: + for param in parameters: + param.grad.data.mul_(1. / loss_scaler.loss_scale) + optimizer.step() + # Otherwise, don't do anything -- ie, skip iteration + else: + print('OVERFLOW!') + + # Update loss scale for next iteration + loss_scaler.update_scale(has_overflow) + +""" diff --git a/apex/parallel/__init__.py b/apex/parallel/__init__.py new file mode 100644 index 000000000..c59a7f7da --- /dev/null +++ b/apex/parallel/__init__.py @@ -0,0 +1 @@ +from .distributed import DistributedDataParallel diff --git a/apex/parallel/distributed.py b/apex/parallel/distributed.py new file mode 100644 index 000000000..9f5a9c02c --- /dev/null +++ b/apex/parallel/distributed.py @@ -0,0 +1,197 @@ +import torch +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +import torch.distributed as dist +from torch.nn.modules import Module +from torch.autograd import Variable + +def flat_dist_call(tensors, call, extra_args=None): + flat_dist_call.warn_on_half = True + buckets = {} + for tensor in tensors: + tp = tensor.type() + if tp not in buckets: + buckets[tp] = [] + buckets[tp].append(tensor) + + if flat_dist_call.warn_on_half: + if torch.cuda.HalfTensor in buckets: + print("WARNING: gloo dist backend for half parameters may be extremely slow." + + " It is recommended to use the NCCL backend in this case.") + flat_dist_call.warn_on_half = False + + for tp in buckets: + bucket = buckets[tp] + coalesced = _flatten_dense_tensors(bucket) + if extra_args is not None: + call(coalesced, *extra_args) + else: + call(coalesced) + coalesced /= dist.get_world_size() + for buf, synced in zip(bucket, _unflatten_dense_tensors(coalesced, bucket)): + buf.copy_(synced) + +class DistributedDataParallel(Module): + """ + :class:`DistributedDataParallel` is a simpler version of upstream :class:` + DistributedDataParallel` that is optimized for use with NCCL. Its usage is designed + to be used in conjunction with apex.parallel.multiproc.py. It assumes that your run + is using multiprocess with 1 GPU/process, that the model is on the correct device, + and that torch.set_device has been used to set the device. Parameters are broadcasted + to the other processes on initialization of DistributedDataParallel, and will be + allreduced in buckets durring the backward pass. + + See https://github.com/csarofeen/examples/tree/apex/distributed for detailed usage. + + Args: + module: Network definition to be run in multi-gpu/distributed mode. + message_size (Default = 10000000): Minimum number of elements in a communication bucket. + + + """ + + def __init__(self, module, message_size=10000000): + super(DistributedDataParallel, self).__init__() + self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False + + self.message_size = message_size + + #reference to last iterations parameters to see if anything has changed + self.param_refs = [] + + self.reduction_stream = torch.cuda.Stream() + + self.module = module + self.param_list = list(self.module.parameters()) + + if dist._backend == dist.dist_backend.NCCL: + for param in self.param_list: + assert param.is_cuda, "NCCL backend only supports model parameters to be on GPU." + + self.record = [] + self.create_hooks() + + flat_dist_call([param.data for param in self.module.parameters()], dist.broadcast, (0,) ) + + def create_hooks(self): + #all reduce gradient hook + def allreduce_params(): + if(self.needs_reduction): + self.needs_reduction = False + self.needs_refresh = False + else: + return + grads = [param.grad.data for param in self.module.parameters() if param.grad is not None] + flat_dist_call(grads, dist.all_reduce) + t_record = torch.cuda.IntTensor(self.record) + dist.broadcast(t_record, 0) + self.record = [int(entry) for entry in t_record] + + + def flush_buckets(): + if not self.needs_reduction: + return + self.needs_reduction = False + + ready = [] + for i in range(len(self.param_state)): + if self.param_state[i] == 1: + param = self.param_list[self.record[i]] + if param.grad is not None: + ready.append(param.grad.data) + + if(len(ready)>0): + orig_stream = torch.cuda.current_stream() + with torch.cuda.stream(self.reduction_stream): + self.reduction_stream.wait_stream(orig_stream) + flat_dist_call(ready, dist.all_reduce) + + torch.cuda.current_stream().wait_stream(self.reduction_stream) + + for param_i, param in enumerate(list(self.module.parameters())): + def wrapper(param_i): + + def allreduce_hook(*unused): + if self.needs_refresh: + self.record.append(param_i) + Variable._execution_engine.queue_callback(allreduce_params) + else: + Variable._execution_engine.queue_callback(flush_buckets) + self.param_state[self.record.index(param_i)] = 1 + self.comm_ready_buckets() + + + if param.requires_grad: + param.register_hook(allreduce_hook) + wrapper(param_i) + + + def comm_ready_buckets(self): + + ready = [] + counter = 0 + + while counter < len(self.param_state) and self.param_state[counter] == 2: + counter += 1 + + while counter < len(self.param_state) and self.param_state[counter] == 1: + ready.append(counter) + counter += 1 + + if not ready: + return + + grads = [] + for ind in ready: + param_ind = self.record[ind] + if self.param_list[param_ind].grad is not None: + grads.append(self.param_list[param_ind].grad.data) + + bucket = [] + bucket_inds = [] + while grads: + bucket.append(grads.pop(0)) + bucket_inds.append(ready.pop(0)) + + cumm_size = 0 + for ten in bucket: + cumm_size += ten.numel() + + if cumm_size < self.message_size: + continue + + evt = torch.cuda.Event() + evt.record(torch.cuda.current_stream()) + evt.wait(stream=self.reduction_stream) + + with torch.cuda.stream(self.reduction_stream): + flat_dist_call(bucket, dist.all_reduce) + + for ind in bucket_inds: + self.param_state[ind] = 2 + + def forward(self, *inputs, **kwargs): + """ + Forward function for DDP. + Args: + inputs: inputs that match the module's passed in for initialization. + kwargs: kwargs that match the module's passed in for initialization. + + """ + + param_list = [param for param in list(self.module.parameters()) if param.requires_grad] + + + + self.needs_refresh = True if not self.param_refs else any( + [param1 is not param2 for param1, param2 in zip(param_list, self.param_refs)] + ) + + if self.needs_refresh: + self.record = [] + + + self.param_state = [0 for i in range(len(param_list))] + self.param_refs = param_list + self.needs_reduction = True + + return self.module(*inputs, **kwargs) diff --git a/apex/parallel/multiproc.py b/apex/parallel/multiproc.py new file mode 100644 index 000000000..808b31d94 --- /dev/null +++ b/apex/parallel/multiproc.py @@ -0,0 +1,35 @@ +import torch +import sys +import subprocess + +def docstring_hack(): + """ + Multiproc file which will launcch a set of processes locally for multi-gpu + usage: python -m apex.parallel.multiproc main.py ... + """ + pass + +argslist = list(sys.argv)[1:] +world_size = torch.cuda.device_count() + +if '--world-size' in argslist: + argslist[argslist.index('--world-size')+1] = str(world_size) +else: + argslist.append('--world-size') + argslist.append(str(world_size)) + +workers = [] + +for i in range(world_size): + if '--rank' in argslist: + argslist[argslist.index('--rank')+1] = str(i) + else: + argslist.append('--rank') + argslist.append(str(i)) + stdout = None if i == 0 else open("GPU_"+str(i)+".log", "w") + print(argslist) + p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout) + workers.append(p) + +for p in workers: + p.wait() diff --git a/apex/reparameterization/__init__.py b/apex/reparameterization/__init__.py new file mode 100644 index 000000000..90f977883 --- /dev/null +++ b/apex/reparameterization/__init__.py @@ -0,0 +1,127 @@ +from .weight_norm import WeightNorm +from .reparameterization import Reparameterization + +def apply_weight_norm(module, name='', dim=0, hook_child=True): + """ + Applies weight normalization to a parameter in the given module. + If no parameter is provided, applies weight normalization to all + parameters in model (except 1-d vectors and scalars). + + .. math:: + \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} + + Weight normalization is a reparameterization that decouples the magnitude + of a weight tensor from its direction. This replaces the parameter specified + by `name` (e.g. "weight") with two parameters: one specifying the magnitude + (e.g. "weight_g") and one specifying the direction (e.g. "weight_v"). + Weight normalization is implemented via a hook that recomputes the weight + tensor from the magnitude and direction before every :meth:`~Module.forward` + call. + + By default, with `dim=0`, the norm is computed independently per output + channel/plane. To compute a norm over the entire weight tensor, use + `dim=None`. + + See https://arxiv.org/abs/1602.07868 + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + dim (int, optional): dimension over which to compute the norm + hook_child (boolean, optional): adds reparameterization hook to direct parent of the + parameters. If False, it's added to `module` instead. Default: True + + Returns: + The original module with the weight norm hook + + Example:: + + >>> m = apply_weight_norm(nn.Linear(20, 40), name='weight') + Linear (20 -> 40) + >>> m.weight_g.size() + torch.Size([40, 1]) + >>> m.weight_v.size() + torch.Size([40, 20]) + + """ + return apply_reparameterization(module, reparameterization=WeightNorm, hook_child=hook_child, + name=name, dim=dim) + +def remove_weight_norm(module, name='', remove_all=False): + """ + Removes the weight normalization reparameterization of a parameter from a module. + If no parameter is supplied then all weight norm parameterizations are removed. + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + Example: + >>> m = apply_weight_norm(nn.Linear(20, 40)) + >>> remove_weight_norm(m) + """ + return remove_reparameterization(module, reparameterization=WeightNorm, + name=name, remove_all=remove_all) + +def apply_reparameterization(module, reparameterization=None, name='', dim=0, hook_child=True): + """ + Applies a given weight reparameterization (such as weight normalization) to + a parameter in the given module. If no parameter is given, applies the reparameterization + to all parameters in model (except 1-d vectors and scalars). + + Args: + module (nn.Module): containing module + reparameterization (Reparameterization): reparamaterization class to apply + name (str, optional): name of weight parameter + dim (int, optional): dimension over which to perform reparameterization op + hook_child (boolean, optional): adds reparameterization hook to direct parent of the + parameters. If False, it's added to `module` instead. Default: True + + Returns: + The original module with the reparameterization hook + + Example:: + + >>> m = apply_reparameterization(nn.Linear(20, 40), WeightNorm) + Linear (20 -> 40) + + """ + assert reparameterization is not None + if name != '': + Reparameterization.apply(module, name, dim, reparameterization, hook_child) + else: + names = list(module.state_dict().keys()) + for name in names: + apply_reparameterization(module, reparameterization, name, dim, hook_child) + return module + +def remove_reparameterization(module, reparameterization=Reparameterization, + name='', remove_all=False): + """ + Removes the given reparameterization of a parameter from a module. + If no parameter is supplied then all reparameterizations are removed. + Args: + module (nn.Module): containing module + reparameterization (Reparameterization): reparamaterization class to apply + name (str, optional): name of weight parameter + remove_all (bool, optional): if True, remove all reparamaterizations of given type. Default: False + Example: + >>> m = apply_reparameterization(nn.Linear(20, 40),WeightNorm) + >>> remove_reparameterization(m) + """ + if name != '' or remove_all: + to_remove = [] + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, reparameterization) and (hook.name == name or remove_all): + hook.remove(module) + to_remove.append(k) + if len(to_remove) > 0: + for k in to_remove: + del module._forward_pre_hooks[k] + return module + if not remove_all: + raise ValueError("reparameterization of '{}' not found in {}" + .format(name, module)) + else: + modules = [module]+[x for x in module.modules()] + for m in modules: + remove_reparameterization(m, reparameterization=reparameterization, remove_all=True) + return module diff --git a/apex/reparameterization/reparameterization.py b/apex/reparameterization/reparameterization.py new file mode 100644 index 000000000..5f70f857a --- /dev/null +++ b/apex/reparameterization/reparameterization.py @@ -0,0 +1,151 @@ +import torch +from torch.nn.parameter import Parameter +import sys +class Reparameterization(object): + """ + Class interface for performing weight reparameterizations + Arguments: + name (str): name of weight parameter + dim (int): dimension over which to compute the norm + module (nn.Module): parent module to which param `name` is registered to + retain_forward (bool, optional): if False deletes weight on call to + module.backward. Used to avoid memory leaks with DataParallel Default: True + Attributes: + reparameterization_names (list, str): contains names of all parameters + needed to compute reparameterization. + backward_hook_key (int): torch.utils.hooks.RemovableHandle.id for hook used in module backward pass. + """ + + def __init__(self, name, dim, module, retain_forward=True): + self.name = name + self.dim = dim + self.evaluated = False + self.retain_forward = retain_forward + self.reparameterization_names = [] + self.backward_hook_key = None + self.module = module + + def compute_weight(self, module=None, name=None): + """ + Computes reparameterized weight value to assign value to module attribute + with name `name`. + See WeightNorm class for example. + Arguments: + module (nn.Module): module with weight we'd like to reparameterize + Returns: + w (Tensor): Tensor object containing value of reparameterized weight + """ + raise NotImplementedError + + def reparameterize(self, name, weight, dim): + """ + Creates Parameters to be used for reparameterization and creates names that + for attributes for the module these Parameters will correspond to. + The parameters will be registered according to the names provided. + See WeightNorm class for example. + Arguments: + module (nn.Module): module with weight we'd like to reparameterize + name (str, optional): name of weight parameter + dim (int, optional): dimension over which to compute parameterization + Returns: + names (list, str): names of Parameters to be used for reparameterization + params (list, Parameter): Parameters to be used for reparameterization + """ + raise NotImplementedError + + @staticmethod + def apply(module, name, dim, reparameterization=None, hook_child=True): + """ + Applies reparametrization to module's `name` parameter and modifies instance attributes as appropriate. + `hook_child` adds reparameterization hook to direct parent of the parameters. If False, it's added to `module` instead. + """ + if reparameterization is None: + reparameterization = Reparameterization + module2use, name2use = Reparameterization.get_module_and_name(module, name) + # does not work on sparse + if name2use is None or isinstance(module2use, (torch.nn.Embedding, torch.nn.EmbeddingBag)): + return + + if hook_child: + fn = reparameterization(name2use, dim, module2use) + else: + fn = reparameterization(name, dim, module) + + weight = getattr(module2use, name2use) + if weight.dim() <= 1: + return + + # remove weight from parameter list + del module2use._parameters[name2use] + + # add parameters of reparameterization of parameter to module + names, params = fn.reparameterize(name2use, weight, dim) + for n, p in zip(names, params): + module2use.register_parameter(n, p) + + # add parameters to reparameterization so they can be removed later + fn.reparameterization_names = names + + setattr(module2use, name2use, None) + + hook_module = module2use + if not hook_child: + hook_module = module + # recompute weight before every forward() + hook_module.register_forward_pre_hook(fn) + + # remove weight during backward + handle = hook_module.register_backward_hook(fn.backward_hook) + # get hook key so we can delete it later + fn.backward_hook_key = handle.id + + return fn + + @staticmethod + def get_module_and_name(module, name): + """ + recursively fetches (possible) child module and name of weight to be reparameterized + """ + name2use = None + module2use = None + names = name.split('.') + if len(names) == 1 and names[0] != '': + name2use = names[0] + module2use = module + elif len(names) > 1: + module2use = module + name2use = names[0] + for i in range(len(names)-1): + module2use = getattr(module2use, name2use) + name2use = names[i+1] + return module2use, name2use + + def get_params(self, module): + """gets params of reparameterization based on known attribute names""" + return [getattr(module, n) for n in self.reparameterization_names] + + def remove(self, module): + """removes reparameterization and backward hook (does not remove forward hook)""" + module2use, name2use = Reparameterization.get_module_and_name(module, self.name) + for p in self.get_params(module2use): + p.requires_grad = False + weight = self.compute_weight(module2use, name2use) + delattr(module2use, name2use) + for n in self.reparameterization_names: + del module2use._parameters[n] + module2use.register_parameter(name2use, Parameter(weight.data)) + del module._backward_hooks[self.backward_hook_key] + + def __call__(self, module, inputs): + """callable hook for forward pass""" + module2use, name2use = Reparameterization.get_module_and_name(module, self.name) + _w = getattr(module2use, name2use) + if not self.evaluated or _w is None: + setattr(module2use, name2use, self.compute_weight(module2use, name2use)) + self.evaluated = True + + def backward_hook(self, module, grad_input, grad_output): + """callable hook for backward pass""" + module2use, name2use = Reparameterization.get_module_and_name(module, self.name) + wn = getattr(module2use, name2use) + self.evaluated = False diff --git a/apex/reparameterization/weight_norm.py b/apex/reparameterization/weight_norm.py new file mode 100644 index 000000000..eb35b0d8a --- /dev/null +++ b/apex/reparameterization/weight_norm.py @@ -0,0 +1,78 @@ +import torch +from torch.nn.parameter import Parameter +from ..fp16_utils import Fused_Weight_Norm +import time + +from .reparameterization import Reparameterization + +def _norm(p, dim): + """Computes the norm over all dimensions except dim""" + if dim is None: + return p.norm() + elif dim == 0: + output_size = (p.size(0),) + (1,) * (p.dim() - 1) + return p.contiguous().view(p.size(0), -1).norm(dim=1).view(*output_size) + elif dim == p.dim() - 1: + output_size = (1,) * (p.dim() - 1) + (p.size(-1),) + return p.contiguous().view(-1, p.size(-1)).norm(dim=0).view(*output_size) + return _norm(p.transpose(0, dim), 0).transpose(0, dim) + +HALF_TYPES = (torch.cuda.HalfTensor, torch.HalfTensor) + +class WeightNorm(Reparameterization): + """ + Weight normalization is a reparameterization that decouples the magnitude + of a weight tensor from its direction. This replaces the parameter specified + by `name` (e.g. "weight") with two parameters: one specifying the magnitude + (e.g. "weight_g") and one specifying the direction (e.g. "weight_v"). + Weight normalization is implemented via a hook that recomputes the weight + tensor from the magnitude and direction before every :meth:`~Module.forward` + call. + + .. math:: + \mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|} + + By default, with `dim=0`, the norm is computed independently per output + channel/plane. To compute a norm over the entire weight tensor, use + `dim=None`. + """ + def compute_weight(self, module=None, name=None): + """ + Computes weight normalized weight value to assign value to module attribute + with name `name`. + Arguments: + module (nn.Module): module with weight we'd like to reparameterize + Returns: + w (Tensor): Tensor object containing value of reparameterized weight + """ + if module is None: + module = self.module + if name is None: + name = self.name + module, name = Reparameterization.get_module_and_name(module, name) + g = getattr(module, name + '_g') + v = getattr(module, name + '_v') + + fused_weight_norm = Fused_Weight_Norm.apply + v = v.contiguous() + w = fused_weight_norm(v, g, self.dim) + + return w + + def reparameterize(self, name, weight, dim): + """ + Creates Parameters v and gto be used for weight normalization + and creates names that for attributes for the module these Parameters + will correspond to. The parameters will be registered according to the names + provided. + Arguments: + module (nn.Module): module with weight we'd like to reparameterize + name (str, optional): name of weight parameter + dim (int, optional): dimension over which to compute parameterization + Returns: + names (list, str): names of Parameters to be used for reparameterization + params (list, Parameter): Parameters to be used for reparameterization + """ + names = [name + '_g', name + '_v'] + params = [Parameter(_norm(weight, dim).data), Parameter(weight.data)] + return names, params diff --git a/csrc/Module.cpp b/csrc/Module.cpp new file mode 100644 index 000000000..82b9f024d --- /dev/null +++ b/csrc/Module.cpp @@ -0,0 +1,276 @@ +#define PY_SSIZE_T_CLEAN +#define ARG_OFFSET 5 + +#include + +#include +#include + +#include +#include +#include +#include +#include + +// #define USE_NVTX +#ifdef USE_NVTX +#include "nvToolsExt.h" +#endif + +//Meta-data format we will use +#include + +//Cuda kernels +#include + +#define ERROR_MSG cout << "Error at " << __FILE__ << ":" << __LINE__ << "\n"; + +using namespace std; + +TensorInfo PyOb_2_tinfo(PyObject* tensor, float_types data_type) +{ + PyObject* PyStrides = PyObject_CallMethod(tensor, "stride", NULL); + if(PyStrides == NULL) + { + ERROR_MSG; + cout << "PyStrides = NULL" << endl; + } + + PyObject* PySizes = PyObject_CallMethod(tensor, "size", NULL); + if(PySizes == NULL) + { + ERROR_MSG; + cout << "PySizes = NULL" << endl; + } + + PyObject* PyDataPtr = PyObject_CallMethod(tensor, "data_ptr", NULL); + if(PyDataPtr == NULL) + { + ERROR_MSG; + cout << "PyDataPtr = NULL" << endl; + } + + void* data_ptr = (void*) PyLong_AsLong(PyDataPtr); + Py_ssize_t ndims = PyList_GET_SIZE(PySizes); + + //TODO put proper checking on ndims < MAX_CUTORCH_DIMS + idxType strides[MAX_CUTORCH_DIMS], sizes[MAX_CUTORCH_DIMS]; + + for(int i = 0; i < ndims; i++) + { + strides[i] = PyLong_AsLong(PyTuple_GetItem(PyStrides, i)); + sizes[i] = PyLong_AsLong(PyTuple_GetItem(PySizes, i)); + } + + // Reference counts still behave strangely, but at least these appear to cap + // the process' memory usage. + Py_DECREF(PyStrides); + Py_DECREF(PySizes); + Py_DECREF(PyDataPtr); + + return TensorInfo(data_ptr, ndims, sizes, strides, data_type); +} + +vector > get_TInfos(PyObject* args) +{ + vector > info_vec; +#ifdef DEBUG_ANY + cout << "Processing " << PyTuple_GET_SIZE(args) << " arguments" << endl; +#endif + +#ifdef CHECK_MEMLEAK + for(int iter = 0; iter < 1e7; iter++ ) +#endif + for(Py_ssize_t i = 0; iob_type->tp_name); + + PyObject* pyObjTypeCall = PyObject_CallMethod(pyTensor, "type", NULL); + if(pyObjTypeCall == NULL) + { + ERROR_MSG; + cout << "For args item " << i << ", pyObjTypeCall = NULL" << endl; + } + + // This gives a segfault: + // cout << "pyObjTypeCall direct conversion attempt = " << + // PyBytes_AsString(pyObjTypeCall) << endl; + + PyObject* pyObjASCII = PyUnicode_AsASCIIString(pyObjTypeCall); + if(pyObjASCII == NULL) + { + ERROR_MSG; + cout << "For args item " << i << ", pyObjASCII = NULL " << endl; + } + + // cout << "Py_REFCNT(pyObjTypeCall) = " << Py_REFCNT(pyObjTypeCall) << endl; + Py_DECREF(pyObjTypeCall); + + string objTypeCall(PyBytes_AsString(pyObjASCII)); + + // cout << "Py_REFCNT(pyObjASCII) = " << Py_REFCNT(pyObjASCII) << endl; + Py_DECREF(pyObjASCII); + +#ifdef DEBUG_ANY + cout << "arg " << i << endl; + cout << "objType = " << objType << endl; + cout << "objTypeCall = " << objTypeCall << endl; +#endif + + if(objTypeCall == "torch.cuda.FloatTensor") +#ifdef CHECK_MEMLEAK + if(iter == 0 ) +#endif + info_vec.push_back(PyOb_2_tinfo(pyTensor, FLOAT)); +#ifdef CHECK_MEMLEAK + else + info_vec[i] = PyOb_2_tinfo(pyTensor, FLOAT); +#endif + else if(objTypeCall == "torch.cuda.HalfTensor") + info_vec.push_back(PyOb_2_tinfo(pyTensor, HALF)); + // Could add double + else + { + ERROR_MSG; + cout << "For args item " << i << ", unsupported .type() found: " + << objTypeCall << "\n" + "Supported types:\n" + "torch.cuda.FloatTensor\n" + "torch.cuda.HalfTensor\n" + "torch.autograd.variable.Variable containing FloatTensor\n" + "torch.autograd.variable.Variable containing HalfTensor\n" + "torch.nn.parameter.Parameter containing FloatTensor\n" + "torch.nn.parameter.Parameter containing HalfTensor\n" + << endl; + } + } + + // PyErr_SetString(PyExc_RuntimeError, "Exception set in "); + + return info_vec; +} + +int getLastArg_AsInt(PyObject* args) +{ + // None of these should return new references so I don't think this leaks memory. + int dims = PyLong_AsLong(PyTuple_GetItem(args, PyTuple_GET_SIZE(args) - 1)); + return dims; +} + +// Stepping stone, can evolve to be more general (argument forwarding?) +template +void dispatch +( + float_types rtti, + vector>& tensors, + int dim +) +{ + switch(rtti) + { + case FLOAT: + wrapper::template call(tensors, dim); + break; + case HALF: + wrapper::template call(tensors, dim); + break; + default: + std::cout << "Unsupported rtti in Module.cpp:dispatch()" << std::endl; + PyErr_SetString(PyExc_RuntimeError, "Unsupported data type in Module.cpp:dispatch, " + "supported data types are half and float"); + exit(-1); + } +} + +//Will extract all tensors in order. Assumes flat structure, tensors can not be wrapped in lists +//tuples or any other iterator structure. +static PyObject* weight_norm_fwd(PyObject* self, PyObject* args) +{ +#ifdef USE_NVTX +nvtxRangePushA("weight_norm_fwd C backend"); +#endif + + vector > tensors = get_TInfos(args); + int dim = getLastArg_AsInt(args); + + if(dim != 0 && dim != tensors[2].dims - 1) + PyErr_SetString(PyExc_RuntimeError, "weight_norm_fwd currently only " + "supports first or last dimension."); + else + { +#ifdef DEBUG_ANY + cout << "tensors.size() = " << tensors.size() << ", dim = " << dim << endl; +#endif + + dispatch(tensors[0].type, tensors, dim); + +#ifdef USE_NVTX + nvtxRangePop(); +#endif + } + + Py_RETURN_NONE; +} + +static PyObject* weight_norm_bwd(PyObject* self, PyObject* args) +{ +#ifdef USE_NVTX + nvtxRangePushA("weight_norm_bwd C backend"); +#endif + + vector >tensors = get_TInfos(args); + int dim = getLastArg_AsInt(args); + + if(dim != 0 && dim != tensors[3].dims - 1) + PyErr_SetString(PyExc_RuntimeError, "weight_norm_bwd currently only " + "supports first or last dimension."); + else + { +#ifdef DEBUG_ANY + cout << "tensors.size() = " << tensors.size() << ", dim = " << dim << endl; +#endif + + dispatch(tensors[0].type, tensors, dim); + +#ifdef USE_NVTX + nvtxRangePop(); +#endif + } + + Py_RETURN_NONE; +} + +//*******************PYTHON BOILER PLATE******************* +static PyMethodDef apex_methods[] = { + {"weight_norm_fwd", (PyCFunction) weight_norm_fwd, METH_VARARGS, "Slowest-dim norm, forward pass."}, + {"weight_norm_bwd", (PyCFunction) weight_norm_bwd, METH_VARARGS, "Slowest-dim norm, backward pass."}, + {NULL, NULL, 0, NULL} +}; + +#if PY_MAJOR_VERSION >= 3 + +//Module Definitions +static struct PyModuleDef apex = { + PyModuleDef_HEAD_INIT, "apex._C", "Module to add CUDA extensions to Pytorch.", -1, apex_methods +}; +//Initialization Function +PyMODINIT_FUNC PyInit__C(void){ + + //Let's throw an error if we can't find pytorch. + PyImport_ImportModule("torch"); + Py_Initialize(); + return PyModule_Create(&apex); +} +#else +PyMODINIT_FUNC initMODULE(void){ + //Let's throw an error if we can't find pytorch. + PyImport_ImportModule("torch"); + (void) Py_InitModule3("apex._C", apex, "A PyTorch Extension."); +} + +#endif +//********************************************************* + diff --git a/csrc/kernel.cu b/csrc/kernel.cu new file mode 100644 index 000000000..c7976a140 --- /dev/null +++ b/csrc/kernel.cu @@ -0,0 +1,473 @@ +#include "../include/kernel.h" + +template struct TtoInt { static const int test = -1; }; +template<> struct TtoInt { static const int test = 0; }; +template<> struct TtoInt { static const int test = 0; }; +template<> struct TtoInt { static const int test = 0; }; + +#if __CUDACC_VER_MAJOR__ >= 9 +#define __SHFL_DOWN(var, delta) __shfl_down_sync(0xffffffff, var, delta) +#else +#define __SHFL_DOWN(var, delta) __shfl_down(var, delta) +#endif + +#if __CUDACC_VER_MAJOR__ >= 9 +#define __SYNCWARP __syncwarp() +#else +#define __SYNCWARP +#endif + +// Block size for weight_norm_*_first_dim_kernel. +// Currently, kernels are non-persistent. +// Dialing up the block size to, say 1024, can improve performance by +// increase the amount of cache available per block, which can improve cache hit rate. +// However, this is less efficient for short rows. 256 is pretty versatile. +// Implement some heuristics later? +#define BLOCK 256 + +// Block size for weight_norm_*_last_dim_kernel. +// This is tricker than the first_dim case because we must make blocks +// at least 16 fast elements wide to ensure fully-coalesced half-precision accesses. +// Since output-element parallelism is along the fast dimension, this reduces the number of +// blocks we can launch by 16X. +#define TILE_W 16 +// Somewhat versatile strategy: max out intra-block parallelism by extending +// blocks across the slow dimension up to the hardware-max block size of 1024. +#define TILE_H 64 + +using namespace std; + +// lanes is intended to be <= 32. +template +__device__ __forceinline__ void reduce_block_into_lanes(T *x, T val, int lanes) +{ + int tid = threadIdx.x + threadIdx.y*blockDim.x; + int blockSize = blockDim.x*blockDim.y; + + if(blockSize >= 64) + { + x[tid] = val; + __syncthreads(); + } + + #pragma unroll + for(int i = (blockSize >> 1); i >= 64; i >>= 1) + { + if(tid < i) + x[tid] += x[tid+i]; // JoinOp + __syncthreads(); + } + + if(tid < 32) + { + T final; + if(blockSize >= 64) + final = x[tid] + x[tid+32]; // JoinOp + else + final = val; + // __SYNCWARP(); + + #pragma unroll + for(int i = 16; i >= lanes; i >>= 1) + final += __SHFL_DOWN(final, i); + + if(tid < lanes) + x[tid] = final; // EpilogueOp + } + + // Make sure the smem result is visible to all warps. + __syncthreads(); +} + +template +__global__ void weight_norm_fwd_first_dim_kernel +( + TensorInfo w, + TensorInfo norms, + TensorInfo v, + TensorInfo g, + IndexType rowSize +) +{ + // We are norming each slowest-dim row of the tensor separately. + // For now, assign one block to each row. + IndexType tid = threadIdx.x; + IndexType row = blockIdx.x; + IndexType stride = blockDim.x; + + // Logical index offset for this flattened row + IndexType rowStart = row*rowSize; + + extern __shared__ float s[]; + + float thread_sum = 0.f; + for(IndexType i = tid; i < rowSize; i += stride ) + { + float val_f = ScalarConvert::to(DEVICE_LINEAR_GET(v, i + rowStart)); + thread_sum += val_f*val_f; // AccumOp, could do Kahan here + } + + reduce_block_into_lanes(s, thread_sum, 1); + float result = s[0]; + + result = sqrtf(result); + + if(tid == 0) + DEVICE_LINEAR_GET_F(norms, row) = result; + + // Broadcast load, could use shared memory instead. + float g_this_row = ScalarConvert::to(DEVICE_LINEAR_GET(g, row)); + + float rnorm = 1.f/result; // for consistency with backward kernel + + // Write data to output + for(IndexType i = tid; i < rowSize; i += stride ) + { + float val_f = ScalarConvert::to(DEVICE_LINEAR_GET(v, i + rowStart)); + DEVICE_LINEAR_GET(w, i + rowStart) = ScalarConvert::to(g_this_row*val_f*rnorm); + } +} + +template +__global__ void weight_norm_fwd_last_dim_kernel +( + TensorInfo w, + TensorInfo norms, + TensorInfo v, + TensorInfo g, + IndexType fast_dim_size, + IndexType slower_dims_size +) +{ + IndexType fast_dim_location = threadIdx.x + blockIdx.x*blockDim.x; + + extern __shared__ float alloc[]; + float* s = &alloc[0]; + float* rnorms_this_block = &alloc[blockDim.x*blockDim.y]; + + float thread_sum = 0.f; + + IndexType slower_dims_location = threadIdx.y; + IndexType currentIdx = fast_dim_location + fast_dim_size*slower_dims_location; + if(fast_dim_location < fast_dim_size) + while(slower_dims_location < slower_dims_size) + { + float val_f = ScalarConvert::to(DEVICE_LINEAR_GET(v, currentIdx)); + thread_sum += val_f*val_f; // AccumOp, could do Kahan here + currentIdx += blockDim.y*fast_dim_size; + slower_dims_location += blockDim.y; + } + + reduce_block_into_lanes(s, thread_sum, blockDim.x); + + // Better to pass an EpilogueOp to reduce_block_into_lanes, can try later + if(threadIdx.y == 0) + { + float result = s[threadIdx.x]; + float norm_this_col = sqrtf(result); + DEVICE_LINEAR_GET_F(norms, fast_dim_location) = norm_this_col; + rnorms_this_block[threadIdx.x] = 1.f/norm_this_col; + // printf("blockIdx.x = %d, threadIdx.x = %d, norm_this_col = %f\n", + // blockIdx.x, threadIdx.x, norm_this_col); + } + + __syncthreads(); + + float g_this_col = ScalarConvert::to(DEVICE_LINEAR_GET(g, fast_dim_location)); + + float rnorm = rnorms_this_block[threadIdx.x]; + + slower_dims_location = threadIdx.y; + currentIdx = fast_dim_location + fast_dim_size*slower_dims_location; + if(fast_dim_location < fast_dim_size) + while(slower_dims_location < slower_dims_size) + { + float val_f = ScalarConvert::to(DEVICE_LINEAR_GET(v, currentIdx)); + DEVICE_LINEAR_GET(w, currentIdx) = ScalarConvert::to(g_this_col*val_f*rnorm); + currentIdx += blockDim.y*fast_dim_size; + slower_dims_location += blockDim.y; + } +} + +template +__global__ void weight_norm_bwd_first_dim_kernel +( + TensorInfo pLpv, + TensorInfo pLpg, + TensorInfo pLpw, + TensorInfo savedv, + TensorInfo savedg, + TensorInfo savedNorms, + IndexType rowSize +) +{ + // For now, assign one block to each row. + IndexType tid = threadIdx.x; + IndexType row = blockIdx.x; + IndexType stride = blockDim.x; + + // Logical index offset for this flattened row + IndexType rowStart = row*rowSize; + + extern __shared__ float s[]; + + float thread_sum = 0.f; + for(IndexType i = tid; i < rowSize; i += stride ) + { + float pLpwi = ScalarConvert::to(DEVICE_LINEAR_GET(pLpw, i + rowStart)); + float savedvi = ScalarConvert::to(DEVICE_LINEAR_GET(savedv, i + rowStart)); + thread_sum += pLpwi*savedvi; // AccumOp, could do Kahan here + } + + reduce_block_into_lanes(s, thread_sum, 1); + float result = s[0]; + + // Could choose to save reciprocal of norm instead I suppose, but norms is probably + // more handy to keep around. + // Broadcast load; could use shared memory instead. + float rnorm = 1.f/DEVICE_LINEAR_GET_F(savedNorms, row); + float rnorm3 = rnorm*rnorm*rnorm; + + // Write g gradients. + if(tid == 0) + DEVICE_LINEAR_GET(pLpg, row) = ScalarConvert::to(result*rnorm); + + // Broadcast load, could use shared memory instead. + float g_this_row = ScalarConvert::to(DEVICE_LINEAR_GET(savedg, row)); + + // Write v gradients. We are reusing values that were loaded earlier, so there + // is an optimization opportunity here (store values persistently). + for(IndexType j = tid; j < rowSize; j += stride ) + { + float pLpwj = ScalarConvert::to(DEVICE_LINEAR_GET(pLpw, j + rowStart)); + float savedvj = ScalarConvert::to(DEVICE_LINEAR_GET(savedv, j + rowStart)); + float pLpvj = g_this_row*(rnorm*pLpwj - rnorm3*savedvj*result); + DEVICE_LINEAR_GET(pLpv, j + rowStart) = ScalarConvert::to(pLpvj); + } +} + +template +__global__ void weight_norm_bwd_last_dim_kernel +( + TensorInfo pLpv, + TensorInfo pLpg, + TensorInfo pLpw, + TensorInfo savedv, + TensorInfo savedg, + TensorInfo savedNorms, + IndexType fast_dim_size, + IndexType slower_dims_size +) +{ + IndexType fast_dim_location = threadIdx.x + blockIdx.x*blockDim.x; + + extern __shared__ float s[]; + + float thread_sum = 0.f; + + IndexType slower_dims_location = threadIdx.y; + IndexType currentIdx = fast_dim_location + fast_dim_size*slower_dims_location; + if(fast_dim_location < fast_dim_size) + while(slower_dims_location < slower_dims_size) + { + float pLpwi = ScalarConvert::to(DEVICE_LINEAR_GET(pLpw, currentIdx)); + float savedvi = ScalarConvert::to(DEVICE_LINEAR_GET(savedv, currentIdx)); + thread_sum += pLpwi*savedvi; // AccumOp, could do Kahan here + currentIdx += blockDim.y*fast_dim_size; + slower_dims_location += blockDim.y; + } + + reduce_block_into_lanes(s, thread_sum, blockDim.x); + float result = s[threadIdx.x]; + + // Broadcast load; could use shared memory instead. + float rnorm = 1.f/DEVICE_LINEAR_GET_F(savedNorms, fast_dim_location); + float rnorm3 = rnorm*rnorm*rnorm; + + // Write g gradients. + if(threadIdx.y == 0) + DEVICE_LINEAR_GET(pLpg, fast_dim_location) = ScalarConvert::to(result*rnorm); + + // Entire block pulls these values, could use shared memory instead. + float g_this_col = ScalarConvert::to(DEVICE_LINEAR_GET(savedg, fast_dim_location)); + + // Write v gradients. + slower_dims_location = threadIdx.y; + currentIdx = fast_dim_location + fast_dim_size*slower_dims_location; + if(fast_dim_location < fast_dim_size) + while(slower_dims_location < slower_dims_size) + { + float pLpwj = ScalarConvert::to(DEVICE_LINEAR_GET(pLpw, currentIdx)); + float savedvj = ScalarConvert::to(DEVICE_LINEAR_GET(savedv, currentIdx)); + float pLpvj = g_this_col*(rnorm*pLpwj - rnorm3*savedvj*result); + DEVICE_LINEAR_GET(pLpv, currentIdx) = ScalarConvert::to(pLpvj); + currentIdx += blockDim.y*fast_dim_size; + slower_dims_location += blockDim.y; + } +} + +template +void send_to_fwd_wrapper::call +( + vector>& tensors, + int dim +) +{ +#ifdef DEBUG_ANY + cout << "hello from send_to_fwd with v.type = " << v.type << endl; +#endif + + auto w (*((TensorInfo*)&tensors[0])); + auto norms(*((TensorInfo*)&tensors[1])); + auto v (*((TensorInfo*)&tensors[2])); + auto g (*((TensorInfo*)&tensors[3])); + + if(dim == 0) + { + // Find logical size of each flattened slowest-dim row + IndexType rowSize = 1; + for(IndexType i = v.dims - 1; i > 0; i--) + rowSize *= v.sizes[i]; + + weight_norm_fwd_first_dim_kernel<<>> + ( + w, + norms, + v, + g, + rowSize + ); + } + else if(dim == v.dims - 1) + { + // Precompute slower_dims_size and fast_dim_size because they involve dynamically indexing an array. + IndexType slower_dims_size = 1; + for(IndexType i = 0; i < v.dims - 1; i++) + slower_dims_size *= v.sizes[i]; + + int fast_dim_size = v.sizes[v.dims-1]; + + weight_norm_fwd_last_dim_kernel<<<(fast_dim_size+TILE_W-1)/TILE_W, + dim3(TILE_W,TILE_H), + (TILE_W*TILE_H + TILE_W)*sizeof(float)>>> + ( + w, + norms, + v, + g, + fast_dim_size, + slower_dims_size + ); + } + // else + // { + // intermediate dim kernel. Error checking on the dim was already done in + // Module.cpp:weight_norm_fwd. Could put that logic here instead, if we include + // in both files. + // } + +#ifdef DEBUG_PROFILE + cudaDeviceSynchronize(); +#endif +} + +// template +template +void send_to_bwd_wrapper::call +( + vector>& tensors, + int dim +) +{ +#ifdef DEBUG_ANY + cout << "Hello from send_to_bwd with pLpw.type = " << pLpw.type << endl; +#endif + + // this feels sinful + auto pLpv (*((TensorInfo*)&tensors[0])); + auto pLpg (*((TensorInfo*)&tensors[1])); + auto pLpw (*((TensorInfo*)&tensors[2])); + auto savedv (*((TensorInfo*)&tensors[3])); + auto savedg (*((TensorInfo*)&tensors[4])); + auto savedNorms(*((TensorInfo*)&tensors[5])); + + if(dim == 0) + { + // Find logical size of each flattened slowest-dim row + IndexType rowSize = 1; + for(IndexType i = savedv.dims - 1; i > 0; i--) + rowSize *= savedv.sizes[i]; + + weight_norm_bwd_first_dim_kernel<<>> + ( + pLpv, + pLpg, + pLpw, + savedv, + savedg, + savedNorms, + rowSize + ); + } + else if(dim == savedv.dims - 1) + { + // Precompute slower_dims_size and fast_dim_size because they involve dynamically indexing an array. + IndexType slower_dims_size = 1; + for(IndexType i = 0; i < savedv.dims - 1; i++) + slower_dims_size *= savedv.sizes[i]; + + int fast_dim_size = savedv.sizes[savedv.dims-1]; + + weight_norm_bwd_last_dim_kernel<<<(fast_dim_size+TILE_W-1)/TILE_W, + dim3(TILE_W,TILE_H), + (TILE_W*TILE_H + TILE_W)*sizeof(float)>>> + ( + pLpv, + pLpg, + pLpw, + savedv, + savedg, + savedNorms, + fast_dim_size, + slower_dims_size + ); + } + // else + // { + // intermediate dim kernel. Error checking on the dim was already done in + // Module.cpp:weight_norm_bwd. Could put that logic here instead, if we include + // in both files. + // } + +#ifdef DEBUG_PROFILE + cudaDeviceSynchronize(); +#endif +} + +#define INSTANTIATE_SEND_TO_FWD(DATATYPE, ACCUMTYPE, IDXTYPE) \ +template void send_to_fwd_wrapper::call \ +( \ + vector>&, \ + int \ +); +INSTANTIATE_SEND_TO_FWD(float, float, idxType) +INSTANTIATE_SEND_TO_FWD(half, float, idxType) +#undef INSTANTIATE_SEND_TO_FWD + +#define INSTANTIATE_SEND_TO_BWD(DATATYPE, ACCUMTYPE, IDXTYPE) \ +template void send_to_bwd_wrapper::call \ +( \ + vector>&, \ + int \ +); +INSTANTIATE_SEND_TO_BWD(float, float, idxType) +INSTANTIATE_SEND_TO_BWD(half, float, idxType) +#undef INSTANTIATE_SEND_TO_BWD + +#undef BLOCK +#undef TILE_W +#undef TILE_H diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 000000000..36484d21c --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,39 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SPHINXPROJ = PyTorch +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +docset: html + doc2dash --name $(SPHINXPROJ) --icon $(SOURCEDIR)/_static/img/nv-pytorch2.png --enable-js --online-redirect-url http://pytorch.org/docs/ --force $(BUILDDIR)/html/ + + # Manually fix because Zeal doesn't deal well with `icon.png`-only at 2x resolution. + cp $(SPHINXPROJ).docset/icon.png $(SPHINXPROJ).docset/icon@2x.png + convert $(SPHINXPROJ).docset/icon@2x.png -resize 16x16 $(SPHINXPROJ).docset/icon.png + +gh-pages: + git checkout gh-pages + rm -rf build + rm -rf source + git checkout master -- . + make html + rm -rf ../_modules ../_sources ../_static + mv -fv build/html/* ../ + rm -rf build + git add -A + git commit -m "Generated gh-pages for `git log master -1 --pretty=short --abbrev-commit`" && git push origin gh-pages ; git checkout master + +.PHONY: help Makefile docset + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/source/RNN.rst b/docs/source/RNN.rst new file mode 100644 index 000000000..f2d3f346a --- /dev/null +++ b/docs/source/RNN.rst @@ -0,0 +1,26 @@ +.. role:: hidden + :class: hidden-section + +apex.RNN +=================================== + +This sumbodule is an in development API aimed to supply parity to torch.nn.RNN, +but be easier to extend. This module is not ready for use and still lacks important +features and validation. + +.. automodule:: apex.RNN +.. currentmodule:: apex.RNN + +.. RNN + ---------- + +.. autofunction:: LSTM + +.. autofunction:: mLSTM + +.. autofunction:: GRU + +.. autofunction:: ReLU + +.. autofunction:: Tanh + diff --git a/docs/source/_static/css/pytorch_theme.css b/docs/source/_static/css/pytorch_theme.css new file mode 100644 index 000000000..45e984c90 --- /dev/null +++ b/docs/source/_static/css/pytorch_theme.css @@ -0,0 +1,118 @@ +body { + font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; +} + +/* Default header fonts are ugly */ +h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption { + font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; +} + +/* Use white for docs background */ +.wy-side-nav-search { + background-color: #fff; +} + +.wy-nav-content-wrap, .wy-menu li.current > a { + background-color: #fff; +} + +@media screen and (min-width: 1400px) { + .wy-nav-content-wrap { + background-color: rgba(0, 0, 0, 0.0470588); + } + + .wy-nav-content { + background-color: #fff; + } +} + +/* Fixes for mobile */ +.wy-nav-top { + background-color: #fff; + background-image: url('../img/apex.jpg'); + background-repeat: no-repeat; + background-position: center; + padding: 0; + margin: 0.4045em 0.809em; + color: #333; +} + +.wy-nav-top > a { + display: none; +} + +@media screen and (max-width: 768px) { + .wy-side-nav-search>a img.logo { + height: 60px; + } +} + +/* This is needed to ensure that logo above search scales properly */ +.wy-side-nav-search a { + display: block; +} + +/* This ensures that multiple constructors will remain in separate lines. */ +.rst-content dl:not(.docutils) dt { + display: table; +} + +/* Use our red for literals (it's very similar to the original color) */ +.rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal { + color: #F05732; +} + +.rst-content tt.xref, a .rst-content tt, .rst-content tt.xref, +.rst-content code.xref, a .rst-content tt, a .rst-content code { + color: #404040; +} + +/* Change link colors (except for the menu) */ + +a { + color: #F05732; +} + +a:hover { + color: #F05732; +} + + +a:visited { + color: #D44D2C; +} + +.wy-menu a { + color: #b3b3b3; +} + +.wy-menu a:hover { + color: #b3b3b3; +} + +/* Default footer text is quite big */ +footer { + font-size: 80%; +} + +footer .rst-footer-buttons { + font-size: 125%; /* revert footer settings - 1/80% = 125% */ +} + +footer p { + font-size: 100%; +} + +/* For hidden headers that appear in TOC tree */ +/* see http://stackoverflow.com/a/32363545/3343043 */ +.rst-content .hidden-section { + display: none; +} + +nav .hidden-section { + display: inherit; +} + +.wy-side-nav-search>div.version { + color: #000; +} diff --git a/docs/source/_static/img/nv-pytorch2.png b/docs/source/_static/img/nv-pytorch2.png new file mode 100644 index 000000000..981268c60 Binary files /dev/null and b/docs/source/_static/img/nv-pytorch2.png differ diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 000000000..5e24ccf06 --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,247 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# PyTorch documentation build configuration file, created by +# sphinx-quickstart on Fri Dec 23 13:31:47 2016. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys +sys.path.insert(0, os.path.abspath('.')) +# sys.path.insert(0, os.path.abspath('../../apex/parallel/')) +import apex +# import multiproc +import sphinx_rtd_theme + + +# -- General configuration ------------------------------------------------ + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.autosummary', + 'sphinx.ext.doctest', + 'sphinx.ext.intersphinx', + 'sphinx.ext.todo', + 'sphinx.ext.coverage', + 'sphinx.ext.mathjax', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', +] + +napoleon_use_ivar = True + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +# source_suffix = ['.rst', '.md'] +source_suffix = '.rst' + +# The master toctree document. +master_doc = 'index' + +# General information about the project. +project = 'APEx' +copyright = '2018' +author = 'Christian Sarofeen, Natalia Gimelshein, Michael Carilli, Raul Puri' + +# The version info for the project you're documenting, acts as replacement for +# |version| and |release|, also used in various other places throughout the +# built documents. +# +# The short X.Y version. +# TODO: change to [:2] at v1.0 +# version = 'master (' + torch.__version__ + ' )' +version = '0.0' +# The full version, including alpha/beta/rc tags. +# TODO: verify this works as expected +release = '0.0.0' + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This patterns also effect to html_static_path and html_extra_path +exclude_patterns = [] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = True + + +# -- Options for HTML output ---------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' +html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +html_theme_options = { + 'collapse_navigation': False, + 'display_version': True, + 'logo_only': True, +} + +html_logo = '_static/img/nv-pytorch2.png' + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# html_style_path = 'css/pytorch_theme.css' +html_context = { + 'css_files': [ + 'https://fonts.googleapis.com/css?family=Lato', + '_static/css/pytorch_theme.css' + ], +} + + +# -- Options for HTMLHelp output ------------------------------------------ + +# Output file base name for HTML help builder. +htmlhelp_basename = 'PyTorchdoc' + + +# -- Options for LaTeX output --------------------------------------------- + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'apex.tex', 'APEx Documentation', + 'Torch Contributors', 'manual'), +] + + +# -- Options for manual page output --------------------------------------- + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [ + (master_doc, 'APEx', 'APEx Documentation', + [author], 1) +] + + +# -- Options for Texinfo output ------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + (master_doc, 'APEx', 'APEx Documentation', + author, 'APEx', 'One line description of project.', + 'Miscellaneous'), +] + + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = { + 'python': ('https://docs.python.org/', None), + 'numpy': ('http://docs.scipy.org/doc/numpy/', None), +} + +# -- A patch that prevents Sphinx from cross-referencing ivar tags ------- +# See http://stackoverflow.com/a/41184353/3343043 + +from docutils import nodes +from sphinx.util.docfields import TypedField +from sphinx import addnodes + + +def patched_make_field(self, types, domain, items, **kw): + # `kw` catches `env=None` needed for newer sphinx while maintaining + # backwards compatibility when passed along further down! + + # type: (List, unicode, Tuple) -> nodes.field + def handle_item(fieldarg, content): + par = nodes.paragraph() + par += addnodes.literal_strong('', fieldarg) # Patch: this line added + # par.extend(self.make_xrefs(self.rolename, domain, fieldarg, + # addnodes.literal_strong)) + if fieldarg in types: + par += nodes.Text(' (') + # NOTE: using .pop() here to prevent a single type node to be + # inserted twice into the doctree, which leads to + # inconsistencies later when references are resolved + fieldtype = types.pop(fieldarg) + if len(fieldtype) == 1 and isinstance(fieldtype[0], nodes.Text): + typename = u''.join(n.astext() for n in fieldtype) + typename = typename.replace('int', 'python:int') + typename = typename.replace('long', 'python:long') + typename = typename.replace('float', 'python:float') + typename = typename.replace('type', 'python:type') + par.extend(self.make_xrefs(self.typerolename, domain, typename, + addnodes.literal_emphasis, **kw)) + else: + par += fieldtype + par += nodes.Text(')') + par += nodes.Text(' -- ') + par += content + return par + + fieldname = nodes.field_name('', self.label) + if len(items) == 1 and self.can_collapse: + fieldarg, content = items[0] + bodynode = handle_item(fieldarg, content) + else: + bodynode = self.list_type() + for fieldarg, content in items: + bodynode += nodes.list_item('', handle_item(fieldarg, content)) + fieldbody = nodes.field_body('', bodynode) + return nodes.field('', fieldname, fieldbody) + +TypedField.make_field = patched_make_field diff --git a/docs/source/fp16_utils.rst b/docs/source/fp16_utils.rst new file mode 100644 index 000000000..ad5ecb5df --- /dev/null +++ b/docs/source/fp16_utils.rst @@ -0,0 +1,50 @@ +.. role:: hidden + :class: hidden-section + +apex.fp16_utils +=================================== + +This submodule contains utilities designed to streamline the mixed precision training recipe +presented by NVIDIA `on Parallel Forall`_ and in GTC 2018 Sessions +`Training Neural Networks with Mixed Precision: Theory and Practice`_ and +`Training Neural Networks with Mixed Precision: Real Examples`_. +For Pytorch users, Real Examples in particular is recommended. + +.. _`on Parallel Forall`: + https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/ +.. _`Training Neural Networks with Mixed Precision: Theory and Practice`: + http://on-demand.gputechconf.com/gtc/2018/video/S8923/ +.. _`Training Neural Networks with Mixed Precision: Real Examples`: + http://on-demand.gputechconf.com/gtc/2018/video/S81012/ + +.. automodule:: apex.fp16_utils +.. currentmodule:: apex.fp16_utils + +.. FusedNorm + ---------- + +.. autofunction:: prep_param_lists + +.. autofunction:: master_params_to_model_params + +.. autofunction:: model_grads_to_master_grads + +.. autoclass:: FP16_Optimizer + :members: + +.. autoclass:: Fused_Weight_Norm + :members: + +.. .. automodule:: apex.fp16_utils.loss_scaler + +.. autoclass:: LossScaler + :members: + +.. autoclass:: DynamicLossScaler + :members: + +.. .. automodule:: apex.fp16_utils.fp16util + :members: + + + diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 000000000..f70a9ff41 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,38 @@ +.. PyTorch documentation master file, created by + sphinx-quickstart on Fri Dec 23 13:31:47 2016. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +:github_url: https://gitlab-master.nvidia.com/csarofeen/apex + +APEx (A PyTorch Extension) +=================================== + +This is a repo is designed to hold PyTorch modules and utilities that are under active development and experimental. This repo is not designed as a long term solution or a production solution. Things placed in here are intended to be eventually moved to upstream PyTorch. + +A major focus of this extension is the training of neural networks using 16-bit precision floating point math, which offers significant performance benefits on latest NVIDIA GPU architectures. The reduced dynamic range of half precision, however, is more vulnerable to numerical overflow/underflow. + +APEX is an NVIDIA-maintained repository of utilities, including some that are targeted to improve the accuracy and stability of half precision networks, while maintaining high performance. The utilities are designed to be minimally invasive and easy to use. + +Installation requires CUDA9, PyTorch 0.3 or later, and Python 3. Installation can be done by running +:: + git clone https://www.github.com/nvidia/apex + cd apex + python setup.py install + + + +.. toctree:: + :maxdepth: 1 + :caption: apex + + parallel + reparameterization + RNN + fp16_utils + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` diff --git a/docs/source/parallel.rst b/docs/source/parallel.rst new file mode 100644 index 000000000..782b6c357 --- /dev/null +++ b/docs/source/parallel.rst @@ -0,0 +1,16 @@ +.. role:: hidden + :class: hidden-section + +apex.parallel +=================================== + +.. automodule:: apex.parallel +.. currentmodule:: apex.parallel + +Still need to figure out how to document multiproc.py. + +.. DistributedDataParallel + ---------- + +.. autoclass:: DistributedDataParallel + :members: diff --git a/docs/source/reparameterization.rst b/docs/source/reparameterization.rst new file mode 100644 index 000000000..0839022fa --- /dev/null +++ b/docs/source/reparameterization.rst @@ -0,0 +1,14 @@ +.. role:: hidden + :class: hidden-section + +apex.reparameterization +=================================== + +.. automodule:: apex.reparameterization +.. currentmodule:: apex.reparameterization + +.. autoclass:: Reparameterization + :members: + +.. autoclass:: WeightNorm + :members: diff --git a/examples/distributed/README.md b/examples/distributed/README.md new file mode 100644 index 000000000..d6dd1d2d5 --- /dev/null +++ b/examples/distributed/README.md @@ -0,0 +1,21 @@ +# Basic Multirpocess Example based on the MNIST example + +This version of this examples requires APEx which can be installed from https://www.github.com/nvidia/apex. This example demonstrates how to modify a network to use a basic but effective distributed data parallel module. This parallel method is designed to easily run multi-gpu runs on a single node. It was created as current parallel methods integraded into pytorch can induce significant overhead due to python GIL lock. This method will reduce the influence of those overheads and potentially provide a benefit in performance, especially for networks with a significant number of fast running operations. + +## Getting started +Prior to running please run +```pip install -r requirements.txt``` + +and start a single process run to allow the dataset to be downloaded (This will not work properly in multi-gpu. You can stop this job as soon as it starts iterating.). +```python main.py``` + +You can now the code multi-gpu with +```python -m apex.parallelmultiproc main.py ...``` +adding any normal option you'd like. + +## Converting your own model +To understand how to convert your own model to use the distributed module included, please see all sections of main.py within ```#=====START: ADDED FOR DISTRIBUTED======``` and ```#=====END: ADDED FOR DISTRIBUTED======``` flags. + +## Requirements +Pytorch master branch built from source. This requirement is to use NCCL as a distributed backend. +APEx installed from https://www.github.com/nvidia/apex \ No newline at end of file diff --git a/examples/distributed/main.py b/examples/distributed/main.py new file mode 100644 index 000000000..0ccdc3ee3 --- /dev/null +++ b/examples/distributed/main.py @@ -0,0 +1,200 @@ +from __future__ import print_function +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torchvision import datasets, transforms +from torch.autograd import Variable + +#=====START: ADDED FOR DISTRIBUTED====== +'''Add custom module for distributed''' + +try: + from apex.parallel import DistributedDataParallel as DDP +except ImportError: + raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") + +'''Import distributed data loader''' +import torch.utils.data +import torch.utils.data.distributed + +'''Import torch.distributed''' +import torch.distributed as dist + +#=====END: ADDED FOR DISTRIBUTED====== + +# Training settings +parser = argparse.ArgumentParser(description='PyTorch MNIST Example') +parser.add_argument('--batch-size', type=int, default=64, metavar='N', + help='input batch size for training (default: 64)') +parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', + help='input batch size for testing (default: 1000)') +parser.add_argument('--epochs', type=int, default=10, metavar='N', + help='number of epochs to train (default: 10)') +parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') +parser.add_argument('--momentum', type=float, default=0.5, metavar='M', + help='SGD momentum (default: 0.5)') +parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') +parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') +parser.add_argument('--log-interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + +#======START: ADDED FOR DISTRIBUTED====== +''' +Add some distributed options. For explanation of dist-url and dist-backend please see +http://pytorch.org/tutorials/intermediate/dist_tuto.html + +--world-size and --rank are required parameters as they will be used by the multiproc.py launcher +but do not have to be set explicitly. +''' + +parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, + help='url used to set up distributed training') +parser.add_argument('--dist-backend', default='nccl', type=str, + help='distributed backend') +parser.add_argument('--world-size', default=1, type=int, + help='Number of GPUs to use. Can either be manually set ' + + 'or automatically set by using \'python -m multiproc\'.') +parser.add_argument('--rank', default=0, type=int, + help='Used for multi-process training. Can either be manually set ' + + 'or automatically set by using \'python -m multiproc\'.') +#=====END: ADDED FOR DISTRIBUTED====== + +args = parser.parse_args() +args.cuda = not args.no_cuda and torch.cuda.is_available() + +#======START: ADDED FOR DISTRIBUTED====== +'''Add a convenience flag to see if we are running distributed''' +args.distributed = args.world_size > 1 + +'''Check that we are running with cuda, as distributed is only supported for cuda.''' +if args.distributed: + assert args.cuda, "Distributed mode requires running with CUDA." + +if args.distributed: + ''' + Set cuda device so everything is done on the right GPU. + THIS MUST BE DONE AS SOON AS POSSIBLE. + ''' + torch.cuda.set_device(args.rank % torch.cuda.device_count()) + + '''Initialize distributed communication''' + dist.init_process_group(args.dist_backend, init_method=args.dist_url, + world_size=args.world_size) + +#=====END: ADDED FOR DISTRIBUTED====== + +torch.manual_seed(args.seed) +if args.cuda: + torch.cuda.manual_seed(args.seed) + + +kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {} + +#=====START: ADDED FOR DISTRIBUTED====== +''' +Change sampler to distributed if running distributed. +Shuffle data loader only if distributed. +''' +train_dataset = datasets.MNIST('../data', train=True, download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])) + +if args.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) +else: + train_sampler = None + +train_loader = torch.utils.data.DataLoader( + train_dataset, sampler=train_sampler, + batch_size=args.batch_size, shuffle=(train_sampler is None), **kwargs +) + +#=====END: ADDED FOR DISTRIBUTED====== + +test_loader = torch.utils.data.DataLoader( + datasets.MNIST('../data', train=False, transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ])), + batch_size=args.test_batch_size, shuffle=True, **kwargs) + + +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=5) + self.conv2 = nn.Conv2d(10, 20, kernel_size=5) + self.conv2_drop = nn.Dropout2d() + self.fc1 = nn.Linear(320, 50) + self.fc2 = nn.Linear(50, 10) + + def forward(self, x): + x = F.relu(F.max_pool2d(self.conv1(x), 2)) + x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) + x = x.view(-1, 320) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x) + +model = Net() +if args.cuda: + model.cuda() + +#=====START: ADDED FOR DISTRIBUTED====== +''' +Wrap model in our version of DistributedDataParallel. +This must be done AFTER the model is converted to cuda. +''' + +if args.distributed: + model = DDP(model) +#=====END: ADDED FOR DISTRIBUTED====== + +optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) + +def train(epoch): + model.train() + for batch_idx, (data, target) in enumerate(train_loader): + if args.cuda: + data, target = data.cuda(), target.cuda() + data, target = Variable(data), Variable(target) + optimizer.zero_grad() + output = model(data) + loss = F.nll_loss(output, target) + loss.backward() + optimizer.step() + if batch_idx % args.log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.data[0])) + +def test(): + model.eval() + test_loss = 0 + correct = 0 + for data, target in test_loader: + if args.cuda: + data, target = data.cuda(), target.cuda() + data, target = Variable(data, volatile=True), Variable(target) + output = model(data) + test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss + pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability + correct += pred.eq(target.data.view_as(pred)).cpu().sum() + + test_loss /= len(test_loader.dataset) + print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( + test_loss, correct, len(test_loader.dataset), + 100. * correct / len(test_loader.dataset))) + + +for epoch in range(1, args.epochs + 1): + train(epoch) + test() diff --git a/examples/distributed/requirements.txt b/examples/distributed/requirements.txt new file mode 100644 index 000000000..ac988bdf8 --- /dev/null +++ b/examples/distributed/requirements.txt @@ -0,0 +1,2 @@ +torch +torchvision diff --git a/examples/distributed/run_distributed.sh b/examples/distributed/run_distributed.sh new file mode 100644 index 000000000..ca5aeab0c --- /dev/null +++ b/examples/distributed/run_distributed.sh @@ -0,0 +1 @@ +python -m apex.parallel.multiproc main.py diff --git a/include/THCTensorInfo.cuh b/include/THCTensorInfo.cuh new file mode 100644 index 000000000..dec6a5438 --- /dev/null +++ b/include/THCTensorInfo.cuh @@ -0,0 +1,142 @@ +#ifndef THC_TENSOR_INFO_INC +#define THC_TENSOR_INFO_INC + +#include +#include +#include + +// Maximum number of dimensions allowed for cutorch +#define MAX_CUTORCH_DIMS 10 + +// Warning string for tensor arguments that are too large or have too +// many dimensions +#define CUTORCH_STR(X) #X +#define CUTORCH_DIM_WARNING "tensor too large or too many (>" \ + CUTORCH_STR(MAX_CUTORCH_DIMS) ") dimensions" + +enum float_types { FLOAT = 0 , HALF = 1, DOUBLE = 2 }; + +// CUDA kernel argument that defines tensor layout +template +struct TensorInfo { + + TensorInfo(T* p, + int dim, + IndexType sz[MAX_CUTORCH_DIMS], + IndexType st[MAX_CUTORCH_DIMS]); + + TensorInfo(T* p, + int dim, + IndexType sz[MAX_CUTORCH_DIMS], + IndexType st[MAX_CUTORCH_DIMS], + float_types type); + + //Good way to cast from another format + //template > + //TensorInfo(TensorInfo &tinfo_in){ + // data = reinterpret_cast(tinfo_in.data); + //} + + T* data; + IndexType sizes[MAX_CUTORCH_DIMS]; + IndexType strides[MAX_CUTORCH_DIMS]; + int dims; + float_types type; +}; + +//Expand our combinations as convenient typedefs +typedef TensorInfo t_hi; +typedef TensorInfo t_hl; +typedef TensorInfo t_fi; +typedef TensorInfo t_fl; + + +template +TensorInfo::TensorInfo(T* p, + int dim, + IndexType sz[MAX_CUTORCH_DIMS], + IndexType st[MAX_CUTORCH_DIMS]) { + data = p; + dims = dim; + assert(dims > 0 && dims < MAX_CUTORCH_DIMS); + + for (int i = 0; i < dim; ++i) { + sizes[i] = sz[i]; + strides[i] = st[i]; + } +} + +template +TensorInfo::TensorInfo(T* p, + int dim, + IndexType sz[MAX_CUTORCH_DIMS], + IndexType st[MAX_CUTORCH_DIMS], + float_types _type){ + data = p; + dims = dim; + assert(dims > 0 && dims < MAX_CUTORCH_DIMS); + + for (int i = 0; i < dim; ++i) { + sizes[i] = sz[i]; + strides[i] = st[i]; + } + type=_type; +} + + + +// Translate a linear index for the apply to a T* offset; +// specialized on `Dims` to reduce nvcc compilation time +template +struct IndexToOffset { + static __forceinline__ __host__ __device__ IndexType get( + IndexType linearId, + const TensorInfo& info) { + IndexType offset = 0; + + // Use static dims + for (int i = Dims - 1; i > 0; --i) { + for (int i = Dims - 1; i > 0; --i) { + offset += linearId % info.sizes[i] * info.strides[i]; + linearId /= info.sizes[i]; + } + + offset += linearId * info.strides[0]; + return offset; + } + } +}; + + + +// For contiguous tensors, the offset = index +template +struct IndexToOffset { + static __forceinline__ __host__ __device__ IndexType + get(IndexType linearId, const TensorInfo& info) { + return linearId; + } +}; + +template +struct IndexToOffset { + static __forceinline__ __host__ __device__ IndexType get( + IndexType linearId, + const TensorInfo& info) { + + IndexType offset = 0; + + // Use dynamic dims + for (int i = info.dims - 1; i >= 0; --i) { + IndexType curDimIndex = linearId % info.sizes[i]; + IndexType curDimOffset = curDimIndex * info.strides[i]; + offset += curDimOffset; + + linearId /= info.sizes[i]; + } + + return offset; + } +}; + +#endif // THC_TENSOR_INFO_INC diff --git a/include/kernel.h b/include/kernel.h new file mode 100644 index 000000000..47872a7db --- /dev/null +++ b/include/kernel.h @@ -0,0 +1,73 @@ +#include "THCTensorInfo.cuh" +#include +#include +#include +#include +// this is suboptimal, try forward declarations later +#include + +#define Dims -2 +#define DEVICE_LINEAR_GET(D_TENSOR, INDEX) D_TENSOR.data[IndexToOffset::get(INDEX, D_TENSOR)] +#define DEVICE_LINEAR_GET_F(D_TENSOR, INDEX) D_TENSOR.data[IndexToOffset::get(INDEX, D_TENSOR)] + +// template +// void send_to_kernel( +// TensorInfo Input_1, +// TensorInfo Input_2, +// IndexType totalElems +// ); + +typedef int idxType; + +struct send_to_fwd_wrapper +{ + template + static void call(std::vector>& tensors, int dim); +}; + +struct send_to_bwd_wrapper +{ + template + static void call(std::vector>& tensors, int dim); +}; + +template +struct ScalarConvert { + static __host__ __device__ __forceinline__ Out to(const In v) { return (Out) v; } +}; + +#ifdef CUDA_HALF_TENSOR +template +struct ScalarConvert { + static __host__ __device__ __forceinline__ Out to(const half v) { +#ifdef __CUDA_ARCH__ + return (Out) __half2float(v); +#else + return (Out) THC_half2float(v); +#endif + } +}; + +template +struct ScalarConvert { + static __host__ __device__ __forceinline__ half to(const In v) { +#ifdef __CUDA_ARCH__ + return __float2half((float) v); +#else + return THC_float2half((float) v); +#endif + } +}; + +template <> +struct ScalarConvert { + static __host__ __device__ __forceinline__ half to(const half v) { + return v; + } +}; + +#endif diff --git a/setup.py b/setup.py new file mode 100644 index 000000000..e4a64750c --- /dev/null +++ b/setup.py @@ -0,0 +1,175 @@ +import re + +import os +import shutil +import inspect + +import distutils +import distutils.spawn +from distutils.command.clean import clean + +from setuptools import setup, Extension, find_packages +from setuptools.command.install import install + +import subprocess +import ctypes.util + +import torch + +#Takes a path to walk +#A function to decide if to keep +#collection if we want a list of all occurances +def find(path, regex_func, collect=False): + collection = [] if collect else None + for root, dirs, files in os.walk(path): + for file in files: + if regex_func(file): + if collect: + collection.append(os.path.join(root, file)) + else: + return os.path.join(root, file) + return list(set(collection)) + +def findcuda(): + CUDA_HOME = os.getenv('CUDA_HOME', '/usr/local/cuda') + if not os.path.exists(CUDA_HOME): + # We use nvcc path on Linux and cudart path on macOS + osname = platform.system() + if osname == 'Linux': + cuda_path = find_nvcc() + else: + cudart_path = ctypes.util.find_library('cudart') + if cudart_path is not None: + cuda_path = os.path.dirname(cudart_path) + else: + cuda_path = None + if cuda_path is not None: + CUDA_HOME = os.path.dirname(cuda_path) + else: + CUDA_HOME = None + WITH_CUDA = CUDA_HOME is not None + return CUDA_HOME + +#Get some important paths +curdir = os.path.dirname(os.path.abspath(inspect.stack()[0][1])) +buildir = curdir+os.sep+"build" +if not os.path.exists(buildir): + os.makedirs(buildir) + +torch_dir = os.path.split(torch.__file__)[0] + os.sep + "lib" + +cuda_files = find(curdir, lambda file: file.endswith(".cu"), True) +cuda_headers = find(curdir, lambda file: file.endswith(".cuh"), True) +headers = find(curdir, lambda file: file.endswith(".h"), True) + +libaten = find(torch_dir, re.compile("libaten", re.IGNORECASE).search, False) +aten_h = find(torch_dir, re.compile("aten.h", re.IGNORECASE).search, False) + +include_dirs = [os.path.dirname(os.path.dirname(aten_h))] +library_dirs = [] +for file in cuda_headers+headers: + dir = os.path.dirname(file) + if dir not in include_dirs: + include_dirs.append(dir) + +assert libaten, "Could not find PyTorch's libATen." +assert aten_h, "Could not find PyTorch's ATen header." + +library_dirs.append(os.path.dirname(libaten)) + +#create some places to collect important things +object_files = [] +extra_link_args=[] +main_libraries = [] +main_libraries += ['cudart', 'cuda', 'ATen'] +extra_compile_args = ["--std=c++11",] + +#findcuda returns root dir of CUDA +#include cuda/include and cuda/lib64 for python module build. +CUDA_HOME=findcuda() +library_dirs.append(os.path.join(CUDA_HOME, "lib64")) +include_dirs.append(os.path.join(CUDA_HOME, 'include')) + +class RMBuild(clean): + def run(self): + #BE VERY CAUTIOUS WHEN USING RMTREE!!! + #These are some carefully written/crafted directories + if os.path.exists(buildir): + shutil.rmtree(buildir) + + distdir = curdir+os.sep+"dist" + if os.path.exists(distdir): + shutil.rmtree(distdir) + + eggdir = curdir+os.sep+"apex.egg-info" + if os.path.exists(eggdir): + shutil.rmtree(eggdir) + clean.run(self) + +def CompileCudaFiles(): + + print() + print("Compiling cuda modules with nvcc:") + #Need arches to compile for. Compiles for 70 which requires CUDA9 + nvcc_cmd = ['nvcc', + '-Xcompiler', + '-fPIC', + '-gencode', 'arch=compute_52,code=sm_52', + '-gencode', 'arch=compute_60,code=sm_60', + '-gencode', 'arch=compute_61,code=sm_61', + '-gencode', 'arch=compute_70,code=sm_70', + '-gencode', 'arch=compute_70,code=compute_70', + '--std=c++11', + '-O3', + ] + + for dir in include_dirs: + nvcc_cmd.append("-I"+dir) + + for file in cuda_files: + object_name = os.path.basename( + os.path.splitext(file)[0]+".o" + ) + + object_file = os.path.join(buildir, object_name) + object_files.append(object_file) + + file_opts = ['-c', file, '-o', object_file] + + print(' '.join(nvcc_cmd+file_opts)) + subprocess.check_call(nvcc_cmd+file_opts) + + for object_file in object_files: + extra_link_args.append(object_file) + +print() +print("Arguments used to build CUDA extension:") +print("extra_compile_args :", extra_compile_args) +print("include_dirs: ", include_dirs) +print("extra_link_args: ", extra_link_args) +print("library_dirs: ", library_dirs) +print("libraries: ", main_libraries) +print() +CompileCudaFiles() + +print("Building CUDA extension.") +cuda_ext = Extension('apex._C', + [os.path.join('csrc', 'Module.cpp')], + extra_compile_args = extra_compile_args, + include_dirs=include_dirs, + extra_link_args=extra_link_args, + library_dirs=library_dirs, + runtime_library_dirs = library_dirs, + libraries=main_libraries + ) + +print("Building module.") +setup( + name='apex', version='0.1', + cmdclass={ + 'clean' : RMBuild, + }, + ext_modules=[cuda_ext,], + description='PyTorch Extensions written by NVIDIA', + packages=find_packages(exclude=("build", "csrc", "include", "tests")), +) diff --git a/tests/RNN/RNN_tests.py b/tests/RNN/RNN_tests.py new file mode 100644 index 000000000..d413923c4 --- /dev/null +++ b/tests/RNN/RNN_tests.py @@ -0,0 +1,118 @@ +import torch +import torch.nn as nn +from torch.autograd import Variable +import apex +from apex.RNN.models import bidirectionalRNN, stackedRNN, RNNCell +from torch.nn._functions.rnn import LSTMCell +import itertools + + +torch.backends.cudnn.enabled=False + +batch_first = False #not implemented yet +dropout = 0.0 #How to validate? +bidirectional = False #True works, but differs in definition to PyTorch + +rnn_types = ['LSTM', 'GRU', 'ReLU', 'Tanh'] +sizes = [8,4,2] + +seq_sizes = sizes +hidden_sizes = sizes +inp_sizes = sizes +batch_sizes = sizes +num_layerss = sizes + +biases = [True] + +def copy_param_set(pyt_rnn, my_rnn, layer=0, reverse=False): + my_params = None + + rnn = None + if isinstance(my_rnn, bidirectionalRNN): + rnn = my_rnn.fwd.rnns[layer] if not reverse else my_rnn.bckwrd.rnns[layer] + elif isinstance(my_rnn, stackedRNN): + rnn = my_rnn.rnns[layer] + else: + raise RuntimeError() + + param_names = ['w_ih', 'w_hh', 'b_ih', 'b_hh'] + + if not hasattr(rnn, 'b_hh'): + param_names = param_names[:2] + my_params = [getattr(rnn, param_name) for param_name in param_names] + + pyt_params = None + param_names = ['weight_ih_', 'weight_hh_', 'bias_ih_', 'bias_hh_'] + reverse_str = '_reverse' if reverse else '' + + if not hasattr(pyt_rnn, 'bias_hh_l0'): + param_names=param_names[:2] + pyt_params =[getattr(pyt_rnn, param_name + 'l' + str(layer) + reverse_str ) + for param_name in param_names ] + for pyt_param, my_param in zip(pyt_params, my_params): + pyt_param.data.copy_(my_param.data) + +def copy_all_params(pyt_rnn, my_rnn): + for layer in range(num_layers): + copy_param_set(pyt_rnn, my_rnn, layer) + if bidirectional: + copy_param_set(pyt_rnn, my_rnn, layer, bidirectional) + + +def compare_variables(v1, v2, msg, params): + diff = float((v1.data-v2.data).abs().max()) + if diff > 1e-5: + print("Error of ", diff, " found for ", msg, " for case: ", str(params)) + +def compare_tuple_variables(t1, t2, msg, params): + for var1, var2 in zip(t1, t2): + compare_variables(var1, var2, msg, params) + +def maybe_compare(v1, v2, msg, params): + if isinstance(v1, Variable) and isinstance(v2, Variable): + compare_variables(v1, v2, msg, params) + else: + compare_tuple_variables(v1, v2, msg, params) + +product = list(itertools.product(rnn_types, seq_sizes, hidden_sizes, inp_sizes, batch_sizes, num_layerss, biases)) + +for test_case in product: + rnn_type, seq_size, hidden_size, inp_size, batch_size, num_layers, bias = test_case + + inp = torch.cuda.FloatTensor(seq_size, batch_size, inp_size).uniform_() + + if rnn_type == 'ReLU' or rnn_type == 'Tanh': + pytorch_rnn = nn.RNN(inp_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, nonlinearity=rnn_type.lower()).cuda() + else: + pytorch_rnn = getattr(nn, rnn_type)(inp_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional).cuda() + my_rnn = getattr(apex.RNN.models, rnn_type)(inp_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional).cuda() + + copy_all_params(pytorch_rnn, my_rnn) + + pyt_inp = Variable(inp, requires_grad=True) + my_inp = Variable(inp, requires_grad=True) + + my_out, my_hiddens = my_rnn(my_inp) + pyt_out, pyt_hiddens = pytorch_rnn(pyt_inp) + + pyt_out.sum().backward() + my_out.sum().backward() + + + maybe_compare(pyt_out, my_out, "out", test_case) + + #If there's only one hidden state PyTorch doesn't return it in a tuple, + #apex does, so we wrap PyTorch's returned hidden state in a tuple. + if not isinstance(pyt_hiddens, tuple): + pyt_hiddens = (pyt_hiddens,) + + try: + for i, (pyt_hid, my_hid) in enumerate(zip(pyt_hiddens, my_hiddens)): + maybe_compare(pyt_hid, my_hid , "hx_"+str(i), test_case) + except ValueError: + maybe_compare(pyt_hiddens, my_hiddens , "hx_0", test_case) + + + maybe_compare(pyt_inp.grad, my_inp.grad, "inp.grad", test_case) + +print("Test passed.") diff --git a/tests/raw_ops/backward_gotchas.py b/tests/raw_ops/backward_gotchas.py new file mode 100644 index 000000000..3ed620b6c --- /dev/null +++ b/tests/raw_ops/backward_gotchas.py @@ -0,0 +1,47 @@ +import torch +from torch.autograd import Variable +# import apex +import numpy as np + +torch.manual_seed(2) +torch.cuda.manual_seed(2) +# torch.cuda.manual_seed_all(2) +torch.set_printoptions(precision=10) + +rows = 3 +cols = 20 +dims = rows, cols + +# Incoming gradient vectors we will use later +# Need to create the fp16 versions as a half() copy of a Tensor first rather than +# a Variable, because if you create pt_input_control as a Variable then say +# pt_input_fp16 = pt_input_control.half(), you are accidentally making pt_input_fp16 part of +# pLpOutput_control's computational graph, so it will not be a leaf! +pt_input_control = Variable(torch.randn(*dims).cuda(), requires_grad=True) +# pt_input_control = torch.ones(*dims).cuda() +pt_input_fp16 = pt_input_control.half() + +pt_output_fp16 = pt_input_fp16.sum() +pt_output_control = pt_input_control.sum() +print("After sum()s, before backwards:") +print("pt_output_control.requires_grad = ", pt_output_control.requires_grad) +print("pt_output_control.volatile = ", pt_output_control.volatile) +print("pt_input_control.grad = ", pt_input_control.grad) +print("pt_input_fp16.grad = ", pt_input_fp16.grad) +print("\n\n") + +pt_output_fp16.backward() # pt_input_fp16 is not the leaf of this graph, pt_input_control is. +print("After pt_output_fp16.backward():") +print("pt_input_control.grad = ", pt_input_control.grad) +print("pt_input_fp16.grad = ", pt_input_fp16.grad) +print("\n\n") +pt_output_control.backward() # Both backward() calls have pt_input_control as leaves, and so + # will accumulate gradients into pt_input_control.grad +print("After pt_output_control.backward():") +print("pt_input_control.grad = ", pt_input_control.grad) +print("pt_input_fp16.grad = ", pt_input_fp16.grad) +print("\n\n") +print("pt_output_control = ", pt_output_control) +print("pt_output_fp16 = ", pt_output_fp16) + + diff --git a/tests/raw_ops/compare.py b/tests/raw_ops/compare.py new file mode 100644 index 000000000..72af1817c --- /dev/null +++ b/tests/raw_ops/compare.py @@ -0,0 +1,42 @@ +import torch +import numpy as np + +def compare(cuda_out, pt_out, pt_out_control, rows): + + print( "Pytorch ops in fp16: ", pt_out ) + print( "Kernel result: ", cuda_out ) + print("Control (Pytorch ops, sticking to fp32): ", pt_out_control) + + # Make upconverted copies for error check against fp32 control + cuda_out_fp32 = cuda_out.float() + pt_out_fp32 = pt_out.float() + + # Flatten all but the slowest dimension + cuda_out = cuda_out.view(rows,-1) + pt_out = pt_out.view(rows,-1) + cuda_out_fp32 = cuda_out_fp32.view(rows,-1) + pt_out_fp32 = pt_out_fp32.view(rows,-1) + pt_out_control = pt_out_control.view(rows,-1) + + cuda_maxdiffs, cuda_maxdiff_locs = torch.max((pt_out_control - cuda_out_fp32).abs(),1) + pt_maxdiffs, pt_maxdiff_locs = torch.max((pt_out_control - pt_out_fp32 ).abs(),1) + + print( "cuda_maxdiffs = ", cuda_maxdiffs ) + print("cuda_maxdiff_locs = ", cuda_maxdiff_locs) + print( "pt_maxdiffs = ", pt_maxdiffs ) + print( "pt_maxdiff_locs = ", pt_maxdiff_locs ) + + row_indices = torch.LongTensor(np.arange(rows)) + + print("cuda_out at cuda_maxdiff_locs in each row:") + # bizarrely, this will work if you do it at the python prompt: + # print(cuda_out[row_indices,cuda_maxdiff_locs]) + # ...but it only seems to work here if you wrap with numpy arrays: + print( cuda_out[np.array(row_indices),np.array(cuda_maxdiff_locs)]) + print("pt_out_control at cuda_maxdiff_locs in each row:") + print(pt_out_control[np.array(row_indices),np.array(cuda_maxdiff_locs)]) + + print("pt_out at pt_maxdiff_locs in each row:" ) + print( pt_out[np.array(row_indices),np.array(pt_maxdiff_locs)]) + print("pt_out_control at pt_maxdiff_locs in each row:" ) + print(pt_out_control[np.array(row_indices),np.array(pt_maxdiff_locs)]) diff --git a/tests/raw_ops/norm.py b/tests/raw_ops/norm.py new file mode 100644 index 000000000..03c21eb94 --- /dev/null +++ b/tests/raw_ops/norm.py @@ -0,0 +1,20 @@ +import torch + +def get_norm_shape(p, dim): + if dim == 0: + output_size = (p.size(0),) + (1,) * (p.dim() - 1) + return output_size + elif dim == p.dim() - 1: + output_size = (1,) * (p.dim() - 1) + (p.size(-1),) + return output_size + return None + +def pt_norm(p, dim): + """Computes the norm over all dimensions except dim""" + if dim is None: + return p.norm() + elif dim == 0: + return p.contiguous().view(p.size(0), -1).norm(2,dim=1).view(*get_norm_shape(p, dim)) + elif dim == p.dim() - 1: + return p.contiguous().view(-1, p.size(-1)).norm(2,dim=0).view(*get_norm_shape(p, dim)) + return pt_norm(p.transpose(0, dim), 0).transpose(0, dim) diff --git a/tests/raw_ops/test_autograd.py b/tests/raw_ops/test_autograd.py new file mode 100644 index 000000000..058f73dab --- /dev/null +++ b/tests/raw_ops/test_autograd.py @@ -0,0 +1,146 @@ +import torch +from torch.autograd import Variable +from apex.fp16_utils import Fused_Weight_Norm +from compare import compare +from norm import pt_norm, get_norm_shape + +torch.manual_seed(2) +torch.cuda.manual_seed(2) +# torch.cuda.manual_seed_all(2) +torch.set_printoptions(precision=10) + +rows = 1 # 321 +cols = 4096 # 33 +fast = 4096 # 185 +dims = rows, cols, fast + +dim = 2 +CUDA_HALF = False +RAND = True # If false, input gradients (the result of the backward pass) + # should be analytically zero. + +# Loss will be computed via (output*elementwise).sum(). +# This means that output gradients in the backward pass will be equal +# to elementwise, so by manipulating elementwise, we have easy +# fine-grained control over the output gradients we'd like to use for +# testing purposes. +# +# The alternative is just to create the output_gradients manually +# and call output.backward(gradient=output_gradients), +# as is done in test_backward.py. +# But I wanted a minimal working sample similar to an "actual" use case, +# where gradients are computed by calling backward() on a scalar Loss. + +if RAND: + # With std=6.0, I observe the pytorch fp16 ops going unstable + # while the fused kernel remains stable (sometimes). + pt_in_fp32 = torch.cuda.FloatTensor(*dims ).normal_(std=1.0) + norm_shape = get_norm_shape(pt_in_fp32, dim) + pt_g_fp32 = torch.cuda.FloatTensor(*norm_shape).normal_(std=1.0) + elementwise_fp32 = torch.cuda.FloatTensor(*dims ).normal_(std=1.0) +else: + pt_in_fp32 = torch.cuda.FloatTensor(*dims ).fill_(1.0) + norm_shape = get_norm_shape(pt_in_fp32, dim) + pt_g_fp32 = torch.cuda.FloatTensor(*norm_shape).fill_(2.0) + elementwise_fp32 = torch.cuda.FloatTensor(*dims ).fill_(0.5) + +pt_in_fp16 = pt_in_fp32.half() +cd_in_prec = pt_in_fp32.clone() +pt_g_fp16 = pt_g_fp32.half() +cd_g_prec = pt_g_fp32.clone() +elementwise_fp16 = elementwise_fp32.half() +elementwise_prec = elementwise_fp32.clone() + +if CUDA_HALF: + cd_in_prec = cd_in_prec.half() + cd_g_prec = cd_g_prec.half() + elementwise_prec = elementwise_prec.half() + +pt_in_fp32 = Variable(pt_in_fp32 , requires_grad=True) +pt_in_fp16 = Variable(pt_in_fp16 , requires_grad=True) +cd_in_prec = Variable(cd_in_prec , requires_grad=True) + +pt_g_fp32 = Variable(pt_g_fp32 , requires_grad=True) +pt_g_fp16 = Variable(pt_g_fp16 , requires_grad=True) +cd_g_prec = Variable(cd_g_prec , requires_grad=True) + +elementwise_fp32 = Variable(elementwise_fp32, requires_grad=False) +elementwise_fp16 = Variable(elementwise_fp16, requires_grad=False) +elementwise_prec = Variable(elementwise_prec, requires_grad=False) + +torch.cuda.nvtx.range_push("fp16 forward, {}".format(pt_in_fp16.size())) +pt_norms_fp16 = pt_norm(pt_in_fp16, dim) +pt_out_fp16 = pt_in_fp16*(pt_g_fp16/pt_norms_fp16) +torch.cuda.nvtx.range_pop() +# torch.cuda.synchronize() + +torch.cuda.nvtx.range_push("fp32 forward, {}".format(pt_in_fp32.size())) +pt_norms_fp32 = pt_norm(pt_in_fp32, dim) +pt_out_fp32 = pt_in_fp32*(pt_g_fp32/pt_norms_fp32) +torch.cuda.nvtx.range_pop() +# torch.cuda.synchronize() + +# print("pt_norms_fp16 = ", pt_norms_fp16 ) +# print("pt_norms_fp32 = ", pt_norms_fp32) + +# print( "cd_in_prec.data_ptr = {:x}".format(cd_in_prec.data_ptr())) + +# print("elementwise_fp16 = ", elementwise_fp16) + +cd_in_contig = cd_in_prec.contiguous() +# Deliberately make noncontig to see if fused_norm +# will handle the error +# cd_in_contig = cd_in_contig[:,0:5] +# print(type(cd_in_contig)) +torch.cuda.nvtx.range_push("kernel forward") +fused_weight_norm = Fused_Weight_Norm.apply +cd_out_prec = fused_weight_norm(cd_in_contig, cd_g_prec, dim) +torch.cuda.nvtx.range_pop() +# torch.cuda.synchronize() + +# print("type(cd_out_prec.data) = ", type(cd_out_prec.data)) +# print("cd_out_prec.data_ptr = {:x}".format(cd_out_prec.data_ptr())) + +print("\n\n\nCOMPARING FORWARD PASS RESULTS\n\n\n") +compare(cd_out_prec.data, + pt_out_fp16.data, + pt_out_fp32.data, + rows) + +# It's ok to use elementwise_fp16 as a leaf in both the cuda and pytorch graphs. +# This sharing should not affect the computed gradients wrt pt_in_fp16 and cd_in_prec. +# However, just remember: +# If we set requires_grad=True for elementwise_fp16, elementwise_fp16.grad.data +# will accumulate gradients during the backward passes for both the cd and pytorch Losses. +# +# I do need v these parentheses v +Loss_cd_prec = (cd_out_prec*elementwise_prec).sum() +# print(L_cd_fp16) +Loss_pt_fp16 = (pt_out_fp16*elementwise_fp16).sum() +# print(L_pt_fp16) +Loss_pt_fp32 = (pt_out_fp32*elementwise_fp32).sum() +# print(L_pt_fp32) + +torch.cuda.nvtx.range_push("kernel backward") +Loss_cd_prec.backward() +torch.cuda.nvtx.range_pop() +torch.cuda.nvtx.range_push("fp16 backward") +Loss_pt_fp16.backward() +torch.cuda.nvtx.range_pop() +torch.cuda.nvtx.range_push("fp32 backward") +Loss_pt_fp32.backward() +torch.cuda.nvtx.range_pop() + +print("\n\n\nCOMPARING v GRADIENT RESULTS\n\n\n") +compare(cd_in_prec.grad.data, + pt_in_fp16.grad.data, + pt_in_fp32.grad.data, + rows) + +print("\n\n\nCOMPARING g GRADIENT RESULTS\n\n\n") +compare(cd_g_prec.grad.data, + pt_g_fp16.grad.data, + pt_g_fp32.grad.data, + cd_g_prec.size(0)) + + diff --git a/tests/raw_ops/test_backward.py b/tests/raw_ops/test_backward.py new file mode 100644 index 000000000..1f8a8b609 --- /dev/null +++ b/tests/raw_ops/test_backward.py @@ -0,0 +1,129 @@ +import torch +from torch.autograd import Variable +import apex._C +import numpy as np +from compare import compare +from norm import pt_norm, get_norm_shape + +torch.manual_seed(2) +torch.cuda.manual_seed(2) +# torch.cuda.manual_seed_all(2) +torch.set_printoptions(precision=10) + +sizes = [ + # (3, 512, 1024), + # (3, 512, 1536), + (3, 768, 1536), + # (3, 768, 2048), + # (3, 1024, 2048), + # (1, 1024, 4096), + # (1, 2048, 8192), + # (1, 4096, 4096), # this is not one of natalia's sizes, just a reference benchmark. + # (4096, 4096, 1), # this is not one of natalia's sizes, just a reference benchmark. + ] + +# rows = 3 +# cols = 512 +# fast = 1024 +HALF = True +RAND = True +dim = 2 + +for rows, cols, fast in sizes: + dims = rows, cols, fast + # Incoming gradient vectors we will use later + # Need to create the fp16 versions as a half() copy of a Tensor first rather than + # a Variable, because if you create pt_input_control as a Variable then say + # pt_input_fp16 = pt_input_control.half(), you are accidentally making pt_input_fp16 part of + # pLpOutput_control's computational graph, instead of the leaf of its own separate graph. + + # Careful: if you initialize with torch.ones, the gradient wrt input becomes analytically zero :P + if RAND: + pLpOutput_control = torch.cuda.FloatTensor(*dims ).uniform_()*1.0 + norm_shape = get_norm_shape(pLpOutput_control, dim) + pLpg_control = torch.cuda.FloatTensor(*norm_shape).uniform_() + pt_input_control = torch.cuda.FloatTensor(*dims ).uniform_() + pt_g_control = torch.cuda.FloatTensor(*norm_shape).uniform_() + else: + pLpOutput_control = torch.cuda.FloatTensor(*dims ).fill_(1.) + norm_shape = get_norm_shape(pLpOutput_control, dim) + pLpg_control = torch.cuda.FloatTensor(*norm_shape).fill_(2.) + pt_input_control = torch.cuda.FloatTensor(*dims ).fill_(4.0) + pt_g_control = torch.cuda.FloatTensor(*norm_shape).fill_(3.0) + + pLpOutput_fp16 = pLpOutput_control.clone() + pLpg_fp16 = pLpg_control .clone() + pt_input_fp16 = pt_input_control .clone() + pt_g_fp16 = pt_g_control .clone() + + if HALF: + pLpOutput_fp16 = pLpOutput_fp16.half() + pLpg_fp16 = pLpg_fp16 .half() + pt_input_fp16 = pt_input_fp16 .half() + pt_g_fp16 = pt_g_fp16 .half() + + pLpOutput_control = Variable(pLpOutput_control) + pLpg_control = Variable(pLpg_control ) + pLpOutput_fp16 = Variable(pLpOutput_fp16 ) + pLpg_fp16 = Variable(pLpg_fp16 ) + + pt_input_control = Variable(pt_input_control, requires_grad=True) + pt_g_control = Variable(pt_g_control , requires_grad=True) + pt_input_fp16 = Variable(pt_input_fp16 , requires_grad=True) + pt_g_fp16 = Variable(pt_g_fp16 , requires_grad=True) + + # Do forward pass in fp16 and fp32 + pt_norms_fp16 = pt_norm(pt_input_fp16, dim) + pt_norms_control = pt_norm(pt_input_control, dim) + + pt_output_fp16 = pt_input_fp16 *(pt_g_fp16 /pt_norms_fp16 ) + pt_output_control = pt_input_control*(pt_g_control/pt_norms_control) + + # Run the Cuda version + pLpInput_cuda = torch.cuda.FloatTensor(*dims ).fill_(0.) + pLpg_cuda = torch.cuda.FloatTensor(*norm_shape).fill_(0.) + + if HALF: + pLpInput_cuda = pLpInput_cuda.half() + pLpg_cuda = pLpg_cuda .half() + + torch.cuda.nvtx.range_push("kernel weight norm backward") + apex._C.weight_norm_bwd(pLpInput_cuda, + pLpg_cuda, + pLpOutput_fp16, + pt_input_fp16, + pt_g_fp16, + pt_norms_control.data, + dim) + torch.cuda.nvtx.range_pop() + + print("grad_output: ", pLpOutput_fp16.data) + print(" grad_input: ", pLpInput_cuda) + print(" savedInput: ", pt_input_fp16.data) + print("pt_norms_control: ", pt_norms_control.data) + print("pt_norms_fp16: ", pt_norms_fp16.data) + + torch.cuda.nvtx.range_push("pytorch fp16 backward") + pt_output_fp16 .backward(gradient=pLpOutput_fp16 , create_graph=True) + torch.cuda.nvtx.range_pop() + torch.cuda.nvtx.range_push("pytorch fp32 backward") + pt_output_control.backward(gradient=pLpOutput_control, create_graph=True) + torch.cuda.nvtx.range_pop() + + # pt_output_fp16 and pt_output_control are still saved, but + # pt_output_fp16.grad and pt_output_control.grad are None at this point + # because the graph is freed in the backwards pass. + # Specifying create_/retain_ graph don't seem to force saving of + # either the intermediate variables or their gradients. + + print("Comparing gradients wrt v") + torch.cuda.nvtx.range_push("compare pLpv") + compare(pLpInput_cuda, pt_input_fp16.grad.data, pt_input_control.grad.data, rows) + torch.cuda.nvtx.range_pop() + + print("Comparing gradients wrt g") + torch.cuda.nvtx.range_push("compare pLpg") + compare(pLpg_cuda, pt_g_fp16.grad.data, pt_g_control.grad.data, pLpg_cuda.size(0)) + torch.cuda.nvtx.range_pop() + + diff --git a/tests/raw_ops/test_forward.py b/tests/raw_ops/test_forward.py new file mode 100644 index 000000000..0d581ecf1 --- /dev/null +++ b/tests/raw_ops/test_forward.py @@ -0,0 +1,81 @@ +import torch +import sys +import apex._C +import numpy as np +from compare import compare +from norm import pt_norm, get_norm_shape + + +torch.manual_seed(2) +torch.cuda.manual_seed(2) +# torch.cuda.manual_seed_all(2) +torch.set_printoptions(precision=10) + +sizes = [ + # (3, 512, 1024), + # (3, 512, 1536), + # (3, 768, 1536), + # (3, 768, 2048), + # (3, 1024, 2048), + # (1, 1024, 4096), + # (1, 2048, 8192), + # (1, 4096, 4096), # this is not one of natalia's sizes, just a reference benchmark. + (4096, 4096, 1), # this is not one of natalia's sizes, just a reference benchmark. + # (353, 55, 353), # this is not one of natalia's sizes, just a reference benchmark. + ] + +# rows = 3 +# cols = 512 +# fast = 1024 +HALF = True +RAND = True +dim = 0 + + +for rows, cols, fast in sizes: + dims = rows, cols, fast + + print("\n\nTESTING dims = {}\n\n".format(dims)) + + if RAND: + pt_in = 1.*torch.cuda.FloatTensor(*dims).uniform_() + g = torch.cuda.FloatTensor(*get_norm_shape(pt_in, dim)).uniform_() + else: + pt_in = torch.cuda.FloatTensor(*dims).fill_(1.) + g = torch.cuda.FloatTensor(*get_norm_shape(pt_in, dim)).fill_(6.0) + + # per_col = torch.arange(1,cols+1).cuda() + # print((rows*per_col*per_col).sqrt()) + # pt_in *= per_col + + cuda_out = torch.cuda.FloatTensor(*dims).fill_(0.) + cuda_norms = torch.cuda.FloatTensor(*get_norm_shape(pt_in, dim)).fill_(0.) + + # Save a copy of the input as float + pt_in_fp32 = pt_in.clone() + g_fp32 = g.clone() + + if HALF: + pt_in = pt_in.half() + g = g.half() + cuda_out = cuda_out.half() + + apex._C.weight_norm_fwd(cuda_out, cuda_norms, pt_in, g, dim) + torch.cuda.synchronize() + # quit() + + print("type(cuda_out) = {}\n".format(type(cuda_out))) + + rownorms = pt_norm(pt_in, dim) + rownorms_fp32 = pt_norm(pt_in_fp32, dim) + + print("rownorms_fp32:") + print(rownorms_fp32) + print("cuda_norms" ) + print(cuda_norms ) + + # rownorms is broadcast; torch.div(pt_in, rownorms) and pt_in/rownorms work the same way + pt_out = pt_in*(g/rownorms) + pt_out_control = pt_in_fp32*(g_fp32/rownorms_fp32) + + compare(cuda_out, pt_out, pt_out_control, rows)