From 880ab925bce9f817a93988b021e12db5f67f7787 Mon Sep 17 00:00:00 2001 From: Marek Kolodziej Date: Tue, 13 Aug 2019 09:21:42 -0700 Subject: [PATCH] Adding PyProf to Apex (#404) Co-authored-by: Aditya Agrawal Co-authored-by: Marek Kolodziej --- README.md | 5 + apex/__init__.py | 2 + apex/pyprof/FAQs.md | 21 + apex/pyprof/README.md | 252 +++++++++ apex/pyprof/__init__.py | 3 + apex/pyprof/examples/.gitignore | 4 + apex/pyprof/examples/apex/README.md | 1 + apex/pyprof/examples/apex/fused_adam.py | 20 + apex/pyprof/examples/apex/fused_layer_norm.py | 28 + apex/pyprof/examples/apex/test.sh | 30 + .../examples/custom_func_module/README.md | 1 + .../custom_func_module/custom_function.py | 33 ++ .../custom_func_module/custom_module.py | 27 + .../examples/custom_func_module/test.sh | 30 + apex/pyprof/examples/imagenet/imagenet.py | 137 +++++ apex/pyprof/examples/imagenet/test.sh | 36 ++ apex/pyprof/examples/jit/README.md | 14 + .../examples/jit/jit_script_function.py | 30 + apex/pyprof/examples/jit/jit_script_method.py | 31 ++ .../pyprof/examples/jit/jit_trace_function.py | 30 + apex/pyprof/examples/jit/jit_trace_method.py | 36 ++ apex/pyprof/examples/jit/test.sh | 30 + apex/pyprof/examples/lenet.py | 65 +++ apex/pyprof/examples/operators.py | 145 +++++ apex/pyprof/examples/simple.py | 38 ++ .../pyprof/examples/user_annotation/README.md | 21 + .../pyprof/examples/user_annotation/resnet.py | 215 +++++++ apex/pyprof/examples/user_annotation/test.sh | 31 ++ apex/pyprof/nvtx/__init__.py | 2 + apex/pyprof/nvtx/nvmarker.py | 215 +++++++ apex/pyprof/parse/__init__.py | 0 apex/pyprof/parse/__main__.py | 10 + apex/pyprof/parse/db.py | 61 ++ apex/pyprof/parse/kernel.py | 210 +++++++ apex/pyprof/parse/nvvp.py | 282 ++++++++++ apex/pyprof/parse/parse.py | 122 ++++ apex/pyprof/prof/__init__.py | 0 apex/pyprof/prof/__main__.py | 10 + apex/pyprof/prof/activation.py | 65 +++ apex/pyprof/prof/base.py | 47 ++ apex/pyprof/prof/blas.py | 326 +++++++++++ apex/pyprof/prof/conv.py | 233 ++++++++ apex/pyprof/prof/convert.py | 62 +++ apex/pyprof/prof/data.py | 54 ++ apex/pyprof/prof/dropout.py | 50 ++ apex/pyprof/prof/embedding.py | 71 +++ apex/pyprof/prof/index_slice_join_mutate.py | 419 ++++++++++++++ apex/pyprof/prof/linear.py | 188 +++++++ apex/pyprof/prof/loss.py | 84 +++ apex/pyprof/prof/misc.py | 219 ++++++++ apex/pyprof/prof/normalization.py | 54 ++ apex/pyprof/prof/optim.py | 65 +++ apex/pyprof/prof/output.py | 149 +++++ apex/pyprof/prof/pointwise.py | 166 ++++++ apex/pyprof/prof/pooling.py | 59 ++ apex/pyprof/prof/prof.py | 256 +++++++++ apex/pyprof/prof/randomSample.py | 43 ++ apex/pyprof/prof/recurrentCell.py | 207 +++++++ apex/pyprof/prof/reduction.py | 150 +++++ apex/pyprof/prof/softmax.py | 115 ++++ apex/pyprof/prof/usage.py | 73 +++ apex/pyprof/prof/utility.py | 58 ++ requirements.txt | 5 + setup.py | 13 + tests/L0/run_pyprof_nvtx/__init__.py | 1 + tests/L0/run_pyprof_nvtx/test_pyprof_nvtx.py | 526 ++++++++++++++++++ tests/L0/run_test.py | 2 +- 67 files changed, 5987 insertions(+), 1 deletion(-) create mode 100644 apex/pyprof/FAQs.md create mode 100644 apex/pyprof/README.md create mode 100644 apex/pyprof/__init__.py create mode 100644 apex/pyprof/examples/.gitignore create mode 100644 apex/pyprof/examples/apex/README.md create mode 100644 apex/pyprof/examples/apex/fused_adam.py create mode 100644 apex/pyprof/examples/apex/fused_layer_norm.py create mode 100755 apex/pyprof/examples/apex/test.sh create mode 100644 apex/pyprof/examples/custom_func_module/README.md create mode 100644 apex/pyprof/examples/custom_func_module/custom_function.py create mode 100644 apex/pyprof/examples/custom_func_module/custom_module.py create mode 100755 apex/pyprof/examples/custom_func_module/test.sh create mode 100644 apex/pyprof/examples/imagenet/imagenet.py create mode 100755 apex/pyprof/examples/imagenet/test.sh create mode 100644 apex/pyprof/examples/jit/README.md create mode 100644 apex/pyprof/examples/jit/jit_script_function.py create mode 100644 apex/pyprof/examples/jit/jit_script_method.py create mode 100644 apex/pyprof/examples/jit/jit_trace_function.py create mode 100644 apex/pyprof/examples/jit/jit_trace_method.py create mode 100755 apex/pyprof/examples/jit/test.sh create mode 100755 apex/pyprof/examples/lenet.py create mode 100755 apex/pyprof/examples/operators.py create mode 100755 apex/pyprof/examples/simple.py create mode 100644 apex/pyprof/examples/user_annotation/README.md create mode 100644 apex/pyprof/examples/user_annotation/resnet.py create mode 100755 apex/pyprof/examples/user_annotation/test.sh create mode 100644 apex/pyprof/nvtx/__init__.py create mode 100644 apex/pyprof/nvtx/nvmarker.py create mode 100644 apex/pyprof/parse/__init__.py create mode 100644 apex/pyprof/parse/__main__.py create mode 100644 apex/pyprof/parse/db.py create mode 100644 apex/pyprof/parse/kernel.py create mode 100644 apex/pyprof/parse/nvvp.py create mode 100755 apex/pyprof/parse/parse.py create mode 100644 apex/pyprof/prof/__init__.py create mode 100644 apex/pyprof/prof/__main__.py create mode 100644 apex/pyprof/prof/activation.py create mode 100644 apex/pyprof/prof/base.py create mode 100644 apex/pyprof/prof/blas.py create mode 100644 apex/pyprof/prof/conv.py create mode 100644 apex/pyprof/prof/convert.py create mode 100644 apex/pyprof/prof/data.py create mode 100644 apex/pyprof/prof/dropout.py create mode 100644 apex/pyprof/prof/embedding.py create mode 100644 apex/pyprof/prof/index_slice_join_mutate.py create mode 100644 apex/pyprof/prof/linear.py create mode 100644 apex/pyprof/prof/loss.py create mode 100644 apex/pyprof/prof/misc.py create mode 100644 apex/pyprof/prof/normalization.py create mode 100644 apex/pyprof/prof/optim.py create mode 100644 apex/pyprof/prof/output.py create mode 100644 apex/pyprof/prof/pointwise.py create mode 100644 apex/pyprof/prof/pooling.py create mode 100755 apex/pyprof/prof/prof.py create mode 100644 apex/pyprof/prof/randomSample.py create mode 100644 apex/pyprof/prof/recurrentCell.py create mode 100644 apex/pyprof/prof/reduction.py create mode 100644 apex/pyprof/prof/softmax.py create mode 100644 apex/pyprof/prof/usage.py create mode 100644 apex/pyprof/prof/utility.py create mode 100644 requirements.txt create mode 100644 tests/L0/run_pyprof_nvtx/__init__.py create mode 100644 tests/L0/run_pyprof_nvtx/test_pyprof_nvtx.py diff --git a/README.md b/README.md index f954fb73b..e51e91edb 100644 --- a/README.md +++ b/README.md @@ -94,6 +94,11 @@ A Python-only build omits: - Fused kernels that improve the performance of `apex.parallel.DistributedDataParallel` and `apex.amp`. `DistributedDataParallel`, `amp`, and `SyncBatchNorm` will still be usable, but they may be slower. +To enable PyProf support, you need to install the packages required by PyProf. To do so, add the "--pyprof" option at installation time: +``` +$ pip install -v --no-cache-dir --global-option="--pyprof" --global-option="--cpp_ext" --global-option="--cuda_ext" ./ +``` + ### Windows support Windows support is experimental, and Linux is recommended. `pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .` may work if you were able to build Pytorch from source on your system. `pip install -v --no-cache-dir .` (without CUDA/C++ extensions) is more likely to work. If you installed Pytorch in a Conda environment, make sure to install Apex in that same environment. diff --git a/apex/__init__.py b/apex/__init__.py index 3a75a1819..3b2086650 100644 --- a/apex/__init__.py +++ b/apex/__init__.py @@ -1,5 +1,6 @@ # May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten import torch +import warnings from . import parallel from . import amp @@ -14,3 +15,4 @@ # load time) the error message is timely and visible. from . import optimizers from . import normalization +from . import pyprof diff --git a/apex/pyprof/FAQs.md b/apex/pyprof/FAQs.md new file mode 100644 index 000000000..27a0e97c5 --- /dev/null +++ b/apex/pyprof/FAQs.md @@ -0,0 +1,21 @@ +1. How do I intercept the Adam optimizer in APEX ? + + ```python + from apex import pyprof + import fused_adam_cuda + pyprof.nvtx.wrap(fused_adam_cuda, 'adam') + ``` + +2. If you are using JIT and/or AMP, the correct initialization sequence is + 1. Let any JIT to finish. + 2. Initlialize pyprof `pyprof.nvtx.init()`. + 3. Initialize AMP. + +3. How do I profile with `torch.distributed.launch` ? + + ```python + nvprof -f -o net%p.sql \ + --profile-from-start off \ + --profile-child-processes \ + python -m torch.distributed.launch net.py + ``` diff --git a/apex/pyprof/README.md b/apex/pyprof/README.md new file mode 100644 index 000000000..6681f94e4 --- /dev/null +++ b/apex/pyprof/README.md @@ -0,0 +1,252 @@ +## PyProf - PyTorch Profiling tool + +### What does this tool do? + +Analyzing the performance of deep neural networks is hard. Getting kernels out of [NvProf]([https://developer.nvidia.com/nvidia-visual-profiler](https://developer.nvidia.com/nvidia-visual-profiler)) or [NSight Compute]([https://developer.nvidia.com/nsight-compute](https://developer.nvidia.com/nsight-compute)) provides some generic kernel name and its execution time, but not detailed information regarding the following: + + - Which layer launched it: e.g. the association of `ComputeOffsetsKernel` with a concrete PyTorch layer or API is not obvious. + - What the tensor dimensions and precision were: without knowing the tensor dimensions and precision, it's impossible to reason about whether the actual (silicon) kernel time is close to maximum performance of such a kernel on the GPU. Knowing the tensor dimensions and precision, we can figure out the FLOPs and bandwidth required by a layer, and then determine how close to maximum performance the kernel is for that operation. + - Forward-backward correlation: currently it's very hard to determine what the forward pass step was that resulted in the particular weight and data gradients (wgrad, dgrad), which makes it difficult to determine the tensor dimensions required by these backprop steps to assess their performance. + - Did the kernel use [Tensor Cores]([https://www.youtube.com/watch?v=yyR0ZoCeBO8](https://www.youtube.com/watch?v=yyR0ZoCeBO8))? + - Which line in the user's code resulted in launching this particular kernel (program trace)? + +PyProf addresses all of the issues above by: + + 1. Instrumenting PyTorch operations to capture the tensor dimensions and precision using [NVTX](https://devblogs.nvidia.com/cuda-pro-tip-generate-custom-application-profile-timelines-nvtx). This information is recorded at profile capture time, e.g. using [NvProf](https://developer.nvidia.com/nvidia-visual-profiler). + 2. Querying the record produced by the profiler to correlate the kernel name and duration with PyTorch API/layer name, tensor dimensions, tensor precision, as well as calculating FLOPs and bandwidth for common operations. In addition, extra information from the profile is added for use by CUDA professionals, such as CUDA launch parameters (block/grid dimensions). + +Regarding FLOP and bandwidth implementations, these are usually quite straightforward. For example, for matrices AMxK and BKxN, the FLOP count for a matrix multiplication is 2 * M * N * K, and bandwidth is M * K + N * K + M * N. Note that these numbers are based on the algorithm, not the actual performance of the specific kernel. For more details, see NVIDIA's [Deep Learning Performance Guide](https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html). + +Armed with such information, the user can determine various issues to help them tune the network. For instance, according to the [Tensor Core Performance Guide]([https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html](https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html)), the M, N and K dimensions that result in Tensor Core usage need to be divisible by 8. In fact, PyProf comes with a flag that lets the user obtain information regarding whether Tensor Cores were used by the kernel. Other useful information might include knowing that a particular kernel did not exploit much thread parallelism, as determined by the grid/block dimensions. Since many PyTorch kernels are open-source (or even custom written by the user, as in [CUDA Extensions]([https://pytorch.org/tutorials/advanced/cpp_extension.html](https://pytorch.org/tutorials/advanced/cpp_extension.html))), this provides the user with information that helps root cause performance issues and prioritize optimization work. + + +### How to get started? + +1. Add the following lines to your PyTorch network: + + ```python + import torch.cuda.profiler as profiler + from apex import pyprof + pyprof.nvtx.init() + ``` + + Run the training/inference loop with the [PyTorch's NVTX context manager](https://pytorch.org/docs/stable/_modules/torch/autograd/profiler.html#emit_nvtx) + `with torch.autograd.profiler.emit_nvtx()`. Optionally, you can + use `profiler.start()` and `profiler.stop()` to pick an iteration + (say after warm-up) for which you would like to capture data. + Here's an example: + + ```python + iters = 500 + iter_to_capture = 100 + + # Define network, loss function, optimizer etc. + + # PyTorch NVTX context manager + with torch.autograd.profiler.emit_nvtx(): + + for iter in range(iters): + + if iter == iter_to_capture: + profiler.start() + + output = net(images) + loss = criterion(output, labels) + loss.backward() + optimizer.step() + + if iter == iter_to_capture: + profiler.stop() + ``` + +2. Run NVprof to generate a SQL (NVVP) file. This file can be opened with NVVP, as usual. + ```sh + # If you used profiler.start() and profiler.stop() in net.py + nvprof -f -o net.sql --profile-from-start off -- python net.py + + # Profile everything + nvprof -f -o net.sql -- python net.py + ``` + +**Note:** if you're experiencing issues with hardware counters and you get a message such as `**_ERR_NVGPUCTRPERM The user running does not have permission to access NVIDIA GPU Performance Counters on the target device_**`, please follow the steps described in [Hardware Counters](#hardware-counters). + +3. Run parser on the SQL file. The output is an ASCII file. Each line +is a python dictionary which contains information about the kernel name, +duration, parameters etc. This file can be used as input to other custom +scripts as well. + + ```sh + python -m apex.pyprof.parse net.sql > net.dict + ``` + +4. Run the profiler. The input is the python dictionary created above. The tool can produce a CSV output, a columnated output (similar to `column -t` for terminal readability) and a space separated output (for post processing by AWK for instance). The tool produces 20 columns of information for every GPU kernel but you can select a subset of columns using the `-c` flag. Note that a few columns might have the value "na" implying either its a work in progress or the tool was unable to extract that information. Assuming the directory is `prof`, here are a few examples of how to use `prof.py`. + + ```sh + # Print usage and help. Lists all available output columns. + python -m apex.pyprof.prof -h + + # Columnated output of width 150 with some default columns. + python -m apex.pyprof.prof -w 150 net.dict + + # CSV output. + python -m apex.pyprof.prof --csv net.dict + + # Space seperated output. + python -m apex.pyprof.prof net.dict + + # Columnated output of width 130 with columns index,direction,kernel name,parameters,silicon time. + python -m apex.pyprof.prof -w 130 -c idx,dir,kernel,params,sil net.dict + + # CSV output with columns index,direction,kernel name,parameters,silicon time. + python -m apex.pyprof.prof --csv -c idx,dir,kernel,params,sil net.dict + + # Space separated output with columns index,direction,kernel name,parameters,silicon time. + python -m apex.pyprof.prof -c idx,dir,kernel,params,sil net.dict + + # Input redirection. + python -m apex.pyprof.prof < net.dict + ``` + +5. Profile-guided optimization + +If kernels that do matrix multiplication/GEMM or convolution use half precision (fp16) data but do not use Tensor Cores (the TC column in the profile analysis output doesn't show a "1"), one can follow some basic steps to increase the likelihood that a Tensor Core-compatible kernel will be chosen. For example, for GEMMs, M, N and K should be divisible by 8, and for convolutions, the number of input and output channels shuold be divisible by 8. For more information, see detailed Tensor Core guides such as: +- Blog Post: [Tips for Optimizing GPU Performance Using Tensor Cores](https://devblogs.nvidia.com/optimizing-gpu-performance-tensor-cores/) +- GTC Talk: [Tensor Core Deep Learning Performance Guide](https://developer.download.nvidia.com/video/gputechconf/gtc/2019/presentation/s9926-tensor-core-performance-the-ultimate-guide.pdf) + +For both Tensor Core and non-Tensor Core Deep Learning performance optimization tips, see NVIDIA's [Deep Learning Performance Guide](https://docs.nvidia.com/deeplearning/sdk/dl-performance-guide/index.html). + +### TODOs +1. The support for conv transpose is currently missing. +2. PyProf currently works only with NvProf, but Nsight Compute support will be added in the future. + +### Example + +1. Run `nvprof` on the LeNet model in `examples/lenet.py`. This will output a SQL file called `net.sql`. + +```sh +nvprof -f -o net.sql --profile-from-start off -- python examples/lenet.py +``` + +**Note**: DO NOT add --analysis-metrics since that will change which table nvprof writes the kernels to (`CUPTI_ACTIVITY_KIND_KERNEL` instead of the usual `CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL`). Support for running with metrics may be added in the future. + +If you don't care about a full correlation analysis and you'd just like to view the timeline with detailed NVTX annotations, you can do so, e.g. in the NVIDIA Visual Profiler (NVVP). For example, you can call `nvvp net.sql` to view the annotated timeline. + +2. Run the `parse.py` script on `net.sql` to extract kernel and runtime information and +save it as `net.dict`. + +```sh +python -m apex.pyprof.parse net.sql > net.dict +``` + +This will produce a text file, which can be parsed by any external tool, but it can also be directly read one line at a time by Python by calling `eval` on the line being read. + +**Note: you do not need to process this output manually.** Here the output is just shown as an example of modularity - you can process the raw data yourself, or let the next step enrich the information further and dump a CSV. + +The output of this step will look as follows. Note that the dictionary has a lot more keys than the ones shown in the example. + +``` +>>> with open('torchvision.resnet50.adam.64.dict') as f: +... for line in f: +... d = eval(line) +... print(d['kShortName'], d['op'], d['kDuration'], d['block'], d['grid'], d['device'], d['stream'], d['trace']) +... +nchwToNhwc3To4Kernel ['conv2d'] 376324 (256, 1, 1) (1568, 1, 64) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:195'] +generic4Channel_kernel ['conv2d'] 10720 (512, 1, 1) (19, 1, 1) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:195'] +first_layer_fwd_kernel ['conv2d'] 411204 (128, 1, 1) (2, 7, 64) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:195'] +nhwcToNchwKernel ['conv2d'] 342371 (256, 1, 1) (392, 2, 64) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:195'] +elementwise_kernel ['__iadd__'] 2816 (128, 1, 1) (1, 1, 1) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:196'] +batch_norm_collect_statistics_kernel ['batch_norm', 'batch_norm'] 929513 (512, 1, 1) (64, 1, 1) 0 7 ['imagenet.py:137', 'imagenet.py:129', '/opt/conda/lib/python3.6/site-packages/torchvision/models/resnet.py:196'] +``` + +3. Run the `prof.py` script on `net.dict` to summarize the results into a CSV file, or to display the pretty-printed results on the screen. This step processes the raw output from step 2 to generate a nice output, but it also adds a lot of extra useful information inferred from the previous step, such as: +- FLOPs +- bandwidth (bytes in and out of GPU DRAM) +- tensor core usage + +```sh +python -m apex.pyprof.prof --csv net.dict > results.csv +``` + +You can choose which columns you'd like to display. Here's a list from calling `python -m apex.pyprof.prof -h`: + +``` + idx: Index + seq: PyTorch Sequence Id + altseq: PyTorch Alternate Sequence Id + tid: Thread Id + layer: User annotated NVTX string (can be nested) + trace: Function Call Trace + dir: Direction + sub: Sub Sequence Id + mod: Module + op: Operattion + kernel: Kernel Name + params: Parameters + sil: Silicon Time (in ns) + tc: Tensor Core Usage + device: GPU Device Id + stream: Stream Id + grid: Grid Dimensions + block: Block Dimensions + flops: Floating point ops (FMA = 2 FLOPs) + bytes: Number of bytes in and out of DRAM +``` + +Let's have a look at the pretty-printed output: +``` +python -m apex.pyprof.prof -w 100 -c kernel,op,sil,tc,flops,bytes,device,stream,block,grid torchvision.resnet50.adam.64.dict + +Kernel Op Sil(ns) TC FLOPs Bytes Dev Str Block Grid +elementwise_kernel relu 381028 - 51380224 205520896 0 7 512,1,1 100352,1,1 +volta_fp16_s884cudn conv2d 160002 1 1644167168 51388416 0 7 256,1,1 784,1,1 +elementwise_kernel relu 96545 - 12845056 51380224 0 7 512,1,1 25088,1,1 +volta_fp16_s884cudn conv2d 346083 1 6576668672 128483328 0 7 256,1,1 784,2,1 +``` + +Not using the pretty-print width (`-w`) option and adding `--csv` results in a CSV output instead: + +``` +python -m apex.pyprof.prof --csv -c kernel,mod,op,dir,sil,tc,flops,bytes,device,stream,block,grid torchvision.resnet50.adam.64.dict + +"Kernel","Module","Op","Direction","Sil(ns)","TC","FLOPs","Bytes","Device","Stream","Block","Grid" +"nchwToNhwc3To4Kernel","torch.nn.functional","conv2d","fprop","376324","-","0","0","0","7","256,1,1","1568,1,64" +"generic4Channel_kernel","torch.nn.functional","conv2d","fprop","10720","-","0","0","0","7","512,1,1","19,1,1" +"first_layer_fwd_kernel","torch.nn.functional","conv2d","fprop","411204","-","0","0","0","7","128,1,1","2,7,64" +"nhwcToNchwKernel","torch.nn.functional","conv2d","fprop","342371","-","0","0","0","7","256,1,1","392,2,64" +"elementwise_kernel","Tensor","__iadd__","fprop","2816","-","1.0","8","0","7","128,1,1","1,1,1" +"batch_norm_collect_statistics_kernel","torch.nn.functional","batch_norm","fprop","929513","-","411041792","411041792","0","7","512,1,1","64,1,1" +"batch_norm_transform_input_kernel","torch.nn.functional","batch_norm","fprop","377539","-","411041792","411041792","0","7","512,1,1","64,64,1" +"elementwise_kernel","torch.nn.functional","relu","fprop","381028","-","51380224","205520896","0","7","512,1,1","100352,1,1" +"MaxPoolForward","torch.nn.functional","max_pool2d","fprop","406531","-","0","0","0","7","256,1,1","50176,1,1" +"cudnn::gemm::computeOffsetsKernel","torch.nn.functional","conv2d","fprop","2464","-","0","0","0","7","128,1,1","25,1,1" +``` + +### Hardware Counters + +Profiling GPU workloads may require access to [hardware performance counters]([https://en.wikipedia.org/wiki/Hardware_performance_counter](https://en.wikipedia.org/wiki/Hardware_performance_counter)). Due to a [fix](https://nvidia.custhelp.com/app/answers/detail/a_id/4738) in recent NVIDIA drivers addressing [CVE‑2018‑6260](https://nvd.nist.gov/vuln/detail/CVE-2018-6260), the hardware counters are disabled by default, and require elevated privileges to be enabled again. If you're using a recent driver, you may see the following message when trying to run nvprof: + +```**_ERR_NVGPUCTRPERM The user running does not have permission to access NVIDIA GPU Performance Counters on the target device._**``` + +For details, see [here](https://developer.nvidia.com/nvidia-development-tools-solutions-ERR_NVGPUCTRPERM-permission-issue-performance-counters). + +_Permanent solution_ + +Follow the steps [here]([https://developer.nvidia.com/nvidia-development-tools-solutions-ERR_NVGPUCTRPERM-permission-issue-performance-counters](https://developer.nvidia.com/nvidia-development-tools-solutions-ERR_NVGPUCTRPERM-permission-issue-performance-counters)). The current steps for Linux are: +``` +sudo systemctl isolate multi-user +sudo modprobe -r nvidia_uvm nvidia_drm nvidia_modeset nvidia-vgpu-vfio nvidia +sudo modprobe nvidia NVreg_RestrictProfilingToAdminUsers=0 +sudo systemctl isolate graphical +``` +The above steps should result in a permanent change. + +_Temporary solution_ + +When running on bare metal, you can run nvprof with `sudo`. + +If you're running in a Docker image, you can temporarily elevate your privileges with one of the following (oldest to newest syntax): +
+nvidia-docker run --privileged
+docker run --runtime nvidia --privileged
+docker run --gpus all --privileged
+
diff --git a/apex/pyprof/__init__.py b/apex/pyprof/__init__.py new file mode 100644 index 000000000..5c8e3d84f --- /dev/null +++ b/apex/pyprof/__init__.py @@ -0,0 +1,3 @@ +import warnings + +from . import nvtx diff --git a/apex/pyprof/examples/.gitignore b/apex/pyprof/examples/.gitignore new file mode 100644 index 000000000..b6e239e2e --- /dev/null +++ b/apex/pyprof/examples/.gitignore @@ -0,0 +1,4 @@ +__pycache__ +*.sql +*.dict +*.csv diff --git a/apex/pyprof/examples/apex/README.md b/apex/pyprof/examples/apex/README.md new file mode 100644 index 000000000..6af798d34 --- /dev/null +++ b/apex/pyprof/examples/apex/README.md @@ -0,0 +1 @@ +This directory has examples of how to use `pyprof` with APEX extensions e.g. `fused_adam_cuda` and `fused_layer_norm_cuda`. diff --git a/apex/pyprof/examples/apex/fused_adam.py b/apex/pyprof/examples/apex/fused_adam.py new file mode 100644 index 000000000..0da8b3f1c --- /dev/null +++ b/apex/pyprof/examples/apex/fused_adam.py @@ -0,0 +1,20 @@ +import torch +import fused_adam_cuda +from apex.optimizers import FusedAdam, FP16_Optimizer +from apex import pyprof + +pyprof.nvtx.init() +pyprof.nvtx.wrap(fused_adam_cuda, 'adam') + +model = torch.nn.Linear(10, 20).cuda().half() +criterion = torch.nn.CrossEntropyLoss().cuda() +optimizer = FusedAdam(model.parameters()) +optimizer = FP16_Optimizer(optimizer) + +x = torch.ones(32, 10).cuda().half() +target = torch.empty(32, dtype=torch.long).random_(20).cuda() +y = model(x) +loss = criterion(y, target) +optimizer.zero_grad() +loss.backward() +optimizer.step() diff --git a/apex/pyprof/examples/apex/fused_layer_norm.py b/apex/pyprof/examples/apex/fused_layer_norm.py new file mode 100644 index 000000000..2c924e472 --- /dev/null +++ b/apex/pyprof/examples/apex/fused_layer_norm.py @@ -0,0 +1,28 @@ +import torch +import fused_layer_norm_cuda +from apex.normalization import FusedLayerNorm +from apex import pyprof + +pyprof.nvtx.init() +pyprof.nvtx.wrap(fused_layer_norm_cuda, 'forward') +pyprof.nvtx.wrap(fused_layer_norm_cuda, 'backward') +pyprof.nvtx.wrap(fused_layer_norm_cuda, 'forward_affine') +pyprof.nvtx.wrap(fused_layer_norm_cuda, 'backward_affine') + +input = torch.randn(20, 5, 10, 10).cuda() + +# With Learnable Parameters +m = FusedLayerNorm(input.size()[1:]).cuda() +output = m(input) + +# Without Learnable Parameters +m = FusedLayerNorm(input.size()[1:], elementwise_affine=False).cuda() +output = m(input) + +# Normalize over last two dimensions +m = FusedLayerNorm([10, 10]).cuda() +output = m(input) + +# Normalize over last dimension of size 10 +m = FusedLayerNorm(10).cuda() +output = m(input) diff --git a/apex/pyprof/examples/apex/test.sh b/apex/pyprof/examples/apex/test.sh new file mode 100755 index 000000000..dc02f388d --- /dev/null +++ b/apex/pyprof/examples/apex/test.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +set -e + +SCRIPT=`realpath $0` +SCRIPTPATH=`dirname $SCRIPT` +PYPROF="$SCRIPTPATH/../.." + +parse="python $PYPROF/parse/parse.py" +prof="python $PYPROF/prof/prof.py" + +for f in *.py +do + base=`basename $f .py` + sql=$base.sql + dict=$base.dict + + #NVprof + echo "nvprof -fo $sql python $f" + nvprof -fo $sql python $f + + #Parse + echo $parse $sql + $parse $sql > $dict + + #Prof + echo $prof $dict + $prof -w 130 $dict + \rm $sql $dict +done diff --git a/apex/pyprof/examples/custom_func_module/README.md b/apex/pyprof/examples/custom_func_module/README.md new file mode 100644 index 000000000..695b02b3b --- /dev/null +++ b/apex/pyprof/examples/custom_func_module/README.md @@ -0,0 +1 @@ +This directory has examples which show how to intercept (monkey patch) custom functions and modules with `pyprof`. No changes are required in `pyprof/parse`, however, users can add support for bytes and flops calculation for custom functions and modules in `pyprof/prof` by extending the `OperatorLayerBase` class. diff --git a/apex/pyprof/examples/custom_func_module/custom_function.py b/apex/pyprof/examples/custom_func_module/custom_function.py new file mode 100644 index 000000000..535df2fc2 --- /dev/null +++ b/apex/pyprof/examples/custom_func_module/custom_function.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +import torch +import torch.cuda.profiler as profiler +from apex import pyprof +#Initialize pyprof +pyprof.nvtx.init() + +class Foo(torch.autograd.Function): + @staticmethod + def forward(ctx, in1, in2): + out = in1 + in2 #This could be a custom C/C++ function. + return out + + @staticmethod + def backward(ctx, grad): + in1_grad = grad #This could be a custom C/C++ function. + in2_grad = grad #This could be a custom C/C++ function. + return in1_grad, in2_grad + +#Hook the forward and backward functions to pyprof +pyprof.nvtx.wrap(Foo, 'forward') +pyprof.nvtx.wrap(Foo, 'backward') + +foo = Foo.apply + +x = torch.ones(4,4).cuda() +y = torch.ones(4,4).cuda() + +with torch.autograd.profiler.emit_nvtx(): + profiler.start() + z = foo(x,y) + profiler.stop() diff --git a/apex/pyprof/examples/custom_func_module/custom_module.py b/apex/pyprof/examples/custom_func_module/custom_module.py new file mode 100644 index 000000000..c03f5cf3d --- /dev/null +++ b/apex/pyprof/examples/custom_func_module/custom_module.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 + +import torch +import torch.cuda.profiler as profiler +from apex import pyprof +pyprof.nvtx.init() + +class Foo(torch.nn.Module): + def __init__(self, size): + super(Foo, self).__init__() + self.n = torch.nn.Parameter(torch.ones(size)) + self.m = torch.nn.Parameter(torch.ones(size)) + + def forward(self, input): + return self.n*input + self.m + +#Hook the forward function to pyprof +pyprof.nvtx.wrap(Foo, 'forward') + +foo = Foo(4) +foo.cuda() +x = torch.ones(4).cuda() + +with torch.autograd.profiler.emit_nvtx(): + profiler.start() + z = foo(x) + profiler.stop() diff --git a/apex/pyprof/examples/custom_func_module/test.sh b/apex/pyprof/examples/custom_func_module/test.sh new file mode 100755 index 000000000..dc02f388d --- /dev/null +++ b/apex/pyprof/examples/custom_func_module/test.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +set -e + +SCRIPT=`realpath $0` +SCRIPTPATH=`dirname $SCRIPT` +PYPROF="$SCRIPTPATH/../.." + +parse="python $PYPROF/parse/parse.py" +prof="python $PYPROF/prof/prof.py" + +for f in *.py +do + base=`basename $f .py` + sql=$base.sql + dict=$base.dict + + #NVprof + echo "nvprof -fo $sql python $f" + nvprof -fo $sql python $f + + #Parse + echo $parse $sql + $parse $sql > $dict + + #Prof + echo $prof $dict + $prof -w 130 $dict + \rm $sql $dict +done diff --git a/apex/pyprof/examples/imagenet/imagenet.py b/apex/pyprof/examples/imagenet/imagenet.py new file mode 100644 index 000000000..ec4eb1b56 --- /dev/null +++ b/apex/pyprof/examples/imagenet/imagenet.py @@ -0,0 +1,137 @@ +#!/usr/bin/env python3 + +""" +Example to run pyprof with imagenet models. +""" + +import sys +import torch +import torch.nn as nn +import torchvision.models as models +import torch.cuda.profiler as profiler +import argparse + +from apex import pyprof +from apex.optimizers import FusedAdam, FP16_Optimizer +import fused_adam_cuda + +def parseArgs(): + parser = argparse.ArgumentParser(prog=sys.argv[0], description="Run popular imagenet models.") + + parser.add_argument("-m", + type=str, + default="resnet50", + choices=["alexnet", "densenet121", "densenet161", "densenet169", "densenet201", "googlenet", "mnasnet0_5", "mnasnet0_75", "mnasnet1_0", "mnasnet1_3", "mobilenet_v2", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d", "wide_resnet50_2", "wide_resnet101_2", "shufflenet_v2_x0_5", "shufflenet_v2_x1_0", "shufflenet_v2_x1_5", "shufflenet_v2_x2_0", "squeezenet1_0", "squeezenet1_1", "vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", "vgg19", "vgg19_bn", "inception_v3"], + help="Model.") + + parser.add_argument("-b", + type=int, + default=32, + help="Batch size.") + + parser.add_argument("-o", + type=str, + default="adam", + choices=["adam", "sgd"], + help="Optimizer.") + + args = parser.parse_args() + return args + +d = { + "alexnet": {'H': 224, 'W': 224, 'opts': {}}, + + "densenet121": {'H': 224, 'W': 224, 'opts': {}}, + "densenet161": {'H': 224, 'W': 224, 'opts': {}}, + "densenet169": {'H': 224, 'W': 224, 'opts': {}}, + "densenet201": {'H': 224, 'W': 224, 'opts': {}}, + + "googlenet": {'H': 224, 'W': 224, 'opts': {'aux_logits': False}}, + + "mnasnet0_5": {'H': 224, 'W': 224, 'opts': {}}, + "mnasnet0_75": {'H': 224, 'W': 224, 'opts': {}}, + "mnasnet1_0": {'H': 224, 'W': 224, 'opts': {}}, + "mnasnet1_3": {'H': 224, 'W': 224, 'opts': {}}, + + "mobilenet_v2": {'H': 224, 'W': 224, 'opts': {}}, + + "resnet18": {'H': 224, 'W': 224, 'opts': {}}, + "resnet34": {'H': 224, 'W': 224, 'opts': {}}, + "resnet50": {'H': 224, 'W': 224, 'opts': {}}, + "resnet101": {'H': 224, 'W': 224, 'opts': {}}, + "resnet152": {'H': 224, 'W': 224, 'opts': {}}, + + "resnext50_32x4d": {'H': 224, 'W': 224, 'opts': {}}, + "resnext101_32x8d": {'H': 224, 'W': 224, 'opts': {}}, + + "wide_resnet50_2": {'H': 224, 'W': 224, 'opts': {}}, + "wide_resnet101_2": {'H': 224, 'W': 224, 'opts': {}}, + + "shufflenet_v2_x0_5": {'H': 224, 'W': 224, 'opts': {}}, + "shufflenet_v2_x1_0": {'H': 224, 'W': 224, 'opts': {}}, + "shufflenet_v2_x1_5": {'H': 224, 'W': 224, 'opts': {}}, + "shufflenet_v2_x2_0": {'H': 224, 'W': 224, 'opts': {}}, + + "squeezenet1_0": {'H': 224, 'W': 224, 'opts': {}}, + "squeezenet1_1": {'H': 224, 'W': 224, 'opts': {}}, + + "vgg11": {'H': 224, 'W': 224, 'opts': {}}, + "vgg11_bn": {'H': 224, 'W': 224, 'opts': {}}, + "vgg13": {'H': 224, 'W': 224, 'opts': {}}, + "vgg13_bn": {'H': 224, 'W': 224, 'opts': {}}, + "vgg16": {'H': 224, 'W': 224, 'opts': {}}, + "vgg16_bn": {'H': 224, 'W': 224, 'opts': {}}, + "vgg19": {'H': 224, 'W': 224, 'opts': {}}, + "vgg19_bn": {'H': 224, 'W': 224, 'opts': {}}, + + "inception_v3": {'H': 299, 'W': 299, 'opts': {'aux_logits': False}}, + } + +def main(): + args = parseArgs() + + pyprof.nvtx.init() + pyprof.nvtx.wrap(fused_adam_cuda, 'adam') + + N = args.b + C = 3 + H = d[args.m]['H'] + W = d[args.m]['W'] + opts = d[args.m]['opts'] + classes = 1000 + + net = getattr(models, args.m) + net = net(**opts).cuda().half() + net.train() + + x = torch.rand(N, C, H, W).cuda().half() + target = torch.empty(N, dtype=torch.long).random_(classes).cuda() + + criterion = nn.CrossEntropyLoss().cuda() + if (args.o == "sgd"): + optimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9) + elif (args.o == "adam"): + optimizer = FusedAdam(net.parameters()) + optimizer = FP16_Optimizer(optimizer) + else: + assert False + + #Warm up without profiler + for i in range(2): + output = net(x) + loss = criterion(output, target) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + with torch.autograd.profiler.emit_nvtx(): + profiler.start() + output = net(x) + loss = criterion(output, target) + optimizer.zero_grad() + loss.backward() + optimizer.step() + profiler.stop() + +if __name__ == "__main__": + main() diff --git a/apex/pyprof/examples/imagenet/test.sh b/apex/pyprof/examples/imagenet/test.sh new file mode 100755 index 000000000..0c44c05bb --- /dev/null +++ b/apex/pyprof/examples/imagenet/test.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +set -e + +SCRIPT=`realpath $0` +SCRIPTPATH=`dirname $SCRIPT` +PYPROF="$SCRIPTPATH/../.." + +parse="python -m apex.pyprof.parse" +prof="python -m apex.pyprof.prof" + +for net in "resnet50" +do + for optim in adam sgd + do + for batch in 32 64 + do + base="torchvision".$net.$optim.$batch + sql=$base.sql + dict=$base.dict + + #NVprof + echo "nvprof -fo $sql --profile-from-start off python imagenet.py -m ${net} -o $optim -b $batch" + nvprof -fo $sql --profile-from-start off python imagenet.py -m ${net} -o $optim -b $batch + + #Parse + echo $parse $sql + $parse $sql > $dict + + #Prof + echo $prof $dict + $prof -w 130 $dict +# \rm $sql $dict + done + done +done diff --git a/apex/pyprof/examples/jit/README.md b/apex/pyprof/examples/jit/README.md new file mode 100644 index 000000000..101980f00 --- /dev/null +++ b/apex/pyprof/examples/jit/README.md @@ -0,0 +1,14 @@ +*As of this writing, these examples do not work +because of changes being proposed in PyTorch.* + +There are two ways to use PyTorch JIT + - Scripting + - Tracing + +In addition, we can JIT a + - Stand alone function + - Class / class method + +This directory has an example for each of the 4 cases. +Intercepting (monkey patching) JITted code has a few extra steps, +which are explained through comments. diff --git a/apex/pyprof/examples/jit/jit_script_function.py b/apex/pyprof/examples/jit/jit_script_function.py new file mode 100644 index 000000000..0b692c810 --- /dev/null +++ b/apex/pyprof/examples/jit/jit_script_function.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 + +import torch +import torch.cuda.profiler as profiler +from apex import pyprof + +#The following creates an object "foo" of type ScriptModule +#The new object has a function called "forward" + +@torch.jit.script +def foo(x, y): + return torch.sigmoid(x) + y + +#Initialize pyprof after the JIT step +pyprof.nvtx.init() + +#Assign a name to the object "foo" +foo.__name__ = "foo" + +#Hook up the forward function to pyprof +pyprof.nvtx.wrap(foo, 'forward') + +x = torch.zeros(4,4).cuda() +y = torch.ones(4,4).cuda() + +with torch.autograd.profiler.emit_nvtx(): + profiler.start() + z = foo(x, y) + profiler.stop() + print(z) diff --git a/apex/pyprof/examples/jit/jit_script_method.py b/apex/pyprof/examples/jit/jit_script_method.py new file mode 100644 index 000000000..79189c837 --- /dev/null +++ b/apex/pyprof/examples/jit/jit_script_method.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python3 + +import torch +import torch.cuda.profiler as profiler +from apex import pyprof + +class Foo(torch.jit.ScriptModule): + def __init__(self, size): + super(Foo, self).__init__() + self.n = torch.nn.Parameter(torch.ones(size)) + self.m = torch.nn.Parameter(torch.ones(size)) + + @torch.jit.script_method + def forward(self, input): + return self.n*input + self.m + +#Initialize pyprof after the JIT step +pyprof.nvtx.init() + +#Hook up the forward function to pyprof +pyprof.nvtx.wrap(Foo, 'forward') + +foo = Foo(4) +foo.cuda() +x = torch.ones(4).cuda() + +with torch.autograd.profiler.emit_nvtx(): + profiler.start() + z = foo(x) + profiler.stop() + print(z) diff --git a/apex/pyprof/examples/jit/jit_trace_function.py b/apex/pyprof/examples/jit/jit_trace_function.py new file mode 100644 index 000000000..06a781888 --- /dev/null +++ b/apex/pyprof/examples/jit/jit_trace_function.py @@ -0,0 +1,30 @@ +#!/usr/bin/env python3 + +import torch +import torch.cuda.profiler as profiler +from apex import pyprof + +def foo(x, y): + return torch.sigmoid(x) + y + +x = torch.zeros(4,4).cuda() +y = torch.ones(4,4).cuda() + +#JIT the function using tracing +#This returns an object of type ScriptModule with a forward method. +traced_foo = torch.jit.trace(foo, (x,y)) + +#Initialize pyprof after the JIT step +pyprof.nvtx.init() + +#Assign a name to the object "traced_foo" +traced_foo.__dict__['__name__'] = "foo" + +#Hook up the forward function to pyprof +pyprof.nvtx.wrap(traced_foo, 'forward') + +with torch.autograd.profiler.emit_nvtx(): + profiler.start() + z = traced_foo(x, y) + profiler.stop() + print(z) diff --git a/apex/pyprof/examples/jit/jit_trace_method.py b/apex/pyprof/examples/jit/jit_trace_method.py new file mode 100644 index 000000000..8a0d3fcec --- /dev/null +++ b/apex/pyprof/examples/jit/jit_trace_method.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python3 + +import torch +import torch.cuda.profiler as profiler +from apex import pyprof + +class Foo(torch.nn.Module): + def __init__(self, size): + super(Foo, self).__init__() + self.n = torch.nn.Parameter(torch.ones(size)) + self.m = torch.nn.Parameter(torch.ones(size)) + + def forward(self, input): + return self.n*input + self.m + +foo = Foo(4) +foo.cuda() +x = torch.ones(4).cuda() + +#JIT the class using tracing +traced_foo = torch.jit.trace(foo, x) + +#Initialize pyprof after the JIT step +pyprof.nvtx.init() + +#Assign a name to the object "traced_foo" +traced_foo.__dict__['__name__'] = "foo" + +#Hook up the forward function to pyprof +pyprof.nvtx.wrap(traced_foo, 'forward') + +with torch.autograd.profiler.emit_nvtx(): + profiler.start() + z = traced_foo(x) + profiler.stop() + print(z) diff --git a/apex/pyprof/examples/jit/test.sh b/apex/pyprof/examples/jit/test.sh new file mode 100755 index 000000000..dc02f388d --- /dev/null +++ b/apex/pyprof/examples/jit/test.sh @@ -0,0 +1,30 @@ +#!/bin/bash + +set -e + +SCRIPT=`realpath $0` +SCRIPTPATH=`dirname $SCRIPT` +PYPROF="$SCRIPTPATH/../.." + +parse="python $PYPROF/parse/parse.py" +prof="python $PYPROF/prof/prof.py" + +for f in *.py +do + base=`basename $f .py` + sql=$base.sql + dict=$base.dict + + #NVprof + echo "nvprof -fo $sql python $f" + nvprof -fo $sql python $f + + #Parse + echo $parse $sql + $parse $sql > $dict + + #Prof + echo $prof $dict + $prof -w 130 $dict + \rm $sql $dict +done diff --git a/apex/pyprof/examples/lenet.py b/apex/pyprof/examples/lenet.py new file mode 100755 index 000000000..9e22477ea --- /dev/null +++ b/apex/pyprof/examples/lenet.py @@ -0,0 +1,65 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.cuda.profiler as profiler +import torch.optim as optim + +from apex import pyprof +pyprof.nvtx.init() + +class LeNet5(nn.Module): + def __init__(self): + super(LeNet5, self).__init__() + # 1 input image channel, 6 output channels, 5x5 square convolution + # kernel + self.conv1 = nn.Conv2d(1, 6, 5) + self.conv2 = nn.Conv2d(6, 16, 5) + # an affine operation: y = Wx + b + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + # Max pooling over a (2, 2) window + x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) + # If the size is a square you can only specify a single number + x = F.max_pool2d(F.relu(self.conv2(x)), 2) + x = x.view(-1, self.num_flat_features(x)) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + def num_flat_features(self, x): + size = x.size()[1:] # all dimensions except the batch dimension + num_features = 1 + for s in size: + num_features *= s + return num_features + +with torch.autograd.profiler.emit_nvtx(): + + net = LeNet5().cuda() + + input = torch.randn(1, 1, 32, 32).cuda() + out = net(input) + + target = torch.randn(10) # a dummy target, for example + target = target.view(1, -1).cuda() # make it the same shape as output + criterion = nn.MSELoss() + + # create your optimizer + optimizer = optim.SGD(net.parameters(), lr=0.01) + + # in your training loop: + optimizer.zero_grad() # zero the gradient buffers + + profiler.start() + output = net(input) + loss = criterion(output, target) + loss.backward() + optimizer.step() # Does the update + profiler.stop() + diff --git a/apex/pyprof/examples/operators.py b/apex/pyprof/examples/operators.py new file mode 100755 index 000000000..a033c2ef4 --- /dev/null +++ b/apex/pyprof/examples/operators.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python3 + +""" +This file checks all Python operators. +""" + +import sys +import torch +import torch.cuda.profiler as profiler +import operator +import inspect + +#Import and initialize pyprof +from apex import pyprof +pyprof.nvtx.init() + +X = 1024 +Y = 1024 + +fa = torch.rand(X, Y).cuda() +fb = torch.rand(X, Y).cuda() +fc = torch.rand(X, Y).cuda() + +ia = torch.randint(0, 100, (X, Y)).cuda() +ib = torch.randint(0, 100, (X, Y)).cuda() + +sa = torch.ones(1,1).cuda() +sb = torch.ones(1,1).cuda() + +ba = fa.byte() + +unaryOps = ["abs", "__abs__", "neg", "__neg__",] +invertOps = ["inv", "invert", "__inv__", "__invert__",] #imlemented only for byte tensors +#pos, __pos__ is not implemented for tensors + +binaryOps = [] +binaryOps += [ "lt", "__lt__", "le", "__le__", "eq", "__eq__", "ne", "__ne__", "ge", "__ge__", "gt", "__gt__" ] +binaryOps += [ "add", "__add__", "sub", "__sub__", "mul", "__mul__", "floordiv", "__floordiv__", "truediv", "__truediv__", "pow", "__pow__", "mod", "__mod__"] +binaryOps += [ "and_", "__and__", "or_", "__or__", "xor", "__xor__", "lshift", "__lshift__", "rshift", "__rshift__"] + +inplaceOps = [] +inplaceOps += ["iadd", "__iadd__", "isub", "__isub__", "imul", "__imul__", "ifloordiv", "__ifloordiv__", "itruediv", "__itruediv__", "imod", "__imod__",] +#ipow, __ipow__ is not implemented in pytorch +inplaceOps += [ "iand", "__iand__", "ior", "__ior__", "ixor", "__ixor__", "ilshift", "__ilshift__", "irshift", "__irshift__",] + +matmulOps = [ "matmul", "__matmul__" ] +inplacematmulOps = [ "imatmul", "__imatmul__" ] + +reverseIntBinaryOps = ["__radd__", "__rsub__", "__rmul__", "__rfloordiv__", "__rpow__",] +reverseFloatBinaryOps = ["__radd__", "__rsub__", "__rmul__", "__rdiv__", "__rtruediv__", "__rfloordiv__", "__rpow__",] + +''' +TODO +.concat(a, b) +.__concat__(a, b) +.contains(a, b) +.__contains__(a, b) +.countOf(a, b) +.delitem(a, b) +.__delitem__(a, b) +.getitem(a, b) +.__getitem__(a, b) +.indexOf(a, b) +.setitem(a, b, c) +.__setitem__(a, b, c) +.length_hint(obj, default=0) +.iconcat(a, b) +.__iconcat__(a, b) +.index(a) +.__index__(a) +''' + +#Context manager +with torch.autograd.profiler.emit_nvtx(): + + #Start profiler + profiler.start() + + for op in unaryOps: + assert hasattr(operator, op) + f = getattr(operator, op) + assert inspect.isbuiltin(f) + c = f(ia) + + for op in invertOps: + assert hasattr(operator, op) + f = getattr(operator, op) + assert inspect.isbuiltin(f) + c = f(ba) + + for op in binaryOps: + assert hasattr(operator, op) + f = getattr(operator, op) + assert inspect.isbuiltin(f) + c = f(ia, ib) + c = f(ia, 2) + + for op in inplaceOps: + assert hasattr(operator, op) + f = getattr(operator, op) + assert inspect.isbuiltin(f) + ia = f(ia, ib) + ia = f(ia, 2) + + for op in matmulOps: + assert hasattr(operator, op) + f = getattr(operator, op) + assert inspect.isbuiltin(f) + c = f(fa, fb) + + for op in inplacematmulOps: + assert hasattr(operator, op) + f = getattr(operator, op) + assert inspect.isbuiltin(f) + fa = f(fa, fb) + + for op in reverseIntBinaryOps: + assert hasattr(torch.Tensor, op) + f = getattr(torch.Tensor, op) + ia = f(ia, ib) + + for op in reverseFloatBinaryOps: + assert hasattr(torch.Tensor, op) + f = getattr(torch.Tensor, op) + fa = f(fa, fb) + + ''' + #c = fa[3] + #c = fa[3][3] + #c = torch.min(fa, 3) + c = torch.sum(fa) + c = torch.max(fa) + c = -fa + #fc[2][2] = fa[2][2] + + c = a_scalar and b_scalar + c = a_scalar or b_scalar + c = not a_scalar + + c = a is b + c = a is not b + ''' + + #Stop profiler + profiler.stop() diff --git a/apex/pyprof/examples/simple.py b/apex/pyprof/examples/simple.py new file mode 100755 index 000000000..dbe9c615f --- /dev/null +++ b/apex/pyprof/examples/simple.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python3 + +""" +This simple file provides an example of how to + - import the pyprof library and initialize it + - use the emit_nvtx context manager + - start and stop the profiler + +Only kernels within profiler.start and profiler.stop calls are profiled. +To profile +$ nvprof -f -o simple.sql --profile-from-start off ./simple.py +""" + +import sys +import torch +import torch.cuda.profiler as profiler + +#Import and initialize pyprof +from apex import pyprof +pyprof.nvtx.init() + +a = torch.randn(5, 5).cuda() +b = torch.randn(5, 5).cuda() + +#Context manager +with torch.autograd.profiler.emit_nvtx(): + + #Start profiler + profiler.start() + + c = a + b + c = torch.mul(a,b) + c = torch.matmul(a,b) + c = torch.argmax(a, dim=1) + c = torch.nn.functional.pad(a, (1,1)) + + #Stop profiler + profiler.stop() diff --git a/apex/pyprof/examples/user_annotation/README.md b/apex/pyprof/examples/user_annotation/README.md new file mode 100644 index 000000000..30523615a --- /dev/null +++ b/apex/pyprof/examples/user_annotation/README.md @@ -0,0 +1,21 @@ +Nvidia NVTX range markers (https://docs.nvidia.com/gameworks/content/gameworkslibrary/nvtx/nvidia_tools_extension_library_nvtx.htm) +are a useful tool to capture and observe events and code ranges etc. +Using PyTorch APIs e.g, `torch.cuda.nvtx.range_push("xxx")` and `torch.cuda.nvtx.range_pop()` users can easily add their own NVTX range markers. These markers can then be observed in the Nvidia Visual Profiler (NVVP). + +While inserting NVTX markers (strings), if the users follow a specific string pattern `"layer:your_string_here"` e.g. `"layer:conv1"` or `"layer:encoder_layer_3_self_attention`, then `pyprof` will display the strings `conv1` and `encoder_layer_3_self_attention` next to the associated kernels in the output of `prof.py` when used with the `-c layer` option. + +NVTX range markers can be nested and if users follow the above string pattern, the output of `prof.py` will show all the markers associated with a kernel. + +The file `resnet.py` (a simplified version of the torchvision model) shows an example of how users can add (nested) NVTX markers with information which can greatly aid in understanding and analysis of networks. + +Note that the pattern `"layer:your_string_here"` was chosen to aid information extraction by `pyprof`. The tool will work seamlessly even if there are other markers or no markers at all. + +### To run + +```sh +nvprof -fo resnet.sql --profile-from-start off python resnet.py +parse.py resnet.sql > resnet.dict +prof.py --csv -c idx,layer,dir,mod,op,kernel,params,sil resnet.dict +``` + +The file `resnet.sql` can also be opened with NVVP as usual. diff --git a/apex/pyprof/examples/user_annotation/resnet.py b/apex/pyprof/examples/user_annotation/resnet.py new file mode 100644 index 000000000..ec351aa97 --- /dev/null +++ b/apex/pyprof/examples/user_annotation/resnet.py @@ -0,0 +1,215 @@ +#!/usr/bin/env python3 + +""" +An example showing use of nested NVTX markers. +""" + +import torch +import torch.nn as nn + +import torch.cuda.profiler as profiler +import torch.cuda.nvtx as nvtx +from apex import pyprof +pyprof.nvtx.init() + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + +class Bottleneck(nn.Module): + expansion = 4 + count = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + self.id = Bottleneck.count + Bottleneck.count += 1 + + def forward(self, x): + identity = x + + nvtx.range_push("layer:Bottleneck_{}".format(self.id)) + + nvtx.range_push("layer:Conv1") + out = self.conv1(x) + nvtx.range_pop() + + nvtx.range_push("layer:BN1") + out = self.bn1(out) + nvtx.range_pop() + + nvtx.range_push("layer:ReLU") + out = self.relu(out) + nvtx.range_pop() + + nvtx.range_push("layer:Conv2") + out = self.conv2(out) + nvtx.range_pop() + + nvtx.range_push("layer:BN2") + out = self.bn2(out) + nvtx.range_pop() + + nvtx.range_push("layer:ReLU") + out = self.relu(out) + nvtx.range_pop() + + nvtx.range_push("layer:Conv3") + out = self.conv3(out) + nvtx.range_pop() + + nvtx.range_push("layer:BN3") + out = self.bn3(out) + nvtx.range_pop() + + if self.downsample is not None: + nvtx.range_push("layer:Downsample") + identity = self.downsample(x) + nvtx.range_pop() + + nvtx.range_push("layer:Residual") + out += identity + nvtx.range_pop() + + nvtx.range_push("layer:ReLU") + out = self.relu(out) + nvtx.range_pop() + + nvtx.range_pop() + + return out + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, + groups=1, width_per_group=64, norm_layer=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def forward(self, x): + + nvtx.range_push("layer:conv1_x") + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + nvtx.range_pop() + + nvtx.range_push("layer:conv2_x") + x = self.layer1(x) + nvtx.range_pop() + + nvtx.range_push("layer:conv3_x") + x = self.layer2(x) + nvtx.range_pop() + + nvtx.range_push("layer:conv4_x") + x = self.layer3(x) + nvtx.range_pop() + + nvtx.range_push("layer:conv5_x") + x = self.layer4(x) + nvtx.range_pop() + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + nvtx.range_push("layer:FC") + x = self.fc(x) + nvtx.range_pop() + + return x + + +def resnet50(): + return ResNet(Bottleneck, [3, 4, 6, 3]) + +#Create model +net = resnet50().cuda().half() +net.train() + +#Create optimizer +criterion = nn.CrossEntropyLoss().cuda() +optimizer = torch.optim.SGD(net.parameters(), lr = 0.01, momentum=0.9) + +#Create synthetic input and label +x = torch.rand(32, 3, 224, 224).cuda().half() +target = torch.empty(32, dtype=torch.long).random_(1000).cuda() + +with torch.autograd.profiler.emit_nvtx(): + profiler.start() + output = net(x) + loss = criterion(output, target) + optimizer.zero_grad() + loss.backward() + optimizer.step() + profiler.stop() diff --git a/apex/pyprof/examples/user_annotation/test.sh b/apex/pyprof/examples/user_annotation/test.sh new file mode 100755 index 000000000..89b8eb0c8 --- /dev/null +++ b/apex/pyprof/examples/user_annotation/test.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +set -e + +SCRIPT=`realpath $0` +SCRIPTPATH=`dirname $SCRIPT` +PYPROF="$SCRIPTPATH/../.." + +parse="python $PYPROF/parse/parse.py" +prof="python $PYPROF/prof/prof.py" + +for f in *.py +do + base=`basename $f .py` + sql=$base.sql + dict=$base.dict + + #NVprof + echo "nvprof -fo --profile-from-start off $sql python $f" + nvprof -fo $sql --profile-from-start off python $f + + #Parse + echo $parse $sql + $parse $sql > $dict + + #Prof + echo $prof $dict + #$prof -w 130 $dict + $prof --csv -c idx,layer,dir,mod,op,kernel,params,sil $dict + \rm $sql $dict +done diff --git a/apex/pyprof/nvtx/__init__.py b/apex/pyprof/nvtx/__init__.py new file mode 100644 index 000000000..774b4c24f --- /dev/null +++ b/apex/pyprof/nvtx/__init__.py @@ -0,0 +1,2 @@ +from .nvmarker import init +from .nvmarker import add_wrapper as wrap diff --git a/apex/pyprof/nvtx/nvmarker.py b/apex/pyprof/nvtx/nvmarker.py new file mode 100644 index 000000000..754c5e00e --- /dev/null +++ b/apex/pyprof/nvtx/nvmarker.py @@ -0,0 +1,215 @@ +""" +This file intercepts (monkey patches) the following functions and adds NVTX markers. + torch.* + torch.Tensor.* + torch.nn.functional.* + torch.nn.*.forward + +The NVTX markers (one or more) contain the following information + call trace (a list of file_name:line_number) + extra_repr() from torch.nn modules + module/class name + function name + inputs (args and kwargs) + scalar: name, type and value + tensor: name, shape and datatype + numpy: name, shape and datatype + list/tuple: a sequence of scalars or tensors or numpy arrays +""" + +import torch +import torch.cuda.nvtx as nvtx +import numpy +import inspect as ins +import traceback +import math + +def isfunc(mod, f): + assert hasattr(mod, f) + attr = getattr(mod, f) + + #Ignore functions like _add + if (len(f) >= 2): + if f[0] == "_" and f[1] != "_": + return False + + #Ignore functions from this list + ignore = ['__all__', '__array__', '__array_priority__', '__array_wrap__', '__bool__', '__builtins__', '__cached__', '__class__', '__deepcopy__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__file__', '__format__', '__getattribute__', '__getitem__', '__hash__', '__index__', '__init__', '__init_subclass__', '__iter__', '__len__', '__loader__', '__module__', '__name__', '__new__', '__nonzero__', '__package__', '__path__', '__reduce__', '__reduce_ex__', '__repr__', '__reversed__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__spec__', '__str__', '__subclasshook__', '__version__', '__weakref__'] + + #Add functions to this list if they cause recursion + ignore += ['size', 'tolist', 'dim', 'is_storage', 'item'] + if f in ignore: + return False + + return ins.ismethod(attr) or ins.isfunction(attr) or ins.ismethoddescriptor(attr) or ins.isbuiltin(attr) + +def traceMarker(stack): + d = {} + cadena = [] + for i in range(len(stack)-1): + fi = stack[i] + t = "{}:{}".format(fi.filename, fi.lineno) + cadena.append(t) + d['traceMarker'] = cadena + return str(d) + +def modMarker(mod, fn_name, args): + """ + Returns the stringified extra_repr() of a module. + """ + assert(fn_name == 'forward') + assert(len(args) > 0) + d = {} + d['mod'] = mod.__name__ + d['strRepr'] = args[0].extra_repr() + return str(d) + +def add_wrapper(mod, fn_name): + assert isfunc(mod, fn_name) + + # Get a pointer to the original function + func = getattr(mod, fn_name) + + # Check if the mod has a string representation + # and is not a Script or Traced module (used by JIT) + s = hasattr(mod, "extra_repr") and (type(mod) is not torch.jit.ScriptModule) and (type(mod) is not torch.jit.TopLevelTracedModule) + + def wrapper_func(*args, **kwargs): + + # Extract the stacktrace + stack = traceback.extract_stack() + + # Push trace marker + nvtx.range_push(traceMarker(stack)) + + # Push module marker + if s: + m = modMarker(mod, fn_name, args) + nvtx.range_push(m) + + # Create and push argument marker + cadena = argMarker(mod, fn_name, args, kwargs) + nvtx.range_push(cadena) + + # Call the original function + result = func(*args, **kwargs) + + # Pop argumet marker + nvtx.range_pop() + + # Pop module marker + if s: + nvtx.range_pop() + + # Pop trace marker + nvtx.range_pop() + + return result + setattr(mod, fn_name, wrapper_func) + +def argMarker(mod, op, args, kwargs): + #For this function args is a tuple and kwargs is a dict + + def tensor(arg, name=""): + a = {} + a['name'] = name + a['type'] = "tensor" + a['shape'] = tuple(arg.size()) + a['dtype'] = str(arg.dtype).split(".")[-1] + cadena['args'].append(a) + + def ndarray(arg, name=""): + a = {} + a['name'] = name + a['type'] = "ndarray" + a['shape'] = arg.shape + a['dtype'] = str(arg.dtype).split(".")[-1] + cadena['args'].append(a) + + def seq(arg, name=""): + assert issequence(arg) + a = {} + a['name'] = name + if isinstance(arg, list): + a['type'] = "list" + a['value'] = arg + else: + a['type'] = "tuple" + # The arg could be torch.Size, which is a subclass of tuple + # Therefore, explicitly convert to tuple + a['value'] = tuple(arg) + + cadena['args'].append(a) + + def scalar(arg, name=""): + a = {} + a['name'] = name + a['type'] = type(arg).__name__ + #handle the case when the argument is +/- inf or nan + if arg == float('inf'): + a['value'] = "inf" + elif arg == float('-inf'): + a['value'] = "-inf" + elif isinstance(arg, float) and math.isnan(arg): + a['value'] = "nan" + else: + a['value'] = arg + cadena['args'].append(a) + + def isscalar(arg): + return (type(arg) is int) or (type(arg) is float) or (type(arg) is bool) or (arg is None) or (type(arg) is str) + + def issequence(arg): + return isinstance(arg, list) or isinstance(arg, tuple) + + def foo(args, name): + #args should be an iterable sequence e.g. list or tuple + for arg in args: + if isinstance(arg, torch.Tensor): + if arg.dim() == 0: + scalar(arg.item(), name) + else: + tensor(arg, name) + elif isinstance(arg, numpy.ndarray): + ndarray(arg, name) + elif (isscalar(arg)): + scalar(arg, name) + elif issequence(arg): + if (len(arg) == 0) or isscalar(arg[0]): #An empty sequence or a sequence of scalars + seq(arg, name) + else: # A sequence of tensors or numpy arrays + foo(arg, name) + ''' + else: + print("The following arg is none of Tensor, numpy array, scalar but a %s" % (str(type(arg)))) + print("Mod: %s" % str(mod.__name__)) + print("Op: %s" % str(op)) + print(dir(arg)) + ''' + + cadena = {} + cadena['mod'] = mod.__name__ + cadena['op'] = op + cadena['args'] = [] + + foo(args, "") + for k,v in kwargs.items(): + foo((v,), k) + + return str(cadena) + +def patchClass(cls): + for f in dir(cls): + if isfunc(cls, f): + add_wrapper(cls, f) + +def init(): + print("Initializing NVTX monkey patches") + for cls in [torch, torch.Tensor, torch.nn.functional,]: + patchClass(cls) + + for cls in [torch.nn.RNN, torch.nn.RNNCell, torch.nn.LSTM, torch.nn.LSTMCell, torch.nn.GRU, torch.nn.GRUCell]: + if isfunc(cls, 'forward'): + add_wrapper(cls, 'forward') + + print("Done with NVTX monkey patching") diff --git a/apex/pyprof/parse/__init__.py b/apex/pyprof/parse/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apex/pyprof/parse/__main__.py b/apex/pyprof/parse/__main__.py new file mode 100644 index 000000000..fa1291992 --- /dev/null +++ b/apex/pyprof/parse/__main__.py @@ -0,0 +1,10 @@ +import warnings + +try: + from .parse import main +except ImportError as e: + warnings.warn("Did you make sure to install PyProf dependencies by using the --pyprof flag during Apex installation?)") + raise e + +if __name__ == '__main__': + main() diff --git a/apex/pyprof/parse/db.py b/apex/pyprof/parse/db.py new file mode 100644 index 000000000..55c5dc571 --- /dev/null +++ b/apex/pyprof/parse/db.py @@ -0,0 +1,61 @@ +import sys, sqlite3 + +class DB(object): + """ + This class provides functions for DB operations + with exception handling. + """ + + def __init__(self, dbFile): + try: + conn = sqlite3.connect(dbFile) + conn.row_factory = sqlite3.Row + c = conn.cursor() + except: + print("Error opening {}".format(dbFile)) + sys.exit(1) + + self.conn = conn + self.c = c + + def select(self, cmd): + try: + self.c.execute(cmd) + #rows = self.c.fetchall() + rows = [dict(row) for row in self.c.fetchall()] + except sqlite3.Error as e: + print(e) + sys.exit(1) + except: + print("Uncaught error in SQLite access while executing {}".format(cmd)) + sys.exit(1) + + #print(rows) + return rows + + def insert(self, cmd, data): + try: + self.c.execute(cmd, data) + except sqlite3.Error as e: + print(e) + sys.exit(1) + except: + print("Uncaught error in SQLite access while executing {}".format(cmd)) + sys.exit(1) + + def execute(self, cmd): + try: + self.c.execute(cmd) + except sqlite3.Error as e: + print(e) + sys.exit(1) + except: + print("Uncaught error in SQLite access while executing {}".format(cmd)) + sys.exit(1) + + def commit(self): + self.conn.commit() + + def close(self): + self.c.close() + self.conn.close() diff --git a/apex/pyprof/parse/kernel.py b/apex/pyprof/parse/kernel.py new file mode 100644 index 000000000..5a66e290d --- /dev/null +++ b/apex/pyprof/parse/kernel.py @@ -0,0 +1,210 @@ +import cxxfilt, struct, binascii + +#Helper functions + +def demangle(name): + """ + Demangle a C++ string + """ + return cxxfilt.demangle(name) + +def encode_object_id(pid, tid): + """ + Given process id (pid) and thread id (tid), return the object id. + object id = pid (little endian 4 bytes) + tid (little endian 8 bytes) + """ + objId = struct.pack(' start, "This assertion can fail for very large profiles. It usually fails when start = end = 0." + self.kStartTime = start + self.kEndTime = end + self.kDuration = end - start + assert (start > Kernel.profStart) + self.device = int(info['deviceId']) + self.stream = int(info['streamId']) + self.grid = (info['gridX'], info['gridY'], info['gridZ']) + self.block = (info['blockX'], info['blockY'], info['blockZ']) + self.timeOffset = Kernel.profStart + + def setKernelName(self, name): + cadena = demangle(name) + self.kLongName = cadena + self.kShortName = getShortName(cadena) + + def setRunTimeInfo(self, info): + start, end, pid, tid = info + self.rStartTime = start + self.rEndTime = end + self.rDuration = end - start + self.pid = pid + self.tid = tid + self.objId = encode_object_id(pid, tid) + + def setMarkerInfo(self, info): + self.layerMarkers, self.traceMarkers, self.reprMarkers, self.pyprofMarkers, self.seqMarkers, self.otherMarkers, self.altMarkers, self.seqId, self.altSeqId, self.layer = info + self.subSeqId = 0 + + def setDirection(self): + """ + Set direction (fprop, bprop) based on PyTorch sequence markers. + It is a heuristic and not a foolproof method. + """ + if any("Backward, seq = " in x for x in self.seqMarkers) or \ + any("backward, seq = " in x for x in self.seqMarkers) or \ + any("Backward0, seq = " in x for x in self.seqMarkers): + self.dir = "bprop" + else: + self.dir = "fprop" + + def setOp(self): + """ + Detect and set the class/module (mod) and operation (op) + of the kernel e.g. torch.nn.functional / linear, torch / sigmoid. + The lookup sequence we use is + NVTX markers inserted by pyprof + NVTX markers inserted by PyTorch in bprop + NVTX markers inserted by PyTorch in fprop + It is a heuristic and not a foolproof method. + """ + + def sanitize(name): + name = name.replace("torch","") \ + .replace("autograd","") \ + .replace("_backward","") \ + .replace("::","") \ + .replace("jit","") \ + .replace("(anonymous namespace)","") + head, sep, tail = name.partition("Backward") + return head + + #Check pyprof markers + for m in self.pyprofMarkers: + assert ("mod" in m) and ("op" in m) and ("args" in m) + t = eval(m) + self.op.append(t['op']) + self.mod.append(t['mod']) + + if len(self.op): + return + + #Check bprop kernel markers + for m in self.seqMarkers: + if ("backward, seq = " in m) or ("Backward, seq = " in m): + op = m.split(",")[0] + op = sanitize(op) + self.op.append(op) + self.mod.append('na') + + if len(self.op): + return + + #Check markers with "seq = " + for m in self.seqMarkers: + if ", seq = " in m: + op = m.split(",")[0] + self.op.append(op) + self.mod.append('na') + + if len(self.op): + return + + #If nothing else + if len(self.otherMarkers): + self.op.append(self.otherMarkers[0]) + self.mod.append('na') + + def print(self): + """ + Print kernel information. This is used by prof.py. + """ + + a = lambda: None + a.kShortName = self.kShortName + a.kDuration = self.kDuration + #a.layerMarkers = self.layerMarkers + a.layer = self.layer + a.trace = self.traceMarkers + a.reprMarkers = self.reprMarkers + a.marker = self.pyprofMarkers + a.seqMarker = self.seqMarkers + + a.seqId = self.seqId + a.subSeqId = self.subSeqId + a.altSeqId = self.altSeqId + + a.dir = self.dir + a.mod = self.mod + a.op = self.op + + a.tid = self.tid + a.device = self.device + a.stream = self.stream + a.grid = self.grid + a.block = self.block + a.kLongName = self.kLongName + + print(a.__dict__) diff --git a/apex/pyprof/parse/nvvp.py b/apex/pyprof/parse/nvvp.py new file mode 100644 index 000000000..7167a892d --- /dev/null +++ b/apex/pyprof/parse/nvvp.py @@ -0,0 +1,282 @@ +import sys + +class NVVP(object): + """ + This class gets kernel information from the SQL (nvvp) database. + """ + + driverT = "CUPTI_ACTIVITY_KIND_DRIVER" + runtimeT = "CUPTI_ACTIVITY_KIND_RUNTIME" + kernelT = "CUPTI_ACTIVITY_KIND_CONCURRENT_KERNEL" + markerT = "CUPTI_ACTIVITY_KIND_MARKER" + stringT = "StringTable" + + def __init__(self, db): + self.db = db + self.markerId = 0 + + def getProfileStart(self): + """ + Get the profile start time + """ + profStart = sys.maxsize + for table in [self.driverT, self.runtimeT, self.kernelT, self.markerT]: + colname = "timestamp" if table is self.markerT else "start" + cmd = "select {} from {} ORDER BY {} ASC LIMIT 1".format(colname, table, colname) + result = self.db.select(cmd) + assert(len(result) <= 1) + if (len(result) == 1): + assert(colname in result[0]) + t = result[0][colname] + if (t < profStart): + profStart = t + assert(profStart < sys.maxsize) + return profStart + + def getString(self, id_): + """ + Get the string associated with an id. + """ + cmd = "select value from {} where _id_ = {}".format(self.stringT, id_) + result = self.db.select(cmd) + assert (len(result) == 1) + return result[0]['value'] + + def createMarkerTable(self): + """ + Create a temporary table and index it to speed up repeated SQL quesries. + The table is an INNER JOIN of CUPTI_ACTIVITY_KIND_MARKER with itself. + """ + cmd = 'CREATE TEMPORARY TABLE marker AS SELECT \ + a._id_ as id, \ + a.timestamp AS startTime, \ + b.timestamp AS endTime, \ + HEX(a.objectId) AS objectId, \ + a.name AS name \ + FROM {} AS a INNER JOIN {} AS b ON \ + a.id = b.id and \ + a.flags = 2 and b.flags = 4'.format(self.markerT, self.markerT) + self.db.execute(cmd) + + self.db.execute('CREATE INDEX start_index ON marker (startTime)') + self.db.execute('CREATE INDEX end_index ON marker (endTime)') + self.db.execute('CREATE INDEX id_index ON marker (id)') + + def getCPUInfo(self, corrId): + """ + Given the correlation id, get CPU start, end, thread id, process id. + The information can be in the runtime table or the driver table. + """ + + #First look in the runtime table + cmd = "select start,end,processId,threadId from {} where correlationId={}".format(self.runtimeT, corrId); + result = self.db.select(cmd) + assert (len(result) <= 1) + + if (len(result) == 0): + #Look in the driver table + cmd = "select start,end,processId,threadId from {} where correlationId={}".format(self.driverT, corrId); + result = self.db.select(cmd) + + assert (len(result) == 1) + info = result[0] + start = info['start'] + end = info['end'] + pid = info['processId'] + tid = info['threadId'] + tid = tid & 0xffffffff #convert to unsigned + assert (end > start) + return [start, end, pid, tid] + + def getKernelInfo(self): + """ + Get GPU kernel info + """ + cmd = "select name,correlationId,start,end,deviceId,streamId,gridX,gridY,gridZ,blockX,blockY,blockZ from {}".format(self.kernelT) + result = self.db.select(cmd) + return result + + def getMarkerInfo(self, objId, startTime, endTime): + """ + This function first finds all NVTX markers encapsulating + a runtime / driver kernel launch. + It then splits the markers into many lists. + layerMarkers : User added NVTX markers + traceMarkers : Call trace markers (inserted by pyprof) + reprMarkers : Markers containing the extra_repr() of a module (inserted by pyprof) + pyprofMarkers: Markers containing args and kwargs (tensor shape, datatype etc.) + seqMarkers : Markers containing PyTorch internal sequence markers (inserted by PyTorch) + altSeqMarkers: Markers inserted by PyTorch between two kernel launches. Needs better explanation. + otherMarkers : Markers not in either of the above categories. + + We extract seqId from the seq and altSeq markers. The seqId is used in bprop. + We also extract information from the layerMarkers. + """ + + layerMarkers = [] + traceMarkers = [] + reprMarkers = [] + pyprofMarkers = [] + seqMarkers = [] + otherMarkers = [] + altSeqMarkers = [] + bprop = False + + #Helper functions + + def delete(objId, sTime): + """ + Delete rows from the temporary SQL table which are no longer required. + This speeds up future queries. + """ + margin = 0 + cmd = 'DELETE FROM marker WHERE objectId = "{}" AND endTime < {}'.format(objId, sTime - margin) + #cmd = 'DELETE FROM marker WHERE endTime < {}'.format(sTime - margin) + self.db.execute(cmd) + + def getLayerName(mlist): + """ + Get layer names from layer marker list. + """ + layers = [] + assert(type(mlist) == list) + for m in mlist: + assert("layer:" in m) + l = m.split(":")[1] + layers.append(l) + return layers + + def getSeqId(mlist): + """ + Get sequence ids from seq / alt seq marker list. + """ + ids = [] + assert(type(mlist) == list) + for m in mlist: + assert(", seq = " in m) + seq = int(m.split("=")[1]) + ids.append(seq) + + #Remove duplicates + ids = list(set(ids)) + ids.sort() + return ids + + def seqcompare(elem): + """ + Sorting function for sequence markers + """ + assert (", seq = " in elem) + #sort by sequence id and then the string + l = elem.split(" = ") + return l[1] + l[0] + + def prune(mlist): + """ + Remove markers with the same seqId and if the strings are similar. + This function works on a sorted sequence. + """ + assert (type(mlist) == list) + assert (len(mlist)) + a = mlist[0:1] + for i in range(1,len(mlist)): + m = mlist[i] + pm = mlist[i-1] + name,seq = m.split(",") + pname,pseq = pm.split(",") + similar = (name in pname) or (pname in name) + if (seq == pseq) and similar: + continue + else: + a.append(m) + return a + + def filterTrace(mlist): + """ + Filter trace markers to remove certain file names. + """ + assert (type(mlist) == list) + if len(mlist) == 0: + return mlist + mlist = mlist[-1] #The last stack trace will be a super set. + mlist = eval(mlist) + mlist = mlist['traceMarker'] + assert (type(mlist) == list) + mlist = list(filter(lambda x : "/torch/nn/modules/" not in x, mlist)) + mlist = list(filter(lambda x : "/torch/nn/functional.py" not in x, mlist)) + mlist = list(filter(lambda x : "/torch/tensor.py" not in x, mlist)) + mlist = list(filter(lambda x : "/torch/autograd/__init__.py" not in x, mlist)) + mlist = list(filter(lambda x : "/torch/_jit_internal.py" not in x, mlist)) + mlist = list(filter(lambda x : "/pyprof/nvtx/nvmarker.py" not in x, mlist)) + mlist = list(filter(lambda x : "/apex/optimizers/" not in x, mlist)) + mlist = list(filter(lambda x : "/torch/_utils.py" not in x, mlist)) + mlist = list(filter(lambda x : "/torch/optim/" not in x, mlist)) + return mlist + + #Find all encapsulating markers + cmd = 'SELECT id,name from marker where \ + objectId = "{}" and \ + startTime < {} and \ + endTime > {} \ + ORDER BY startTime ASC'.format(objId, startTime, endTime) + result = self.db.select(cmd) + + #Bin markers into different lists + for r in result: + m = self.getString(r['name']) + + #Hack: If its a known gradient checkpointing marker, ignore it. + if m.find("CheckpointFunctionBackward") >= 0: + continue + + if ("_backward, seq =" in m) or ("Backward, seq =" in m) or ("Backward0, seq =" in m): + bprop = True + + if ("mod" in m) and ("op" in m) and ("args" in m) and ("type" in m): + pyprofMarkers.append(m) + elif ("layer:" in m): + layerMarkers.append(m) + elif ("traceMarker" in m): + traceMarkers.append(m) + elif ("strRepr" in m): + reprMarkers.append(m) + elif (", seq = " in m): + seqMarkers.append(m) + else: + otherMarkers.append(m) + + #Remove duplicates, sort and prune seqMarkers + if (len(seqMarkers)): + seqMarkers = list(set(seqMarkers)) + seqMarkers.sort(key=seqcompare) + seqMarkers = prune(seqMarkers) + + #Remove duplicates from otherMarkers + otherMarkers = list(set(otherMarkers)) + + #Get markers with seq id (inserted by PyTorch) from the previous kernel to the present kernel + #Only for fprop kernels + if (len(result) and not bprop): + loId = self.markerId + hiId = result[-1]['id'] + self.markerId = hiId + + #Get markers between loId and hiId + cmd = 'SELECT id,name from marker where objectId = "{}" and id > {} and id < {} ORDER BY startTime ASC'.format(objId, loId, hiId) + result1 = self.db.select(cmd) + + for r in result1: + m = self.getString(r['name']) + #Get only markers with seq id + if (", seq=" in m): + altSeqMarkers.append(m) + + #Remove duplicates, sort and prune altSeqMarkers + if (len(altSeqMarkers)): + altSeqMarkers = list(set(altSeqMarkers)) + altSeqMarkers.sort(key=seqcompare) + altSeqMarkers = prune(altSeqMarkers) + + delete(objId, startTime) + + return layerMarkers, filterTrace(traceMarkers), reprMarkers, pyprofMarkers, seqMarkers, otherMarkers, altSeqMarkers, getSeqId(seqMarkers), getSeqId(altSeqMarkers), getLayerName(layerMarkers) diff --git a/apex/pyprof/parse/parse.py b/apex/pyprof/parse/parse.py new file mode 100755 index 000000000..c119fe5e0 --- /dev/null +++ b/apex/pyprof/parse/parse.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python3 + +""" +Parse the SQL db and print a dictionary for every kernel. +""" + +import sys +import argparse +from tqdm import tqdm + +from .db import DB +from .kernel import Kernel +from .nvvp import NVVP + +def parseArgs(): + parser = argparse.ArgumentParser(prog=sys.argv[0], description="Parse SQL (nvvp) db.") + parser.add_argument("file", + type=str, + default=None, + help="SQL db (nvvp) file.") + + args = parser.parse_args() + return args + +def main(): + args = parseArgs() + + db = DB(args.file) + nvvp = NVVP(db) + + kInfo = nvvp.getKernelInfo() + if len(kInfo) == 0: + print("Found 0 kernels. Exiting.", file=sys.stderr) + db.close() + sys.exit(0) + else: + print("Found {} kernels. Getting info for each kernel.".format(len(kInfo)), file=sys.stderr) + + nvvp.createMarkerTable() + + prevSeqId = -1 + prevSubSeqId = -1 + prevOp = "na" + + Kernel.profStart = nvvp.getProfileStart() + + for i in tqdm(range(len(kInfo)), ascii=True): + info = kInfo[i] + k = Kernel() + + #Set kernel info + k.setKernelInfo(info) + + #Get, set kernel name + name = nvvp.getString(k.kNameId) + k.setKernelName(name) + + #Get runtime info + info = nvvp.getCPUInfo(k.corrId) + k.setRunTimeInfo(info) + + #Get and set marker and seqid info + info = nvvp.getMarkerInfo(k.objId, k.rStartTime, k.rEndTime) + k.setMarkerInfo(info) + + #If the seqId contains both 0 and non zero integers, remove 0. + if any(seq != 0 for seq in k.seqId) and (0 in k.seqId): + k.seqId.remove(0) + + #Set direction (it uses seq id) + k.setDirection() + + #Set op + k.setOp() + + #The following code is based on heuristics. + #TODO: Refactor. + #Assign subSeqId, adjust seqId and altSeqId + #seqId can be 0. + #A kernel can have multiple seqIds both in fprop and bprop. + #In bprop, seqIds might not decrease monotonically. I have observed a few blips. + if len(k.seqId): + assert (k.dir in ["fprop", "bprop"]) + if (k.dir == "fprop"): + #Check if there is a sequence id larger than the previous + inc = (k.seqId[-1] > prevSeqId) + if inc: + currSeqId = [x for x in k.seqId if x > prevSeqId][0] + else: + currSeqId = prevSeqId + else: + currSeqId = k.seqId[0] + + #if ((currSeqId == prevSeqId) and (k.op == prevOp)): + if ((currSeqId == prevSeqId) and (k.op == prevOp)) or ((k.op[0] == "forward") and (k.op == prevOp) and (k.mod[0] in ["LSTMCell", "GRUCell", "RNNCell"])): + #The second condition is to trap cases when pytorch does not use cudnn for a LSTMCell. + k.subSeqId = prevSubSeqId + 1 + + prevSeqId = currSeqId + prevSubSeqId = k.subSeqId + prevOp = k.op + + #Keep currSeqId in k.seqId, move everything else to k.altSeqId + for s in k.seqId: + if s != currSeqId: + k.seqId.remove(s) + k.altSeqId.append(s) + + for s in k.altSeqId: + if s == currSeqId: + k.altSeqId.remove(s) + + k.altSeqId = list(set(k.altSeqId)) + if (len(k.altSeqId)): + (k.altSeqId).sort() + + k.print() + + db.close() + +if __name__ == '__main__': + main() diff --git a/apex/pyprof/prof/__init__.py b/apex/pyprof/prof/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/apex/pyprof/prof/__main__.py b/apex/pyprof/prof/__main__.py new file mode 100644 index 000000000..a114f7c94 --- /dev/null +++ b/apex/pyprof/prof/__main__.py @@ -0,0 +1,10 @@ +import warnings + +try: + from .prof import main +except ImportError as e: + warnings.warn("Did you make sure to install PyProf dependencies by using the --pyprof flag during Apex installation?") + raise e + +if __name__ == '__main__': + main() diff --git a/apex/pyprof/prof/activation.py b/apex/pyprof/prof/activation.py new file mode 100644 index 000000000..e2443b093 --- /dev/null +++ b/apex/pyprof/prof/activation.py @@ -0,0 +1,65 @@ +from collections import OrderedDict +from .utility import Utility +from .base import OperatorLayerBase + +class Activation(OperatorLayerBase): + """ + This class handles the various activation functions. + """ + + ops = ["celu", "elu", "elu_", "hardshrink", "hardtanh", "hardtanh_", "leaky_relu", "leaky_relu_", "logsigmoid", "prelu", "relu", "relu_", "relu6", "rrelu", "rrelu_", "selu", "sigmoid", "softplus", "softshrink", "softsign", "tanh", "tanhshrink", "threshold", "threshold_"] + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod in ["torch.nn.functional", "torch", "Tensor"]) + + #Filter out named parameters + args = list(filter(lambda x : x['name'] == '', args)) + + assert (len(args) >= 1) + arg = args[0] + assert (arg['type'] == "tensor") + + self.i = arg + self.dir = d.dir + + def params(self): + p = OrderedDict([('T', self.i['shape']),('type', self.i['dtype'])]) + return p + + def flops(self): + direction = self.dir + tensor = self.i['shape'] + t = self.i['dtype'] + + # TODO: revise + elems = Utility.numElems(tensor) + return elems + + def bytes(self): + direction = self.dir + tensor = self.i['shape'] + t = self.i['dtype'] + + elems = Utility.numElems(tensor) + elems = elems * (2 if direction == "fprop" else 3) + + return elems * Utility.typeToBytes(t) + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ diff --git a/apex/pyprof/prof/base.py b/apex/pyprof/prof/base.py new file mode 100644 index 000000000..adfc4ce8a --- /dev/null +++ b/apex/pyprof/prof/base.py @@ -0,0 +1,47 @@ +from abc import ABC, abstractmethod + +class OperatorLayerBase(ABC): + """ + Base class for all layers and operators. + Every derived class should have the following functions. + """ + + @abstractmethod + def tc(self): + """ + Tensor core usage by the kernel. + Return "1" (yes), "0" (no, but possible), "-" (not applicable) + """ + pass + + @abstractmethod + def params(self): + """ + Kernel parameters to be printed. + """ + pass + + @abstractmethod + def flops(self): + """ + Note that 1 FMA = 2 flops. + """ + pass + + @abstractmethod + def bytes(self): + pass + + @abstractmethod + def mod(self): + """ + Name of the module/class e.g. torch.nn.functional. + """ + pass + + @abstractmethod + def op(self): + """ + Name of the operator e.g. sigmoid. + """ + pass diff --git a/apex/pyprof/prof/blas.py b/apex/pyprof/prof/blas.py new file mode 100644 index 000000000..65b445d45 --- /dev/null +++ b/apex/pyprof/prof/blas.py @@ -0,0 +1,326 @@ +from collections import OrderedDict +from .utility import Utility +from .base import OperatorLayerBase +import numpy as np + +class Addmm(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod in ["torch", "Tensor",]) + assert (op in ["addmm", "addmm_",]) + + #Get alpha and beta + alpha = 1 + beta = 1 + if any(x['name'] == 'alpha' for x in args): + alpha = list(filter(lambda x : x['name'] == "alpha", args))[0] + alpha = alpha['value'] + + if any(x['name'] == 'beta' for x in args): + beta = list(filter(lambda x : x['name'] == "beta", args))[0] + beta = beta['value'] + + self.alpha = alpha + self.beta = beta + + #Filter out named parameters + args = list(filter(lambda x : x['name'] == '', args)) + + assert (len(args) == 3) + C,A,B = args + m,k1 = A['shape'] + k2,n = B['shape'] + assert (k1 == k2) + t1 = A['dtype'] + t2 = B['dtype'] + t3 = C['dtype'] + assert(t1 == t2 == t3) + + self.A = A + self.B = B + self.C = C + + self.m = m + self.n = n + self.k = k1 + self.type = t1 + self.name = d.name + + return + + def tc(self): + return 1 if "884gemm" in self.name else 0 + + def bytes(self): + m, n, k = self.m, self.n, self.k + return Utility.typeToBytes(self.type) * (m*n + m*k + n*k) + + def flops(self): + return self.m * self.n * self.k * 2 + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def params(self): + p = OrderedDict([('M',self.n),('N',self.m),('K',self.k),('type',self.type)]) + return p + +class Bmm(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "torch") and (op == "bmm") + + #Filter out named params (kwargs) + args = list(filter(lambda x : x['name'] == "", args)) + + assert (len(args) == 2) + A,B = args + b1,m,k1 = A['shape'] + b2,k2,n = B['shape'] + assert (b1 == b2) + assert (k1 == k2) + t1 = A['dtype'] + t2 = B['dtype'] + assert(t1 == t2) + + self.A = A + self.B = B + self.b = b1 + self.m = m + self.n = n + self.k = k1 + self.type = t1 + self.name = d.name + + def tc(self): + return 1 if "884gemm" in self.name else 0 + + def params(self): + #p = OrderedDict([('A', A['shape']), ('B', B['shape']), ('type', t1)]) + p = OrderedDict([('B',self.b), ('M',self.n),('N',self.m),('K',self.k),('type',self.type)]) + return p + + def flops(self): + return self.b * self.m * self.n * self.k * 2 + + def bytes(self): + b, m, n, k = self.b, self.m, self.n, self.k + return Utility.typeToBytes(self.type) * b * (m*n + m*k + n*k) + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + +class Matmul(OperatorLayerBase): + + NON_GEMM = ["kernelPointwiseApply2", "reduce_1Block_kernel", "elementwise_kernel"] + NON_TC = NON_GEMM + ["dot_kernel"] + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + self.name = d.name + self.sub = d.sub + + assert ((mod == "torch") and (op == "matmul")) or ((mod == "Tensor") and (op == "__matmul__")) + assert (len(args) == 2) + + assert any([x in d.name for x in Matmul.NON_TC + ["gemm", "gemv"]]) + + A,B = args + t1 = A['dtype'] + t2 = B['dtype'] + assert(t1 == t2) + + A = A['shape'] + B = B['shape'] + + self.A = A + self.B = B + self.type = t1 + + # batch, MNK + if (len(A) == 1) and (len(B) == 1): + #dot product + assert (A[0] == B[0]) + self.b = (1,) + self.m = 1 + self.n = 1 + self.k = A[0] + + elif (len(A) == 2) and (len(B) == 2): + #gemm + m,k1 = A + k2,n = B + assert(k1 == k2) + self.b = (1,) + self.m = m + self.n = n + self.k = k1 + + elif (len(A) == 1) and (len(B) == 2): + #vector matrix + k1 = A[0] + k2,n = B + assert(k1 == k2) + + self.b = (1,) + self.m = 1 + self.n = n + self.k = k1 + + elif (len(A) == 2) and (len(B) == 1): + #gemv + m,k1 = A + k2 = B[0] + assert (k1 == k2) + + self.b = (1,) + self.m = m + self.n = 1 + self.k = k1 + + elif (len(A) == 1) and (len(B) > 2): + assert (A[0] == B[-2]) + + self.b = B[0:-2] + self.m = 1 + self.n = B[-1] + self.k = B[-2] + + elif (len(B) == 1) and (len(A) > 2): + assert (B[0] == A[-1]) + + self.b = A[0:-2] + self.m = A[-2] + self.n = 1 + self.k = A[-1] + + else: + assert (len(A) >= 2) + assert (len(B) >= 2) + assert (A[-1] == B[-2]) + self.m = A[-2] + self.n = B[-1] + self.k = A[-1] + + aa = np.empty(A[0:-2]) + bb = np.empty(B[0:-2]) + self.b = np.broadcast(aa, bb).shape + + def params(self): + return OrderedDict([('A', self.A), ('B', self.B), ('type', self.type)]) + + def tc(self): + if self.name in Matmul.NON_TC: + return "-" + else: + return 1 if "884gemm" in self.name else 0 + + def bytes(self): + # TODO: check bytes for non-GEMM cases + if self.name in Matmul.NON_GEMM: + return 2 * Utility.typeToBytes(self.type) * Utility.numElems(self.A) #could be B as well + else: + m, n, k = self.m, self.n, self.k + return Utility.typeToBytes(self.type) * (m*n + m*k + n*k) + + def flops(self): + # TODO: calculate actual FLOPs. At least we're not saying it's GEMM FLOPs for now. + if self.name in Matmul.NON_GEMM: + return 0 + else: + return Utility.numElems(self.b) * self.m * self.n * self.k * 2 + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + +class Mm(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "torch") and (op == "mm") + assert (len(args) == 2) + + A,B = args + m,k1 = A['shape'] + k2,n = B['shape'] + assert (k1 == k2) + t1 = A['dtype'] + t2 = B['dtype'] + assert(t1 == t2) + + self.A = A + self.B = B + self.m = m + self.n = n + self.k = k1 + self.type = t1 + self.name = d.name + + return + + def params(self): + p = OrderedDict([('M',self.n),('N',self.m),('K',self.k),('type',self.type)]) + return p + + def tc(self): + return 1 if "884gemm" in self.name else 0 + + def bytes(self): + m, n, k = self.m, self.n, self.k + return Utility.typeToBytes(self.type) * (m*n + m*k + n*k) + + def flops(self): + return self.m * self.n * self.k * 2 + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ diff --git a/apex/pyprof/prof/conv.py b/apex/pyprof/prof/conv.py new file mode 100644 index 000000000..23dbe1051 --- /dev/null +++ b/apex/pyprof/prof/conv.py @@ -0,0 +1,233 @@ +from collections import OrderedDict +from .utility import Utility +from .base import OperatorLayerBase + +class Conv(OperatorLayerBase): + + """ + # N = batch size + # C,H,W = input channels, height, width + # K,P,Q = output channels, height, width + # R,S = filter height, width + # g = groups + """ + + #todo: refine winograd and FFT + convAuxList = ["nchwToNhwc", "nhwcToNchw", "OffsetsKernel",] + winoAuxList = ["generateWinogradTilesKernel", "winogradWgradData", "winogradWgradOutput", "winogradWgradDelta"] + fftAuxList = ["compute_gemm_pointers", "flip_filter", "fft2d_r2c_", "fft2d_c2r_", "fft1d_r2c", "fft1d_c2r"] + miscAuxList = ["scaleTensor_kernel",] + + convList = ["_s884cudnn_", "_scudnn_", "2d_grouped_direct_kernel", "cudnn::detail::implicit_convolve_sgemm", "cudnn::detail::dgrad2d_alg1_1", "cudnn::detail::wgrad_alg0_engine", "cudnn::detail::dgrad_engine", "dgrad_1x1_stride_2x2", "spatialDepthwiseConvolutionUpdateOutput"] + winoList = ["winograd3x3Kernel", "_sgemm_"] + fftList = ["fermiPlusCgemmLDS128_batched", "_gcgemm_",] + miscList = [] + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + self.dir = d.dir + self.name = d.name + self.sub = d.sub + + assert (mod == "torch.nn.functional") + assert (op in ["conv1d", "conv2d"]) + length = len(args) + assert (length >= 2) and (length <= 7) + i,w = args[0], args[1] + assert (i['type'] == "tensor") + assert (w['type'] == "tensor") + + #ignore bias + + if (length >= 4) and (args[3]['name'] == ""): + s = args[3] + elif any(x['name'] == 'stride' for x in args): + s = list(filter(lambda x : x['name'] == 'stride', args))[0] + else: + s = {'name': 'stride', 'type': 'int', 'value': 1} + + if (length >= 5) and (args[4]['name'] == ""): + p = args[4] + elif any(x['name'] == 'padding' for x in args): + p = list(filter(lambda x : x['name'] == 'padding', args))[0] + else: + p = {'name': 'padding', 'type': 'int', 'value': 0} + + if (length >= 6) and (args[5]['name'] == ""): + d = args[5] + elif any(x['name'] == 'dilation' for x in args): + d = list(filter(lambda x : x['name'] == 'dilation', args))[0] + else: + d = {'name': 'dilation', 'type': 'int', 'value': 1} + + if (length == 7) and (args[6]['name'] == ""): + g = args[6] + elif any(x['name'] == 'groups' for x in args): + g = list(filter(lambda x : x['name'] == 'groups', args))[0] + else: + g = {'name': 'groups', 'type': 'int', 'value': 1} + + if op == "conv1d": + assert (len(i['shape']) == 3) + assert (len(w['shape']) == 3) + assert (i['dtype'] == w['dtype']) + N, C1, W = i['shape'] + K, C2, S = w['shape'] + assert (C1 == C2) + p = p['value'] if Utility.isscalar(p['type']) else p['value'][0] + s = s['value'] if Utility.isscalar(s['type']) else s['value'][0] + d = d['value'] if Utility.isscalar(d['type']) else d['value'][0] + g = g['value'] + assert (g == 1) + H = 1 + R = 1 + + P = 1 + (H - (((R-1))+1)) + Q = 1 + (W + 2*p - (((S-1)*d)+1))/s + P = int(P) + Q = int(Q) + if (H == 1): + assert (P == 1) + if (W == 1): + assert (Q == 1) + + self.N = N + self.C = C1 + self.H = H + self.W = W + self.K = K + self.P = P + self.Q = Q + self.R = R + self.S = S + self.ph = 0 + self.pw = p + self.U = 1 + self.V = s + self.dh = 1 + self.dw = d + self.g = g + self.type = i['dtype'] + + elif op == "conv2d": + assert (len(i['shape']) == 4) + assert (len(w['shape']) == 4) + assert (i['dtype'] == w['dtype']) + N, C1, H, W = i['shape'] + K, C2, R, S = w['shape'] + + if Utility.isscalar(p['type']): + ph = pw = p['value'] + else: + assert (p['type'] == "tuple") + ph, pw = p['value'] + + if Utility.isscalar(s['type']): + sh = sw = s['value'] + else: + assert (s['type'] == "tuple") + sh, sw = s['value'] + + if Utility.isscalar(d['type']): + dh = dw = d['value'] + else: + assert (d['type'] == "tuple") + dh, dw = d['value'] + + g = g['value'] + assert (g >= 1) + assert (C1 == C2*g) + + P = 1 + (H + 2*ph - (((R-1)*dh)+1))/sh + Q = 1 + (W + 2*pw - (((S-1)*dw)+1))/sw + P = int(P) + Q = int(Q) + if (H == 1): + assert (P == 1) + if (W == 1): + assert (Q == 1) + + self.N = N + self.C = C1 + self.H = H + self.W = W + self.K = K + self.P = P + self.Q = Q + self.R = R + self.S = S + self.ph = ph + self.pw = pw + self.U = sh + self.V = sw + self.dh = dh + self.dw = dw + self.g = g + self.type = i['dtype'] + + else: + assert False + + def params(self): + p = OrderedDict([('N',self.N), ('C',self.C), ('H',self.H), ('W',self.W), ('K',self.K), ('P',self.P), ('Q',self.Q), ('R',self.R), ('S',self.S), ('ph',self.ph), ('pw',self.pw), ('U',self.U), ('V',self.V), ('dh',self.dh), ('dw',self.dw), ('g',self.g), ('type',self.type)]) + return p + + def conv_bytes_flops(self, N, C, H, W, K, P, Q, R, S, g, t): + f = 2*N*K*P*Q*C*R*S/g #for fprop + elems = N*C*H*W + K*C*R*S/g + N*K*P*Q + b = elems * Utility.typeToBytes(t) + return b,f + + def bytes_flops(self): + N,C,H,W,K,P,Q,R,S,ph,pw,U,V,dh,dw,g,t = self.params().values() + + if any(x in self.name for x in Conv.convAuxList+Conv.winoAuxList+Conv.fftAuxList+Conv.miscAuxList): + bytes, flops = [0, 0] + + elif any(x in self.name for x in Conv.convList+Conv.winoList+Conv.fftList+Conv.miscList): + if g == 1: + bytes, flops = self.conv_bytes_flops(N,C,H,W,K,P,Q,R,S,g,t) + else: + if "2d_grouped_direct_kernel" in self.name: #only 1 kernel is called + bytes, flops = self.conv_bytes_flops(N,C,H,W,K,P,Q,R,S,g,t) + elif "spatialDepthwiseConvolutionUpdateOutput" in self.name: #one kernel for separable conv + bytes, flops = self.conv_bytes_flops(N,C,H,W,K,P,Q,R,S,g,t) + else: #a kernel per group is called + bytes, flops = self.conv_bytes_flops(N,C/g,H,W,K/g,P,Q,R,S,1,t) + + elif ("calc_bias_diff" in self.name): #bias gradient + elems = N*K*P*Q + flops = elems + bytes = 2 * elems * Utility.typeToBytes(t) + #params = OrderedDict([('N',N), ('K',K), ('P',P), ('Q',Q), ('type', t)]) + + else: + bytes, flops = [0, 0] + + return bytes, flops + + def bytes(self): + b,_ = self.bytes_flops() + return b + + def flops(self): + _,f = self.bytes_flops() + return f + + def tc(self): + return 1 if "884cudnn" in self.name else "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ diff --git a/apex/pyprof/prof/convert.py b/apex/pyprof/prof/convert.py new file mode 100644 index 000000000..0d6735a81 --- /dev/null +++ b/apex/pyprof/prof/convert.py @@ -0,0 +1,62 @@ +from collections import OrderedDict +from .utility import Utility +from .base import OperatorLayerBase + +class Convert(OperatorLayerBase): + """ + Class to handle convert operations. + """ + ops = ["byte", "char", "double", "float", "half", "int", "long", "short", "to"] + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "Tensor") + assert (op in Convert.ops) + assert (len(args) == 1) + + #The argument could be a tensor or scalar + t = args[0] + if t['type'] == "tensor": + shape = t['shape'] + stype = t['dtype'] + else: + shape = (1,) + stype = t['type'] + if self.op_ == "to": + op = stype + + self.shape = shape + self.stype = stype + self.dtype = op + + def params(self): + p = OrderedDict([('T', self.shape), ('stype', self.stype), ('dtype', self.dtype)]) + return p + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def tc(self): + return "-" + + def elems(self): + return Utility.numElems(self.shape) + + def flops(self): + return 0 + + def bytes(self): + b = self.elems() * (Utility.typeToBytes(self.stype) + Utility.typeToBytes(self.dtype)) + return b diff --git a/apex/pyprof/prof/data.py b/apex/pyprof/prof/data.py new file mode 100644 index 000000000..dcd323dea --- /dev/null +++ b/apex/pyprof/prof/data.py @@ -0,0 +1,54 @@ +from .utility import Utility + +class Data(object): + """ + Class to store all the data for every kernel e.g. name, bytes, flops, device, stream etc. + """ + def __init__(self, kernel): + #Available from NVprof + self.tid = kernel['tid'] + self.device = kernel['device'] + self.stream = kernel['stream'] + self.grid = str(kernel['grid']).replace(" ","").replace("(","").replace(")","") + self.block = str(kernel['block']).replace(" ","").replace("(","").replace(")","") + self.name = kernel['kShortName'].replace(" ","_") + self.lName = kernel['kLongName'] + self.sil = kernel['kDuration'] #units ns + + self.index = None + + #Markers + self.argMarker = kernel['marker'] + self.modMarker = kernel['reprMarkers'] + self.seqMarker = kernel['seqMarker'] + + self.layer = kernel['layer'] + self.trace = kernel['trace'] + + self.seqId = kernel['seqId'] + self.altSeqId = kernel['altSeqId'] + + self.dir = kernel['dir'] + self.sub = kernel['subSeqId'] + + self.mod = "na" + self.op = "na" + self.params = {"na":"na"} + self.tc = "na" + self.flops = 0 + self.bytes = 0 + + def setParams(self, params): + #Remove space from params + qaz = "" + for key,value in params.items(): + if "type" not in key: + qaz += "{}={},".format(key,value) + else: + if type(value) is str: + qaz += "{},".format(Utility.typeToString(value)) + else: + qaz += "{}".format(value) + + self.params = qaz.replace(" ", "") + diff --git a/apex/pyprof/prof/dropout.py b/apex/pyprof/prof/dropout.py new file mode 100644 index 000000000..47b352bbf --- /dev/null +++ b/apex/pyprof/prof/dropout.py @@ -0,0 +1,50 @@ +from collections import OrderedDict +from .utility import Utility +from .base import OperatorLayerBase + +class Dropout(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "torch.nn.functional") + assert (op == "dropout") + #assert (len(args) == 1) + + self.shape = args[0]['shape'] + self.type = args[0]['dtype'] + self.dir = d.dir + + return + + def params(self): + p = OrderedDict([('T', self.shape), ('type', self.type)]) + return p + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def tc(self): + return "-" + + def elems(self): + return Utility.numElems(self.shape) + + def bytes(self): + #Ignoring the cost of writing and reading the mask + return Utility.typeToBytes(self.type) * self.elems() * 2 + + def flops(self): + # Note: This is approximate and depends on the RNG + return 5*self.elems() diff --git a/apex/pyprof/prof/embedding.py b/apex/pyprof/prof/embedding.py new file mode 100644 index 000000000..32a1c35ea --- /dev/null +++ b/apex/pyprof/prof/embedding.py @@ -0,0 +1,71 @@ +from collections import OrderedDict +from .utility import Utility +from .base import OperatorLayerBase + +class Embedding(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "torch.nn.functional") + assert (op == "embedding") + + self.ishape = args[0]['shape'] + self.itype = args[0]['dtype'] + + self.eshape = args[1]['shape'] + self.etype = args[1]['dtype'] + + assert (len(self.eshape) == 2) + + self.dir = d.dir + self.sub = d.sub + return + + def params(self): + p = OrderedDict([('I', self.ishape), ('itype', self.itype), ('E', self.eshape), ('etype', self.etype)]) + return p + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def tc(self): + return "-" + + def bytes(self): + ishape = self.ishape + itype = self.itype + eshape = self.eshape + etype = self.etype + + ielems = Utility.numElems(ishape) + + b = 0 + if self.dir == "fprop": + #indices + b += ielems * Utility.typeToBytes(itype) + #read and write the embedding matrix + b += ielems * eshape[1] * 2 * Utility.typeToBytes(etype) + else: + #3 times the size of the incoming gradient + b = ielems * eshape[1] * 3 * Utility.typeToBytes(etype) + + if self.sub > 0: + b = 0 + + return b + + def flops(self): + # Note: not implemented yet + return 0 diff --git a/apex/pyprof/prof/index_slice_join_mutate.py b/apex/pyprof/prof/index_slice_join_mutate.py new file mode 100644 index 000000000..1ecbe60d7 --- /dev/null +++ b/apex/pyprof/prof/index_slice_join_mutate.py @@ -0,0 +1,419 @@ +from collections import OrderedDict +from .utility import Utility +import numpy as np +from .base import OperatorLayerBase + +class Cat(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "torch") + assert (op == "cat") + assert (len(args) >= 2) + + t = args[0]['dtype'] + shapes = [] + + for arg in args: + if arg['type'] == "tensor": + assert (arg['dtype'] == t) + shapes.append(arg['shape']) + + self.type = t + self.shapes = shapes + + def params(self): + p = OrderedDict([('T', self.shapes), ('type', self.type)]) + return p + + def flops(self): + return 0 + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def bytes(self): + b = 0 + for s in self.shapes: + b += Utility.numElems(s) + return 2 * b * Utility.typeToBytes(self.type) + +class Reshape(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "Tensor") + assert (op == "reshape") + + #Temporarily commenting three lines + #assert (len(args) == 2) + #t,s = args + #assert s['type'] == "tuple" + + t = args[0] + assert t['type'] == "tensor" + self.type = t['dtype'] + self.shape = t['shape'] + + def params(self): + p = OrderedDict([('T', self.shape), ('type', self.type)]) + return p + + def flops(self): + return 0 + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def bytes(self): + return 0 + +class Gather(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "Tensor") or (mod == "torch") + assert (op == "gather") + + #Filter out the "out" parameter + args = list(filter(lambda x : x['name'] != 'out', args)) + assert (len(args) == 3) + + #Get input + if (args[0]['name'] == ""): + arg = args[0] + else: + arg = list(filter(lambda x : x['name'] == "input", args))[0] + + assert (arg['type'] == "tensor") + + self.shape = arg['shape'] + self.type = arg['dtype'] + + def params(self): + p = OrderedDict([('T', self.shape),('type', self.type)]) + return p + + def flops(self): + return 0 + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def bytes(self): + return 2 * Utility.numElems(self.shape) * Utility.typeToBytes(self.type) + +class MaskedScatter(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "Tensor") + assert (op == "masked_scatter_") + assert (len(args) == 3) + + dst, mask, src = args + assert (dst['type'] == mask['type'] == src['type'] == "tensor") + assert (mask['dtype'] == "uint8") + assert (dst['dtype'] == src['dtype']) + assert (dst['shape'] == mask['shape']) + + self.shape = dst['shape'] + self.type = dst['dtype'] + self.seqId = d.seqId + + def params(self): + p = OrderedDict([('T', self.shape),('type', self.type)]) + return p + + def flops(self): + return 0 + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def bytes(self): + elems = Utility.numElems(self.shape) + + #src and dst + b = 2 * elems * Utility.typeToBytes(self.type) + + #mask (uint8) + b += elems + + if (self.seqId > 0): + b = 0 + return b + +class Nonzero(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod in ["torch", "Tensor"]) + assert (op == "nonzero") + assert (len(args) == 1) + + arg = args[0] + self.shape = arg['shape'] + self.type = arg['dtype'] + self.seqId = d.seqId + + def params(self): + p = OrderedDict([('T', self.shape),('type', self.type)]) + return p + + def flops(self): + return 0 + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def bytes(self): + elems = Utility.numElems(self.shape) + dim = len(self.shape) + + #input tensor + b = elems * Utility.typeToBytes(self.type) + + #in the worst case, the output is a (elems x dim) tensor of type "long" + b += elems * dim * Utility.typeToBytes("int64") + + if self.seqId > 0: + return 0 + else: + return b + +class IndexSelect(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "Tensor") or (mod == "torch") + assert (op == "index_select") + + #Filter out the "out" parameter + args = list(filter(lambda x : x['name'] != 'out', args)) + assert (len(args) == 3) + + #Get input, dim and index + if (args[0]['name'] == ""): + t = args[0] + else: + t = list(filter(lambda x : x['name'] == "input", args))[0] + + if (args[1]['name'] == ""): + d = args[1] + else: + d = list(filter(lambda x : x['name'] == "dim", args))[0] + + if (args[2]['name'] == ""): + i = args[2] + else: + i = list(filter(lambda x : x['name'] == "index", args))[0] + + assert (t['type'] == i['type'] == "tensor") + assert (d['type'] == "int") + assert (i['dtype'] == "int64") + assert (len(i['shape']) == 1) + + shape = t['shape'] + dim = d['value'] + indices = i['shape'][0] + assert (dim < len(shape)) + + self.shape = shape + self.dim = dim + self.indices = indices + self.type = t['dtype'] + + def params(self): + p = OrderedDict([('T', self.shape),('D', self.dim),('I', self.indices),('type', self.type)]) + return p + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def flops(self): + return 0 + + def bytes(self): + #determine the shape of the output tensor + shape = list(self.shape) + shape[self.dim] = self.indices + + b = 0 + + #time to read the input and write the output + elems = Utility.numElems(shape) + b += 2 * elems * Utility.typeToBytes(self.type) + + #time to read the indices + b += self.indices * Utility.typeToBytes("int64") + + return b + +class MaskedSelect(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + self.sub = d.sub + + assert (mod == "Tensor") or (mod == "torch") + assert (op == "masked_select") + + #Filter out the "out" parameter + args = list(filter(lambda x : x['name'] != 'out', args)) + assert (len(args) == 2) + + #Get input and mask + if (args[0]['name'] == ""): + t = args[0] + else: + t = list(filter(lambda x : x['name'] == "input", args))[0] + + if (args[1]['name'] == ""): + m = args[1] + else: + m = list(filter(lambda x : x['name'] == "mask", args))[0] + + assert (m['dtype'] == "uint8") + + tensor = t['shape'] + mask = m['shape'] + + #check for broadcast condition + if (tensor != mask): + array1 = np.empty(list(tensor)) + array2 = np.empty(list(mask)) + try: + out = np.broadcast(array1, array2).shape + except: + assert False + + self.tshape = tensor + self.mshape = mask + self.type = t['dtype'] + + def params(self): + p = OrderedDict([('T', self.tshape),('M', self.mshape),('type', self.type)]) + return p + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def bytes(self): + tensor = self.tshape + mask = self.mshape + t = self.type + + #in the worst case, #output elements = #input elements + b = 2 * Utility.numElems(tensor) * Utility.typeToBytes(t) + + #mask tensor (assuming uint8) + b += Utility.numElems(mask) + return b + + def flops(self): + return 0 diff --git a/apex/pyprof/prof/linear.py b/apex/pyprof/prof/linear.py new file mode 100644 index 000000000..7c9e13cbe --- /dev/null +++ b/apex/pyprof/prof/linear.py @@ -0,0 +1,188 @@ +from collections import OrderedDict +from .utility import Utility +from .base import OperatorLayerBase + +class Linear(OperatorLayerBase): + + ''' + Notes: + If the bias occurs before the GEMM, then its 1 write (bias expansion). + If the bias occurs after, then its 1 read and 1 write. + bias in bprop is a reduction and hence is 1 read. + ''' + + gemmKernels = ["gemm", "gemv", "dot_kernel", "splitKreduce_kernel", "reduce_1Block_kernel"] + biasKernels = ["kernelReduceContigDim", "kernelReduceNoncontigDim_shared", "elementwise_kernel", "reduce_kernel"] + + def setXWBMNK(self, args): + x = None + w = None + b = None + if (len(args) == 2): + x,w = args + elif (len(args) == 3): + x,w,b = args + assert (x['type'] == w['type'] == "tensor") + if (b['type'] == "tensor"): + assert(len(b['shape']) == 1) + elif (b['type'] == "NoneType"): + assert b['value'] is None + b = None + else: + assert False + else: + assert False + + assert(len(w['shape']) == 2) + k1 = x['shape'][-1] + n,k2 = w['shape'] + assert(k1 == k2) + if b is not None: + assert(b['shape'][0] == n) + t1 = x['dtype'] + t2 = w['dtype'] + assert(t1 == t2) + + # X, W, B + self.x = x['shape'] + self.w = w['shape'] + self.b = b['shape'] if b is not None else None + self.type = t1 + + # M, N, K + #n = Utility.numElems(x[0:-1]) + n = self.x[0:-1] + k = self.x[-1] + m,k1 = self.w + assert (k == k1) + + self.m = m + self.n = n + self.k = k + + def tc(self): + if self.op() == "linear": + return 1 if "884gemm" in self.name else 0 + else: + return "-" + + def __init__(self, d): + self.name = d.name + self.dir = d.dir + self.sub = d.sub + + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + assert (mod == "torch.nn.functional") + assert (op == "linear") + + self.setXWBMNK(args) + + if any(x in d.name for x in Linear.gemmKernels): + self.op_ = "linear" + else: + assert (d.name in Linear.biasKernels) + self.op_ = "bias" + + ''' + elif (("kernelPointwiseApply2" in d.name) or ("kernelReduceContigDim" in d.name) or ("kernelReduceNoncontigDim_shared" in d.name)): + #bias expansion was before the gemm + self.op_ = "bias" + + elif ("elementwise_kernel" in d.name): + #Bias addition happens later with a broadcast tensor + self.op_ = "bias" + assert (len(d.argMarker) == 2) + marker = eval(d.argMarker[1]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + assert (mod == "Tensor") + assert (op == "__iadd__") + assert (len(args) == 2) + mn = args[0]['shape'] + b = args[1]['shape'] + assert (len(b) == 1) + + assert (mn == (self.n + (self.m,))) + assert (b == self.b) + + else: + assert False + ''' + + def params(self): + #p = OrderedDict([('X', self.x), ('W', self.w), ('B', self.b), ('type', self.type)]) + + m, n, k, x, w, t = self.m, self.n, self.k, self.x, self.w, self.type + if len(n) == 1: + n = n[0] + + if self.op_ == "linear": + if self.dir == "fprop": + p = OrderedDict([('M', m), ('N', n), ('K', k), ('type', t)]) + elif self.dir == "bprop": + if self.sub == 0: #dgrad (most likely) + p = OrderedDict([('M', k), ('N', n), ('K', m), ('type', t)]) + elif self.sub == 1: #wgrad (most likely) + p = OrderedDict([('M', k), ('N', m), ('K', n), ('type', t)]) + else: + #This happens when there are additional kernels for reduction + p = OrderedDict([('X', x), ('W', w), ('type', t)]) + else: + assert False + + elif self.op_ == "bias": + p = OrderedDict([('M', m), ('N', n), ('type', t)]) + else: + assert False + return p + + def op(self): + return self.op_ + + def bytesFlops(self): + + m = self.m + n = Utility.numElems(self.n) + k = self.k + + if self.op_ == "linear": + if self.dir == "fprop": + f = m * n * k * 2 + b = m*n + m*k + n*k * Utility.typeToBytes(self.type) + elif self.dir == "bprop": + if self.sub == 0: #dgrad (most likely) + f = m * n * k * 2 + b = m*n + m*k + n*k * Utility.typeToBytes(self.type) + elif self.sub == 1: #wgrad (most likely) + f = m * n * k * 2 + b = m*n + m*k + n*k * Utility.typeToBytes(self.type) + else: + #This happens when there are additional kernels for reduction + f = 0 + b = 0 + else: + assert False + + elif self.op_ == "bias": + f = m * n + b = 2 * m * n * Utility.typeToBytes(self.type) + else: + assert False + return b,f + + def bytes(self): + b, f = self.bytesFlops() + return b + + def flops(self): + b, f = self.bytesFlops() + return f + + def mod(self): + return self.mod_ diff --git a/apex/pyprof/prof/loss.py b/apex/pyprof/prof/loss.py new file mode 100644 index 000000000..3fe8a5501 --- /dev/null +++ b/apex/pyprof/prof/loss.py @@ -0,0 +1,84 @@ +from collections import OrderedDict +from .utility import Utility +from .base import OperatorLayerBase + +#TODO: Add support for additional loss functions. + +class MSELoss(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "torch.nn.functional") + assert (op == "mse_loss") + assert (len(args) == 3) + + #Get input, target and reduction + if (args[0]['name'] == ""): + x = args[0] + else: + x = list(filter(lambda x : x['name'] == "input", args))[0] + + if (args[1]['name'] == ""): + y = args[1] + else: + y = list(filter(lambda x : x['name'] == "target", args))[0] + + if (args[2]['name'] == ""): + r = args[2] + else: + r = list(filter(lambda x : x['name'] == "reduction", args))[0] + + assert (x['type'] == y['type'] == "tensor") + assert (x['shape'] == y['shape']) + assert (x['dtype'] == y['dtype']) + assert (r['type'] == "str") + assert (r['value'] in ["none", "mean", "sum"]) + + self.shape = x['shape'] + self.type = x['dtype'] + self.red = r['value'] + self.dir = d.dir + + def params(self): + p = OrderedDict([('T', self.shape), ('type', self.type), ('red', self.red)]) + return p + + def elems(self): + red = self.red + e = Utility.numElems(self.shape) + + if self.dir == "fprop": + if red == "none": + e *= 3 + else: + e *= 2 + else: + if red == "none": + e *= 4 + else: + e *= 3 + return e + + def bytes(self): + return self.elems() * Utility.typeToBytes(self.type) + + def flops(self): + return self.elems() * 2 + 1 + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ diff --git a/apex/pyprof/prof/misc.py b/apex/pyprof/prof/misc.py new file mode 100644 index 000000000..e1c247d46 --- /dev/null +++ b/apex/pyprof/prof/misc.py @@ -0,0 +1,219 @@ +from collections import OrderedDict +from .utility import Utility +from .base import OperatorLayerBase + +class Foo(OperatorLayerBase): + """ + An object of Foo is instantiated when we detect an unsupported operator. + """ + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + shapes = [] + types = [] + + for arg in args: + if arg['type'] == "tensor": + shapes.append(arg['shape']) + types.append(arg['dtype']) + + self.shape = shapes + self.type = types + + def params(self): + p = OrderedDict([('T', self.shape), ('type', self.type)]) + return p + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def flops(self): + return 0 + + def bytes(self): + return 0 + +class Copy(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "Tensor") + assert (op == "copy_") + assert (len(args) == 2) + + dst, src = args + assert (src['type'] == dst['type']) + assert (src['shape'] == dst['shape']) + + self.shape = src['shape'] + self.stype = src['dtype'] + self.dtype = dst['dtype'] + + def params(self): + #The data type might be different + p = OrderedDict([('T', self.shape), ('stype', self.stype), ('dtype', self.dtype)]) + return p + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def flops(self): + return 0 + + def elems(self): + return Utility.numElems(self.shape) + + def bytes(self): + return self.elems() * (Utility.typeToBytes(self.stype) + Utility.typeToBytes(self.dtype)) + +class Clone(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "Tensor") + assert (op == "clone") + assert (len(args) == 1) + t = args[0] + self.shape = t['shape'] + self.type = t['dtype'] + + def params(self): + p = OrderedDict([('T', self.shape), ('type', self.type)]) + return p + + def flops(self): + return 0 + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def elems(self): + return Utility.numElems(self.shape) + + def bytes(self): + return 2 * self.elems() * Utility.typeToBytes(self.type) + +class Contiguous(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "Tensor") + assert (op == "contiguous") + assert (len(args) == 1) + t = args[0] + self.shape = t['shape'] + self.type = t['dtype'] + + def params(self): + p = OrderedDict([('T', self.shape), ('type', self.type)]) + return p + + def flops(self): + return 0 + + def bytes(self): + return 2 * Utility.numElems(self.shape) * Utility.typeToBytes(self.type) + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + +class Any(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "Tensor") + assert (op == "any") + assert (len(args) == 1) #could be 2 as well, the second argument is a bool + t = args[0] + + self.shape = t['shape'] + self.type = t['dtype'] + self.sub = d.sub + return + + def params(self): + p = OrderedDict([('T', self.shape), ('type', self.type)]) + return p + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def tc(self): + return "-" + + def flops(self): + return 0 + + def bytes(self): + return Utility.numElems(self.shape) * Utility.typeToBytes(self.type) diff --git a/apex/pyprof/prof/normalization.py b/apex/pyprof/prof/normalization.py new file mode 100644 index 000000000..c9c5ae0b1 --- /dev/null +++ b/apex/pyprof/prof/normalization.py @@ -0,0 +1,54 @@ +from collections import OrderedDict +from .utility import Utility +from .base import OperatorLayerBase + +class BatchNorm(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (op == "batch_norm") + assert (len(args) == 8) + i = args[0] + assert (i['type'] == "tensor") + + self.shape = i['shape'] + self.type = i['dtype'] + self.dir = d.dir + + def params(self): + p = OrderedDict([('T', self.shape), ('type', self.type)]) + return p + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def elems(self): + return Utility.numElems(self.shape) + + def flops(self): + # Variance algo-dependent, but this is a reasonable value. + return self.elems() * 8 + + def bytes(self): + e = self.elems() + if self.dir == "fprop": + e *= 4 + else: + e *= 5 + + return e * Utility.typeToBytes(self.type) diff --git a/apex/pyprof/prof/optim.py b/apex/pyprof/prof/optim.py new file mode 100644 index 000000000..f2c32759b --- /dev/null +++ b/apex/pyprof/prof/optim.py @@ -0,0 +1,65 @@ +from collections import OrderedDict +from .utility import Utility +from .base import OperatorLayerBase + +#TODO: Add support for other optimizers. + +class Adam(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert(op == "adam") + assert (len(args) == 12) or (len(args) == 14) + w, hw, m, v, g = args[0:5] + assert (w['shape'] == m['shape'] == v['shape'] == g['shape']) + assert (hw['shape'] == w['shape']) or (hw['shape'] == (0,)) #hw could be null + assert (w['type'] == m['type'] == v['type'] == g['type'] == hw['type'] == "tensor") + assert (w['dtype'] == m['dtype'] == v['dtype'] == "float32") + + self.w = w + self.g = g + + def params(self): + p = OrderedDict([('T',self.w['shape']), ('wtype',self.w['dtype']), ('gtype',self.g['dtype'])]) + return p + + def flops(self): + return 0 + + def bytes(self): + wshape = self.w['shape'] + wtype = self.w['dtype'] + gtype = self.g['dtype'] + b = 0 + + elems = Utility.numElems(wshape) + + #Get time to stream read/write w, m, v + b += 6 * elems * Utility.typeToBytes(wtype) + + #Get time to read "g" + b += elems * Utility.typeToBytes(gtype) + + if wtype != gtype: #mixed precision + #Get time to write "hw + b += elems * Utility.typeToBytes(gtype) + + return b + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ diff --git a/apex/pyprof/prof/output.py b/apex/pyprof/prof/output.py new file mode 100644 index 000000000..832514cf2 --- /dev/null +++ b/apex/pyprof/prof/output.py @@ -0,0 +1,149 @@ +import errno, os, sys + +class Output(): + """ + This class handles printing of a columed output and a CSV. + """ + + # The table below is organized as + # user_option: [output_header, attribute_in_Data_class, type, min_width_in_columed_output] + table = { + "idx": ["Idx", "index", int, 7], + "seq": ["SeqId", "seqId", str, 7], + "altseq": ["AltSeqId", "altSeqId", str, 7], + "tid": ["TId", "tid", int, 12], + "layer": ["Layer", "layer", str, 10], + "trace": ["Trace", "trace", str, 25], + "dir": ["Direction", "dir", str, 5], + "sub": ["Sub", "sub", int, 3], + "mod": ["Module", "mod", str, 15], + "op": ["Op", "op", str, 15], + "kernel": ["Kernel", "name", str, 0], + "params": ["Params", "params", str, 0], + "sil": ["Sil(ns)", "sil", int, 10], + "tc": ["TC", "tc", str, 2], + "device": ["Device", "device", int, 3], + "stream": ["Stream", "stream", int, 3], + "grid": ["Grid", "grid", str, 12], + "block": ["Block", "block", str, 12], + "flops": ["FLOPs", "flops", int, 12], + "bytes": ["Bytes", "bytes", int, 12] + } + + def __init__(self, args): + self.cols = args.c + self.csv = args.csv + self.col = True if (args.w > 0) else False + self.width = args.w + + w = 0 + for col in self.cols: + assert col in Output.table.keys() + w += Output.table[col][3] + + if ((self.col) and (w > self.width)): + print("Minimum width required to print {} = {}. Exiting.".format(",".join(self.cols), w)) + sys.exit(1) + + remainder = self.width - w + + if ("kernel" in self.cols) and ("params" in self.cols): + Output.table["kernel"][3] = int(remainder/2) + Output.table["params"][3] = int(remainder/2) + elif ("kernel" in self.cols): + Output.table["kernel"][3] = remainder + elif ("params" in self.cols): + Output.table["params"][3] = remainder + + #header format + cadena = "" + for col in self.cols: + _,_,t,w = Output.table[col] + cadena += "%-{}.{}s ".format(w,w) + + self.hFormat = cadena + + #data format + cadena = "" + for col in self.cols: + _,_,t,w = Output.table[col] + if (t == str): + cadena += "%-{}.{}s ".format(w,w) + elif (t == int): + cadena += "%{}d ".format(w) + + self.dFormat = cadena + + def foo(self, cadena, pformat): + if self.csv: + cadena = ",".join(map(lambda x : '"' + str(x) + '"', cadena)) + elif self.col: + cadena = pformat % cadena + else: + cadena = " ".join(map(str,cadena)) + + try: + print(cadena) + except IOError as e: + #gracefully handle pipes + if e.errno == errno.EPIPE: + # Python flushes standard streams on exit; redirect remaining output + # to devnull to avoid another BrokenPipeError at shutdown + + devnull = os.open(os.devnull, os.O_WRONLY) + os.dup2(devnull, sys.stdout.fileno()) + sys.exit(0) + else: + sys.exit(-1) + + def header(self): + cadena = () + for col in self.cols: + h = Output.table[col][0] + cadena = cadena + (h,) + + self.foo(cadena, self.hFormat) + + def data(self, a): + if a.dir == "": + direc = "na" + else: + direc = a.dir + + if a.op == "": + op = "na" + else: + op = a.op + + if a.mod == "": + mod = "na" + else: + mod = a.mod + + cadena = () + for col in self.cols: + attr = Output.table[col][1] + val = getattr(a, attr) + + if col == "layer": + assert(type(val) == list) + val = ":".join(val) + val = "-" if val == "" else val + + if col == "trace": + assert(type(val) == list) + if self.col and len(val): + val = val[-1] + val = val.split("/")[-1] + else: + val = ",".join(val) + val = "-" if val == "" else val + + if col in ["seq", "altseq"]: + assert(type(val) == list) + val = ",".join(map(str,val)) + val = "-" if val == "" else val + + cadena = cadena + (val,) + + self.foo(cadena, self.dFormat) diff --git a/apex/pyprof/prof/pointwise.py b/apex/pyprof/prof/pointwise.py new file mode 100644 index 000000000..3c9afc52a --- /dev/null +++ b/apex/pyprof/prof/pointwise.py @@ -0,0 +1,166 @@ +import numpy as np +from collections import OrderedDict +from .utility import Utility +from .base import OperatorLayerBase + +class Pointwise(OperatorLayerBase): + + ops = [] + ops += ["__abs__", "__neg__", "__invert__"] + ops += ["__add__", "__sub__", "__mul__", "__floordiv__", "__truediv__", "__pow__", "__mod__"] + ops += ["__radd__", "__rsub__", "__rmul__", "__rdiv__", "__rtruediv__", "__rfloordiv__", "__rpow__"] + ops += ["__iadd__", "__isub__", "__imul__", "__itruediv__",] + ops += ["__lt__", "__gt__", "__ge__", "__le__", "__eq__", "__ne__",] + ops += ["lt", "lt_", "gt", "gt_", "ge", "ge_", "le", "le_", "eq", "eq_", "ne", "ne_",] + ops += ["__and__", "__or__", "__xor__", "__lshift__", "__rshift__"] + ops += ["__iand__", "__ior__", "__ixor__", "__ilshift__", "__irshift__"] + ops += ["abs", "abs_", "neg", "neg_"] + ops += ["add", "add_", "div", "div_", "mul", "mul_", "reciprocal", "reciprocal_", "remainder", "remainder_", "sub", "sub_",] + ops += ["addcdiv", "addcdiv_", "addcmul", "addcmul_"] + ops += ["exp", "exp_", "exp1m", "exp1m_", "log", "log_", "log10", "log10_", "log1p", "log1p_", "log2", "log2_", "pow", "pow_", "rsqrt", "rsqrt_", "sqrt", "sqrt_",] + ops += ["ceil", "ceil_", "clamp", "clamp_", "floor", "floor_", "fmod", "fmod_", "frac", "frac_", "round", "round_", "sign", "sign_", "trunc", "trunc_"] + ops += ["acos", "acos_", "asin", "asin_", "atan", "atan_", "atan2", "atan2_", "cos", "cos_", "cosh", "cosh_", "sin", "sin_", "sinh", "sinh_", "tan", "tan_", "sigmoid", "sigmoid_", "tanh", "tanh_"] + ops += ["digamma", "erf", "erf_", "erfc", "erfc_", "erfinv", "erfinv_", "lerp", "lerp_", "mvlgamma",] + + @staticmethod + def foo(d): + return d['name'],d['type'],d['shape'],d['dtype'] + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + self.dir = d.dir + assert (d.dir in ["fprop", "bprop"]) + assert (op in Pointwise.ops) + + #Filter out all named parameters (kwargs). + #This might require revisiting in future. + args = list(filter(lambda x : x['name'] == "", args)) + + #Filter out non tensors + args = list(filter(lambda x : x['type'] == "tensor", args)) + + if (len(args) == 0): + self.shape = [(1,)] + self.type = "float32" #FIX + + elif (len(args) == 1): + in0 = args[0] + _,t0,s0,dt0 = Pointwise.foo(in0) + assert (t0 == "tensor") + self.shape = [s0,] + self.type = dt0 + + elif (len(args) == 2): + in0,in1 = args + _,t0,s0,dt0 = Pointwise.foo(in0) + _,t1,s1,dt1 = Pointwise.foo(in1) + assert (t0 == t1 == "tensor") + assert (dt0 == dt1) + self.shape = [s0,s1] + self.type = dt0 + + elif (len(args) == 3): + in0,in1,in2 = args + _,t0,s0,dt0 = Pointwise.foo(in0) + _,t1,s1,dt1 = Pointwise.foo(in1) + _,t2,s2,dt2 = Pointwise.foo(in2) + assert (t0 == t1 == t2 == "tensor") + assert (dt0 == dt1 == dt2) + self.shape = [s0,s1,s2] + self.type = dt0 + else: + assert False + return + + def params(self): + p = OrderedDict([('T',self.shape), ('type', self.type)]) + return p + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def elems(self): + tensor = self.shape + t = self.type + + if (len(tensor) == 1): + elems = 2 * Utility.numElems(tensor[0]) + elif (len(tensor) == 2): + if (tensor[0] == tensor[1]): # same shape + elems = Utility.numElems(tensor[0]) + if self.dir == "fprop": + elems *= 3 + else: + if (self.op_ in ["add", "__add__", "sub", "__sub__", "__isub__"]): + elems *= 2 + elif (self.op_ in ["__mul__", "__rmul__", "div", "__truediv__"]): + elems *= 3 + else: + assert False + else: #check for broadcast conditions + array1 = np.empty(list(tensor[0])) + array2 = np.empty(list(tensor[1])) + try: + out = np.broadcast(array1, array2).shape + except: + assert False + + elems = Utility.numElems(tensor[0]) + elems += Utility.numElems(tensor[1]) + elems += Utility.numElems(out) + #TODO bprop + elif (len(tensor) == 3): + if (tensor[0] == tensor[1] == tensor[2]): #same shape + elems = Utility.numElems(tensor[0]) + elems *= 4 + else: + assert False + else: + assert False + + return elems + + def bytes(self): + return self.elems() * Utility.typeToBytes(self.type) + + def flops(self): + # Note: some cases may still be missing. + + f = 0 + if self.op_ in ["__abs__", "__neg__", "__add__", "__sub__", "__mul__", + "__radd__", "__rmul__", "__iadd__", "__isub__", "__imul__", "__itruediv__", + "abs", "abs_", "neg", "neg_", "add", "add_", "div", "div_", "mul", "mul_", + "sub", "sub_", "exp", "exp_", "sign", "sign_", "trunc", "trunc_", + "sin", "sin_", "cos", "cos_", "sinh", "sinh_", "cosh", "cosh_", + "sqrt", "sqrt_", "rsqrt", "rsqrt_", "__lt__", "__gt__", "__ge__", "__le__", + "__eq__", "__ne__", "lt", "lt_", "gt", "gt_", "ge", "ge_", "le", "le_", + "eq", "eq_", "ne", "ne_", "ceil", "ceil_", "clamp", "clamp_", "floor", "floor_", + "round", "sign", "sign_", "trunc", "trunc_"]: + # We're counting only one operand, not two (2 operands, 1 op) + f = self.elems() / 2 + elif self.op_ in ["fmod", "fmod_"]: + f = self.elems() + elif self.op_ in ["tanh", "tanh_", "sigmoid", "sigmoid_", "log", "log_", "log2", + "log2_", "log10", "log10_"]: + f = self.elems() * 2 + elif self.op_ in ["asin", "asin_", "acos", "acos_", "atan", "atan_"]: + # no intrinsic, hence slow execution + # surprisingly, asin/acos and atan were all the same (via nvprof measurement) + f = self.elems() * 10 + + return f diff --git a/apex/pyprof/prof/pooling.py b/apex/pyprof/prof/pooling.py new file mode 100644 index 000000000..3f342b4d4 --- /dev/null +++ b/apex/pyprof/prof/pooling.py @@ -0,0 +1,59 @@ +from .collections import OrderedDict +from .utility import Utility + +# Work in progress. + +#poolFuncs = ["max_pool2d_with_indices_forward", "max_pool2d_with_indices"] +class MaxPool2d(object): + + def parse(marker): + + def convert2Tuple(arg): + assert (arg['type'] in ["int", "tuple"]) + if arg['type'] == "int": + return (arg['value'], arg['value']) + else: + return arg['value'] + + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + assert (mod == "torch.nn.functional") + assert (op == "max_pool2d") + assert (len(args) >= 2) + + #input + assert (args[0]['name'] == "") + inp = args[0] + assert (inp['type'] == "tensor") + i = inp['shape'] + t = inp['dtype'] + assert (len(i) == 4) #nchw tensor + + #kernel + if (args[1]['name'] == ""): + k = args[1] + else: + k = list(filter(lambda x : x['name'] == "kernel_size", args))[0] + k = convert2Tuple(k) + + #stride + s = k #default value + if ((len(args) >= 3) and args[2] == ""): + s = args[2] + s = convert2Tuple(s) + elif any(x['name'] == "stride" for x in args): + s = list(filter(lambda x : x['name'] == "stride", args))[0] + s = convert2Tuple(s) + + #padding + p = (0,0) + if ((len(args) >= 4) and args[3] == ""): + p = args[3] + p = convert2Tuple(p) + elif any(x['name'] == "padding" for x in args): + p = list(filter(lambda x : x['name'] == "padding", args))[0] + p = convert2Tuple(p) + + params = OrderedDict([('T', i), ('K', k), ('s',s), ('p',p), ('type', t)]) + return params diff --git a/apex/pyprof/prof/prof.py b/apex/pyprof/prof/prof.py new file mode 100755 index 000000000..a3467e6f3 --- /dev/null +++ b/apex/pyprof/prof/prof.py @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 + +""" +This script reads the output (Python dictionary) created by parse.py. +For every kernel (line) in the input it determines + module / class name e.g. torch.nn.functional + operator name e.g. linear + kernel parameters e.g. GEMM M, N, K, datatype + bytes + flops + tensor core usage + direction (fprop, bprop) + and other things. Please see the tool usage. +""" + +from .usage import parseArgs +from .output import Output +from .utility import Utility +from .pointwise import Pointwise +from .convert import Convert +from .blas import * +from .embedding import Embedding +from .reduction import * +from .dropout import Dropout +from .softmax import * +#from pooling import * # work in progress +from .linear import Linear +from .optim import Adam +from .misc import * +from .conv import Conv +from .activation import Activation +from .index_slice_join_mutate import Cat, Reshape, MaskedScatter, Gather, Nonzero, IndexSelect, MaskedSelect +from .recurrentCell import RNNCell +from .normalization import BatchNorm +from .randomSample import RandPerm +from .loss import MSELoss +from .data import Data + +def findFpropKernel(seq): + #Find the last fprop kernel with the same seqId + #First look at seqId and then at altSeqId + for idx in reversed(range(len(kernels))): + k = kernels[idx] + if (seq in k['seqId']) and (k['dir'] == "fprop"): + return idx + + for idx in reversed(range(len(kernels))): + k = kernels[idx] + if (seq in k['altSeqId']) and (k['dir'] == "fprop"): + return idx + + return -1 + #print("Error: seqId {} not found.".format(seq), file=sys.stderr) + #assert False + +def foo(mod, op, d): + if (op[0] == "linear"): + xx = Linear(d) + + # rnncell, lstmcell, grucell + elif (mod[0] in["LSTMCell", "GRUCell"]) and (op[0] == "forward"): + xx = RNNCell(d) + + elif op[0] in ["conv1d", "conv2d",]: + xx = Conv(d) + + elif (op[0] in Pointwise.ops): + xx = Pointwise(d) + + elif (op[0] in Convert.ops): + xx = Convert(d) + + elif op[0] in ["__matmul__", "matmul"]: + xx = Matmul(d) + + elif op[0] == "embedding": + xx = Embedding(d) + + #reduction + elif op[0] == "sum": + xx = Sum(d) + + elif op[0] == "mean": + xx = Mean(d) + + elif op[0] == "norm": + xx = Norm(d) + + elif op[0] == "dropout": + xx = Dropout(d) + + #Index, Slice, Join, Mutate + elif (op[0] == "cat"): + xx = Cat(d) + + elif (op[0] == "reshape"): + xx = Reshape(d) + + elif (op[0] == "masked_scatter_"): + xx = MaskedScatter(d) + + elif (op[0] == "gather"): + xx = Gather(d) + + elif (op[0] == "nonzero"): + xx = Nonzero(d) + + elif (op[0] == "index_select"): + xx = IndexSelect(d) + + elif (op[0] == "masked_select"): + xx = MaskedSelect(d) + + #blas + elif op[0] in ["addmm", "addmm_"]: + xx = Addmm(d) + + elif op[0] == "mm": + xx = Mm(d) + + elif op[0] == "bmm": + xx = Bmm(d) + + #softmax + elif op[0] == "softmax": + xx = Softmax(d) + + elif op[0] == "log_softmax": + xx = LogSoftmax(d) + + #loss + elif op[0] == "mse_loss": + xx = MSELoss(d) + + #optimizers + elif op[0] == "adam": + xx = Adam(d) + + #normalization + elif op[0] == "batch_norm": + xx = BatchNorm(d) + + #random + elif op[0] == "randperm": + xx = RandPerm(d) + + #misc + elif op[0] == "copy_": + xx = Copy(d) + + elif op[0] == "clone": + xx = Clone(d) + + elif op[0] == "contiguous": + xx = Contiguous(d) + + elif op[0] == "any": + xx = Any(d) + + elif (op[0] in Activation.ops): + xx = Activation(d) + + elif op[0] == "to": + xx = Convert(d) + + else: + xx = Foo(d) + + return xx + +def main(): + #Read cmd line arguments + cmdArgs = parseArgs() + + output = Output(cmdArgs) + output.header() + + idx = -1 + #Read in all the kernel info + for line in cmdArgs.file: + idx += 1 + kernel = eval(line) + assert(kernel) + kernels.append(kernel) + + k = kernel + d = Data(k) + + mod = k['mod'] + op = k['op'] + + flops = 0 + params = {"na":"na"} + tc = "na" + bytes = 0 + + if (d.dir == "bprop"): + d.seqMarker = k['seqMarker'] + seq = k['seqId'] + if len(seq) > 1: + pass + seq = k['seqId'][:1] + assert (len(seq) == 1), seq + #assert (seq[0] != 0) + assert (len(d.seqMarker) > 0) + #If there is no useful marker associated, use the + #sequence number to find the kernel from fprop + if len(d.argMarker) == 0: + index = findFpropKernel(seq[0]) + if index >= 0: + d.argMarker = kernels[index]['marker'] + d.modMarker = kernels[index]['reprMarkers'] + mod = kernels[index]['mod'] + op = kernels[index]['op'] + + d.layer = kernels[index]['layer'] + d.trace = kernels[index]['trace'] + + # Check if marker has our annotations + if len(d.argMarker) and Utility.hasNVTX(d.argMarker[0]): + + xx = foo(mod, op, d) + + bytes = xx.bytes() + flops = xx.flops() + op = xx.op() + params = xx.params() + tc = xx.tc() + + if type(op) is list: + if len(op): + op = op[0] + else: + op = "" + + if type(mod) is list: + if len(mod): + mod = mod[0] + else: + mod = "" + + d.index = idx+1 + + # The following 8 come from operator class functions. + d.setParams(params) + d.tc = tc + d.flops = flops + d.bytes = bytes + d.mod = mod + d.op = op + + output.data(d) + +kernels = [] +if __name__ == '__main__': + main() diff --git a/apex/pyprof/prof/randomSample.py b/apex/pyprof/prof/randomSample.py new file mode 100644 index 000000000..f7521bf34 --- /dev/null +++ b/apex/pyprof/prof/randomSample.py @@ -0,0 +1,43 @@ +from collections import OrderedDict +from .utility import Utility +from .base import OperatorLayerBase + +class RandPerm(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "torch") + assert (op == "randperm") + assert (len(args) == 1) + n = args[0] + assert n['type'] == "int" + self.n = n['value'] + + def params(self): + p = OrderedDict([('N', self.n)]) + return p + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def bytes(self): + return self.n * Utility.typeToBytes("int64") + + def flops(self): + # Depends on RNG but this is probably a reasonable assumption. + return self.n * 3 diff --git a/apex/pyprof/prof/recurrentCell.py b/apex/pyprof/prof/recurrentCell.py new file mode 100644 index 000000000..945a7b158 --- /dev/null +++ b/apex/pyprof/prof/recurrentCell.py @@ -0,0 +1,207 @@ +from collections import OrderedDict +from .utility import Utility +from .base import OperatorLayerBase + +def hasTileSize(name): + if ("sgemm" in name) or ("884gemm" in name) or ("hgemm" in name): + return True + else: + return False + +def ctaTile(name): + name = name.split("_") + name = list(filter(lambda x : "x" in x, name)) + name = list(filter(lambda x : "slice" not in x, name)) + assert(len(name) == 1) + name = name[0].split("x") + assert(len(name) == 2) + name = list(map(int, name)) + return name[0], name[1] + +class RNNCell(OperatorLayerBase): + """ + This class supports RNNCell, LSTMCell and GRUCell. + """ + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + self.name = d.name + self.dir = d.dir + self.sub = d.sub + self.grid = d.grid + + assert (op == "forward") + assert (mod in ["LSTMCell", "GRUCell", "RNNCell"]) + assert (len(args) in [2,3]) + + x,h = args[0],args[1] + b1,ii = x['shape'] + b2,hh = h['shape'] + assert b1 == b2 + assert x['dtype'] == h['dtype'] + t = x['dtype'] + + self.cell = mod + self.inp = ii + self.hid = hh + self.b = b1 + self.type = t + + self.multiple = 1 + if self.cell == "LSTMCell": + self.multiple = 4 + elif self.cell == "GRUCell": + self.multiple = 3 + + self.gemm = None + self.m = None + self.n = None + self.k = None + self.elems = 0 + + self.bar() + + def params(self): + if self.gemm is None: + p = OrderedDict([('cell', self.cell), ('X', self.inp), ('H', self.hid), ('B', self.b), ('type', self.type)]) + else: + assert self.m is not None + assert self.n is not None + assert self.k is not None + p = OrderedDict([('gemm', self.gemm), ('M', self.m), ('N', self.n), ('K', self.k), ('type', self.type)]) + return p + + def tc(self): + if "gemm" in self.name: + return 1 if "884gemm" in self.name else 0 + else: + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def bytes(self): + if self.gemm is not None: + m, n, k, t = self.m, self.n, self.k, self.type + b = (m*k + k*n + m*n) * Utility.typeToBytes(t) + elif self.elems != 0: + b = self.elems * Utility.typeToBytes(self.type) + else: + b = 0 + return b + + def flops(self): + if self.gemm is not None: + m, n, k = self.m, self.n, self.k + f = 2*m*n*k + elif self.elems != 0: + f = 0 #TODO + else: + f = 0 + return f + + def bar(self): + cell = self.cell + X = self.inp + H = self.hid + B = self.b + t = self.type + subseqId = self.sub + direc = self.dir + name = self.name + grid = self.grid + multiple = self.multiple + + if direc == "fprop": + subseqId = subseqId % 3 + if subseqId == 0: #layer gemm + self.gemm = "layer" + self.m = multiple*H + self.n = B + self.k = X + elif subseqId == 1: #recurrent gemm + self.gemm = "recur" + self.m = multiple*H + self.n = B + self.k = H + else: + layerGemmElems = multiple*H*B + recurGemmElems = multiple*H*B + cElems = H*B + hElems = H*B + totElems = layerGemmElems + recurGemmElems + 2*cElems + hElems + self.elems = totElems + + else: + if ("gemm" in name) and hasTileSize(name): #gemm + #Get cta tile size + tileX, tileY = ctaTile(name) + #Get grid dimensions + grid = grid.split(",") + gridX,gridY,gridZ = map(lambda x : int(x), grid) + + gemmM = tileX * gridX + gemmN = tileY * gridY + + if name[-3:] == "_nn": # dgrad + if (gemmM == H): # recurrent dgrad + #Ideally gemmN = B, but we have a limited set of tile sizes. + gemmN = B + gemmK = multiple*H + + self.gemm = "recur" + self.m = gemmM + self.n = gemmN + self.k = gemmK + + elif (gemmM == X): # layer dgrad + #assert(gemmN % B == 0) + gemmK = multiple*H + + self.gemm = "layer" + self.m = gemmM + self.n = gemmN + self.k = gemmK + + else: + pass + + elif name[-3:] == "_nt": #wgrad + if (gemmM == H): #recurrent wgrad + assert (gemmN == multiple*H) + gemmK = B + + self.gemm = "recur" + self.m = gemmM + self.n = gemmN + self.k = gemmK + + elif (gemmM == X): #layer wgrad + assert (gemmN == multiple*H) + gemmK = B + + self.gemm = "layer" + self.m = gemmM + self.n = gemmN + self.k = gemmK + + else: + pass + else: + pass + else: + pass + + return diff --git a/apex/pyprof/prof/reduction.py b/apex/pyprof/prof/reduction.py new file mode 100644 index 000000000..af2703523 --- /dev/null +++ b/apex/pyprof/prof/reduction.py @@ -0,0 +1,150 @@ +from collections import OrderedDict +from .utility import Utility +from .base import OperatorLayerBase + +class Mean(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod in ["torch", "Tensor"]) + assert (op == "mean") + + #Filter out named parameters + args = list(filter(lambda x : x['name'] == '', args)) + + assert (len(args) <= 2) + i = args[0] + + self.shape = i['shape'] + self.type = i['dtype'] + self.dir = d.dir + self.sub = d.sub + + def params(self): + p = OrderedDict([('T', self.shape), ('type', self.type)]) + return p + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def elems(self): + return Utility.numElems(self.shape) + + def bytes(self): + if self.sub == 0: + return self.elems() * Utility.typeToBytes(self.type) + else: + return 0 + + def flops(self): + if self.sub == 0: + return self.elems() + 1 + else: + return 0 + +class Sum(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod in ["torch", "Tensor"]) + assert (op == "sum") + assert (len(args) >= 1) + + #Get input + if (args[0]['name'] == ""): + i = args[0] + else: + i = list(filter(lambda x : x['name'] == "input", args))[0] + + self.shape = i['shape'] + self.type = i['dtype'] + + def params(self): + p = OrderedDict([('T', self.shape), ('type', self.type)]) + return p + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def elems(self): + return Utility.numElems(self.shape) + + def flops(self): + # Note: This is incorrect, need to calculate actual flops (say via nvprof) + return self.elems() + + def bytes(self): + return self.elems() * Utility.typeToBytes(self.type) + +class Norm(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod in ["torch", "Tensor"]) + assert (op == "norm") + #assert (len(args) == 1) + i = args[0] + self.shape = i['shape'] + self.type = i['dtype'] + + def params(self): + p = OrderedDict([('T', self.shape), ('type', self.type)]) + return p + + def elems(self): + return Utility.numElems(self.shape) + + def bytes(self): + return self.elems() * Utility.typeToBytes(self.type) + + def flops(self): + # square and add plus sqrt + return 2 * self.elems() + 1 + + def tc(self): + return "-" + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ diff --git a/apex/pyprof/prof/softmax.py b/apex/pyprof/prof/softmax.py new file mode 100644 index 000000000..4271a8d94 --- /dev/null +++ b/apex/pyprof/prof/softmax.py @@ -0,0 +1,115 @@ +from collections import OrderedDict +from .utility import Utility +from .base import OperatorLayerBase + +class Softmax(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "torch.nn.functional") + assert (op == "softmax") + + #Filter out named parameters + args = list(filter(lambda x : x['name'] == '', args)) + + assert (len(args) <= 2) + self.shape = args[0]['shape'] + self.type = args[0]['dtype'] + self.dir = d.dir + + return + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def tc(self): + return "-" + + def params(self): + p = OrderedDict([('T', self.shape), ('type', self.type)]) + return p + + def elems(self): + return Utility.numElems(self.shape) + + def flops(self): + # Note: exp, sum-reduce, divide + #flops = elems * 3 + return 0 + + def bytes(self): + b = self.elems() * Utility.typeToBytes(self.type) + b *= 3 if self.dir == "fprop" else 5 #verify + return b + +class LogSoftmax(OperatorLayerBase): + + def __init__(self, d): + marker = eval(d.argMarker[0]) + mod = marker['mod'] + op = marker['op'] + args = marker['args'] + + self.marker = marker + self.mod_ = mod + self.op_ = op + self.args = args + + assert (mod == "torch.nn.functional") + assert (op == "log_softmax") + + #Filter out named parameters + args = list(filter(lambda x : x['name'] == '', args)) + + assert (len(args) <= 2) + + #Get input + if (args[0]['name'] == ""): + i = args[0] + else: + i = list(filter(lambda x : x['name'] == "input", args))[0] + + t = i['dtype'] + + self.shape = i['shape'] + self.type = i['dtype'] + self.dir = d.dir + return + + def op(self): + return self.op_ + + def mod(self): + return self.mod_ + + def tc(self): + return "-" + + def params(self): + p = OrderedDict([('T', self.shape), ('type', self.type)]) + return p + + def elems(self): + return Utility.numElems(self.shape) + + def flops(self): + # Note: exp, sum-reduce, divide, log + #flops = elems * 4 + return 0 + + def bytes(self): + b = self.elems() * Utility.typeToBytes(self.type) + b *= 3 if self.dir == "fprop" else 5 #verify + return b diff --git a/apex/pyprof/prof/usage.py b/apex/pyprof/prof/usage.py new file mode 100644 index 000000000..3a299d565 --- /dev/null +++ b/apex/pyprof/prof/usage.py @@ -0,0 +1,73 @@ +import sys +import argparse + +def parseArgs(): + """ + Print usage and parse arguments. + """ + + def check_cols(value): + valid = ["idx", "seq", "altseq", "tid", "layer", "trace", "dir", "sub", "mod", "op", "kernel", "params", "sil", "tc", "device", "stream", "grid", "block", "flops", "bytes"] + cols = value.split(",") + for col in cols: + if col not in valid: + raise argparse.ArgumentTypeError("{} is not a valid column name. Valid column names are {}.".format(col, ",".join(valid))) + return cols + + def openFile(f): + try: + d = open(f, "r") + return d + except IOError: + print("Error opening file {}. Exiting.".format(f), file=sys.stderr) + sys.exit(1) + + parser = argparse.ArgumentParser(prog=sys.argv[0], description="PyTorch Profiler", formatter_class=argparse.RawTextHelpFormatter) + parser.add_argument("file", + nargs='?', + type=str, + default=None, + help="Output of parse.py (Python dictionary).") + + parser.add_argument("-c", + type=check_cols, + default="idx,dir,sub,mod,op,kernel,params,sil", + help='''Comma seperated names of columns to print. +idx: Index +seq: PyTorch Sequence Id +altseq: PyTorch Alternate Sequence Id +tid: Thread Id +layer: User annotated NVTX string (can be nested) +trace: Function Call Trace +dir: Direction +sub: Sub Sequence Id +mod: Module +op: Operattion +kernel: Kernel Name +params: Parameters +sil: Silicon Time (in ns) +tc: Tensor Core Usage +device: GPU Device Id +stream: Stream Id +grid: Grid Dimensions +block: Block Dimensions +flops: Floating point ops (FMA = 2 FLOPs) +bytes: Number of bytes in and out of DRAM +e.g. -c idx,kernel,sil''') + + group = parser.add_mutually_exclusive_group() + group.add_argument("--csv", + action="store_true", + default=False, + help="Print a CSV output.") + group.add_argument("-w", + type=int, + default=0, + help="Width of columnated output.") + + args = parser.parse_args() + if args.file is None: + args.file = sys.stdin + else: + args.file = openFile(args.file) + return args diff --git a/apex/pyprof/prof/utility.py b/apex/pyprof/prof/utility.py new file mode 100644 index 000000000..450a41961 --- /dev/null +++ b/apex/pyprof/prof/utility.py @@ -0,0 +1,58 @@ +from functools import reduce + +class Utility(object): + + @staticmethod + def numElems(shape): + assert (type(shape) == tuple) + return reduce(lambda x,y: x*y, shape, 1) + + @staticmethod + def typeToBytes(t): + if (t in ["uint8", "int8", "byte", "char"]): + return 1 + elif (t in ["float16", "half", "int16", "short"]): + return 2 + elif (t in ["float32", "float", "int32", "int"]): + return 4 + elif (t in ["int64", "long", "float64", "double"]): + return 8 + assert False + + @staticmethod + def typeToString(t): + if (t in ["uint8", "byte", "char"]): + return "uint8" + elif (t in ["int8",]): + return "int8" + elif (t in ["int16", "short",]): + return "int16" + elif (t in ["float16", "half"]): + return "fp16" + elif (t in ["float32", "float"]): + return "fp32" + elif (t in ["int32", "int",]): + return "int32" + elif (t in ["int64", "long"]): + return "int64" + elif (t in ["float64", "double",]): + return "fp64" + assert False + + @staticmethod + def hasNVTX(marker): + if type(marker) is str: + try: + marker = eval(marker) + except: + return False + + if type(marker) is dict: + keys = marker.keys() + return ("mod" in keys) and ("op" in keys) and ("args" in keys) + else: + return False + + @staticmethod + def isscalar(t): + return (t in ["float", "int"]) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 000000000..f417ef5c5 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +cxxfilt>=0.2.0 +tqdm>=4.28.1 +numpy>=1.15.3 +PyYAML>=5.1 +pytest>=3.5.1 diff --git a/setup.py b/setup.py index 3deb46d8e..529153a9f 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,9 @@ from setuptools import setup, find_packages import subprocess +from pip._internal import main as pipmain import sys +import warnings if not torch.cuda.is_available(): print("\nWarning: Torch did not find available GPUs on this system.\n", @@ -19,6 +21,17 @@ cmdclass = {} ext_modules = [] +if "--pyprof" in sys.argv: + with open('requirements.txt') as f: + required_packages = f.read().splitlines() + pipmain(["install"] + required_packages) + try: + sys.argv.remove("--pyprof") + except: + pass +else: + warnings.warn("Option --pyprof not specified. Not installing PyProf dependencies!") + if "--cpp_ext" in sys.argv or "--cuda_ext" in sys.argv: if TORCH_MAJOR == 0: raise RuntimeError("--cpp_ext requires Pytorch 1.0 or later, " diff --git a/tests/L0/run_pyprof_nvtx/__init__.py b/tests/L0/run_pyprof_nvtx/__init__.py new file mode 100644 index 000000000..7cd4c1062 --- /dev/null +++ b/tests/L0/run_pyprof_nvtx/__init__.py @@ -0,0 +1 @@ +import test_pyprof_nvtx.TestPyProfNvtx as TestPyProfNvtx diff --git a/tests/L0/run_pyprof_nvtx/test_pyprof_nvtx.py b/tests/L0/run_pyprof_nvtx/test_pyprof_nvtx.py new file mode 100644 index 000000000..6f2c8d1e2 --- /dev/null +++ b/tests/L0/run_pyprof_nvtx/test_pyprof_nvtx.py @@ -0,0 +1,526 @@ +import inspect +import os +import torch +import torch.nn.functional as F +import unittest + +from apex import pyprof +pyprof.nvtx.init() + +# TODO: add tests for: +# F.bilinear, F.l1_loss, F.multilabel_soft_margin_loss, F.multi_margin_loss + +class TestPyProfNvtx(unittest.TestCase): + + def __init__(self, testName, dtype=torch.float16): + super().__init__(testName) + self.dtype = dtype + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_conv1d(self): + # Data and weight tensors + tensor1d_in_conv = torch.randn(32, 3, 224, device='cuda', dtype=self.dtype) + tensor1d_in_conv_grouped = torch.randn(32, 6, 224, device='cuda', dtype=self.dtype) + conv1d_filter = torch.randn(16, 3, 3, device='cuda', dtype=self.dtype) + conv1d_bias = torch.ones(16, device='cuda', dtype=self.dtype) + # Vanilla conv1d + conv1d_out_vanilla = F.conv1d(tensor1d_in_conv, conv1d_filter) + # conv1d with bias + conv1d_out_with_bias = F.conv1d(tensor1d_in_conv, conv1d_filter, bias=conv1d_bias) + # conv1d - stride > 1 + conv1d_out_strided = F.conv1d(tensor1d_in_conv, conv1d_filter, stride=2) + # conv1d - dilation > 1 + conv1d_out_dilated = F.conv1d(tensor1d_in_conv, conv1d_filter, dilation=2) + # conv1d - groups > 1 + conv1d_out_grouped = F.conv1d(tensor1d_in_conv_grouped, conv1d_filter, groups=2) + # conv1d - padding with zeros + conv1d_out_padding_zeros = F.conv1d(tensor1d_in_conv, conv1d_filter, padding=6) + + def test_conv2d(self): + # Data and weight tensors + tensor2d_in_conv = torch.randn(32, 3, 224, 224, device='cuda', dtype=self.dtype) + tensor2d_in_conv_grouped = torch.randn(32, 6, 224, 224, device='cuda', dtype=self.dtype) + conv2d_filter = torch.randn(16, 3, 3, 3, device='cuda', dtype=self.dtype) + conv2d_bias = torch.ones(16, device='cuda', dtype=self.dtype) + # Vanilla conv2d + conv2d_out_vanilla = F.conv2d(tensor2d_in_conv, conv2d_filter) + # conv2d with bias + conv2d_with_bias = F.conv2d(tensor2d_in_conv, conv2d_filter, bias=conv2d_bias) + # conv2d - stride > 1 + conv2d_out_strided = F.conv2d(tensor2d_in_conv, conv2d_filter, stride=2) + # conv2d - dilation > 1 + conv2d_out_dilated = F.conv2d(tensor2d_in_conv, conv2d_filter, dilation=2) + # conv2d - groups > 1 + conv2d_out_grouped = F.conv2d(tensor2d_in_conv_grouped, conv2d_filter, groups=2) + # conv2d - padding with zeros + conv2d_out_padding_zeros = F.conv2d(tensor2d_in_conv, conv2d_filter, padding=6) + + + def test_conv3d(self): + # Data and weight tensors + tensor3d_in_conv = torch.randn(32, 3, 16, 224, 224, device='cuda', dtype=self.dtype) + tensor3d_in_conv_grouped = torch.randn(32, 6, 16, 224, 224, device='cuda', dtype=self.dtype) + conv3d_filter = torch.randn(16, 3, 3, 3, 3, device='cuda', dtype=self.dtype) + conv3d_bias = torch.ones(16, device='cuda', dtype=self.dtype) + # Vanilla conv3d + conv3d_out_vanilla = F.conv3d(tensor3d_in_conv, conv3d_filter) + # conv3d - stride > 1 + conv3d_out_strided = F.conv3d(tensor3d_in_conv, conv3d_filter, stride=2) + # conv3d - dilation > 1 + conv3d_out_dilated = F.conv3d(tensor3d_in_conv, conv3d_filter, dilation=2) + # conv3d - groups > 1 + conv3d_out_grouped = F.conv3d(tensor3d_in_conv_grouped, conv3d_filter, groups=2) + # conv3d - padding with zeros + conv3d_out_padding_zeros = F.conv3d(tensor3d_in_conv, conv3d_filter, padding=6) + + def test_conv_transpose1d(self): + # Data and weight tensors + conv_transpose1d_tensor = torch.randn(64, 16, 64, device='cuda', dtype=self.dtype) + conv_transpose1d_filter = torch.randn(16, 32, 3, device='cuda', dtype=self.dtype) + conv_transpose1d_bias = torch.randn(32, device='cuda', dtype=self.dtype) + # Conv transpose runs + conv_transpose1d_out = F.conv_transpose1d(conv_transpose1d_tensor, conv_transpose1d_filter) + conv_transpose1d_out_biased = F.conv_transpose1d(conv_transpose1d_tensor, conv_transpose1d_filter, bias=conv_transpose1d_bias) + conv_transpose1d_out_strided = F.conv_transpose1d(conv_transpose1d_tensor, conv_transpose1d_filter, stride=2) + conv_transpose1d_out_padded = F.conv_transpose1d(conv_transpose1d_tensor, conv_transpose1d_filter, padding=3) + conv_transpose1d_out2_padded = F.conv_transpose1d(conv_transpose1d_tensor, conv_transpose1d_filter, output_padding=2, dilation=3) + conv_transpose1d_out_grouped = F.conv_transpose1d(conv_transpose1d_tensor, conv_transpose1d_filter, groups=2) + conv_transpose1d_out_dilated = F.conv_transpose1d(conv_transpose1d_tensor, conv_transpose1d_filter, dilation=2) + + + def test_conv_transpose2d(self): + # Data and weight tensors + conv_transpose2d_tensor = torch.randn(64, 8, 5, 5, device='cuda', dtype=self.dtype) + conv_transpose2d_filter = torch.randn(8, 16, 3, 3, device='cuda', dtype=self.dtype) + conv_transpose2d_bias = torch.randn(16, device='cuda', dtype=self.dtype) + # Conv transpose runs + conv_transpose2d_out = F.conv_transpose2d(conv_transpose2d_tensor, conv_transpose2d_filter) + conv_transpose2d_out_biased = F.conv_transpose2d(conv_transpose2d_tensor, conv_transpose2d_filter, bias=conv_transpose2d_bias) + conv_transpose2d_out_strided = F.conv_transpose2d(conv_transpose2d_tensor, conv_transpose2d_filter, stride=2) + conv_transpose2d_out_padded = F.conv_transpose2d(conv_transpose2d_tensor, conv_transpose2d_filter, padding=3) + conv_transpose2d_out2_padded = F.conv_transpose2d(conv_transpose2d_tensor, conv_transpose2d_filter, output_padding=2, dilation=3) + conv_transpose2d_out_grouped = F.conv_transpose2d(conv_transpose2d_tensor, conv_transpose2d_filter, groups=2) + conv_transpose2d_out_dilated = F.conv_transpose2d(conv_transpose2d_tensor, conv_transpose2d_filter, dilation=2) + + def test_conv_transpose3d(self): + # Data and weight tensors + conv_transpose3d_tensor = torch.randn(20, 16, 50, 10, 20, device='cuda', dtype=self.dtype) + conv_transpose3d_filter = torch.randn(16, 33, 3, 3, 3, device='cuda', dtype=self.dtype) + conv_transpose3d_bias = torch.randn(33, device='cuda', dtype=self.dtype) + # Conv transpose runs + conv_transpose3d_out = F.conv_transpose3d(conv_transpose3d_tensor, conv_transpose3d_filter) + conv_transpose3d_out_biased = F.conv_transpose3d(conv_transpose3d_tensor, conv_transpose3d_filter, bias=conv_transpose3d_bias) + conv_transpose3d_out_strided = F.conv_transpose3d(conv_transpose3d_tensor, conv_transpose3d_filter, stride=2) + conv_transpose3d_out_padded = F.conv_transpose3d(conv_transpose3d_tensor, conv_transpose3d_filter, padding=3) + conv_transpose3d_out2_padded = F.conv_transpose3d(conv_transpose3d_tensor, conv_transpose3d_filter, output_padding=2, dilation=3) + conv_transpose3d_out_grouped = F.conv_transpose3d(conv_transpose3d_tensor, conv_transpose3d_filter, groups=2) + conv_transpose3d_out_dilated = F.conv_transpose3d(conv_transpose3d_tensor, conv_transpose3d_filter, dilation=2) + + def test_unfold(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + kernel_size = (4, 5) + inp_unf_dilated = F.unfold(inp, kernel_size, dilation=2) + inp_unf_padded = F.unfold(inp, kernel_size, padding=2) + inp_unf_strided = F.unfold(inp, kernel_size, stride=2) + + def test_fold(self): + inp = torch.randn(3, 20, 20, device='cuda', dtype=self.dtype) + inp_folded = F.fold(inp, (4, 5), (1, 1)) + + def test_avg_pool1d(self): + inp = torch.randn(1, 1, 28, device='cuda', dtype=self.dtype) + out = F.avg_pool1d(inp, kernel_size=5, stride=2, padding=2, ceil_mode=True, count_include_pad=False) + + def test_avg_pool2d(self): + inp = torch.randn(1, 3, 224, 224, device='cuda', dtype=self.dtype) + out = F.avg_pool2d(inp, kernel_size=5, stride=2, padding=2, ceil_mode=True, count_include_pad=False) + + def test_avg_pool3d(self): + inp = torch.randn(1, 3, 16, 224, 224, device='cuda', dtype=self.dtype) + out = F.avg_pool3d(inp, kernel_size=5, stride=2, padding=2, ceil_mode=True, count_include_pad=False) + + def test_adaptive_avg_pool1d(self): + inp = torch.randn(1, 1, 28, device='cuda', dtype=self.dtype) + out = F.adaptive_avg_pool1d(inp, output_size=5) + + def test_adaptive_avg_pool2d(self): + inp = torch.randn(1, 16, 32, 32, device='cuda', dtype=self.dtype) + out = F.adaptive_avg_pool2d(inp, output_size=5) + + def test_adaptive_avg_pool3d(self): + inp = torch.randn(1, 16, 16, 32, 32, device='cuda', dtype=self.dtype) + out = F.adaptive_avg_pool3d(inp, output_size=5) + + def test_max_pool1d(self): + inp = torch.randn(1, 16, 32, device='cuda', dtype=self.dtype) + out = F.max_pool1d(inp, kernel_size=5, stride=2, padding=2, return_indices=True, ceil_mode=True) + + def test_max_pool2d(self): + inp = torch.randn(1, 16, 32, 32, device='cuda', dtype=self.dtype) + out = F.max_pool2d(inp, kernel_size=5, stride=2, padding=2, return_indices=True, ceil_mode=True) + + def test_max_pool3d(self): + inp = torch.randn(1, 16, 16, 32, 32, device='cuda', dtype=self.dtype) + out = F.max_pool3d(inp, kernel_size=5, stride=2, padding=2, return_indices=True, ceil_mode=True) + + def test_adaptive_max_pool1d(self): + inp = torch.randn(1, 16, 28, device='cuda', dtype=self.dtype) + out = F.adaptive_max_pool1d(inp, output_size=5, return_indices=True) + + def test_adaptive_max_pool2d(self): + inp = torch.randn(1, 16, 32, 32, device='cuda', dtype=self.dtype) + out = F.adaptive_max_pool2d(inp, output_size=5, return_indices=True) + + def test_adaptive_max_pool3d(self): + inp = torch.randn(1, 16, 16, 32, 32, device='cuda', dtype=self.dtype) + out = F.adaptive_max_pool3d(inp, output_size=5, return_indices=True) + + def test_max_unpool1d(self): + inp = torch.randn(1, 16, 32, device='cuda', dtype=self.dtype) + output, indices = F.max_pool1d(inp, kernel_size=5, stride=2, padding=2, return_indices=True, ceil_mode=True) + output = F.max_unpool1d(output, indices, kernel_size=2, stride=2, padding=2) + + def test_max_unpool2d(self): + inp = torch.randn(1, 16, 32, 32, device='cuda', dtype=self.dtype) + output, indices = F.max_pool2d(inp, kernel_size=5, stride=2, padding=2, return_indices=True, ceil_mode=True) + output = F.max_unpool2d(output, indices, kernel_size=2, stride=2, padding=2) + + def test_max_unpool3d(self): + inp = torch.randn(1, 16, 8, 32, 32, device='cuda', dtype=self.dtype) + output, indices = F.max_pool3d(inp, kernel_size=5, stride=2, padding=2, return_indices=True, ceil_mode=True) + output = F.max_unpool3d(output, indices, kernel_size=2, stride=2, padding=2) + + def test_lp_pool1d(self): + inp = torch.randn(1, 32, 64, device='cuda', dtype=self.dtype) + output = F.lp_pool1d(inp, 2, 3, stride=2, ceil_mode=True) + + def test_lp_pool2d(self): + #torch.nn.LPPool2d(norm_type, kernel_size, stride=None, ceil_mode=False) + inp = torch.randn(1, 32, 64, 64, device='cuda', dtype=self.dtype) + output = F.lp_pool2d(inp, 2, 3, stride=2, ceil_mode=True) + + def test_threshold(self): + inp = torch.randn(1, 8, 32, 32, device='cuda', dtype=self.dtype) + output = F.threshold(inp, 6, 6, inplace=False) + + def test_threshold_(self): + inp = torch.randn(1, 8, 32, 32, device='cuda', dtype=self.dtype) + output = F.threshold_(inp, 6, 6) + + def test_relu(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.relu(inp, inplace=False) + + def test_relu_(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.relu_(inp) + + def test_hardtanh(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.hardtanh(inp, min_val=-1., max_val=1., inplace=False) + + def test_hardtanh_(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.hardtanh_(inp, min_val=-1., max_val=1.) + + def test_relu6(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.relu6(inp, inplace=False) + + def test_elu(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.elu(inp, alpha=1.0, inplace=False) + + def test_elu_(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.elu_(inp, alpha=1.0) + + def test_selu(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.selu(inp) + + def test_celu(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.celu(inp, alpha=1.0, inplace=False) + + def test_leaky_relu(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.leaky_relu(inp, negative_slope=0.01, inplace=False) + + def test_leaky_relu_(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.leaky_relu_(inp, negative_slope=0.01) + + def test_prelu(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + weight = torch.randn(1, device='cuda', dtype=self.dtype) + output = F.prelu(inp, weight) + + def test_rrelu(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.rrelu(inp, lower=1./8, upper=1./3, training=False, inplace=False) + + def test_rrelu_(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.rrelu(inp, lower=1./8, upper=1./3, training=False) + + def test_glu(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.glu(inp, dim=-1) + + def test_logsigmoid(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.logsigmoid(inp) + + def test_hardshrink(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.hardshrink(inp, lambd=0.5) + + def test_tanhshrink(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.tanhshrink(inp) + + def test_softsign(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.softsign(inp) + + def test_softplus(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.softplus(inp, beta=1, threshold=20) + + def test_softmin(self): + inp = torch.randn(16, 1024, device='cuda', dtype=self.dtype) + output = F.softmin(inp, dim=1, _stacklevel=3, dtype=self.dtype) + + def test_softmax(self): + inp = torch.randn(16, 1024, device='cuda', dtype=self.dtype) + output = F.softmax(inp, dim=1, _stacklevel=3, dtype=self.dtype) + + def test_softshrink(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.softshrink(inp, lambd=0.5) + + def test_gumbel_softmax(self): + inp = torch.randn(16, 1024, device='cuda', dtype=self.dtype) + output = F.gumbel_softmax(inp, tau=1, hard=False, eps=1e-10, dim=-1) + + def test_log_softmax(self): + inp = torch.randn(16, 1024, device='cuda', dtype=self.dtype) + output = F.log_softmax(inp, dim=-1, _stacklevel=3) + + def test_tanh(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = torch.tanh(inp) + + def test_sigmoid(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = torch.sigmoid(inp) + + def test_batch_norm(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + # running_mean, running_var + running_mean = torch.randn(3, device='cuda', dtype=self.dtype) + running_var = torch.randn(3, device='cuda', dtype=self.dtype) + output = F.batch_norm(inp, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05) + + def test_instance_norm(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + running_mean = torch.randn(3, device='cuda', dtype=self.dtype) + running_var = torch.randn(3, device='cuda', dtype=self.dtype) + output = F.instance_norm(inp, running_mean=running_mean, running_var=running_var, weight=None, bias=None, use_input_stats=True, momentum=0.1, eps=1e-05) + + def test_layer_norm(self): + inp = torch.randn(1, 3, 32, 32, device='cuda', dtype=self.dtype) + output = F.layer_norm(inp, inp.size()[1:], weight=None, bias=None, eps=1e-05) + + def test_local_response_norm(self): + inp = torch.randn(16, 8, 64, 64, device='cuda', dtype=self.dtype) + output = F.local_response_norm(inp, 2, alpha=0.0001, beta=0.75, k=1.0) + + def test_normalize(self): + inp = torch.randn(16, 8, 64, 64, device='cuda', dtype=self.dtype) + output = F.normalize(inp, p=2, dim=1, eps=1e-12, out=None) + + def test_linear(self): + inp = torch.randn(32, 64, 128, device='cuda', dtype=self.dtype) + weight = torch.randn(256, 128, device='cuda', dtype=self.dtype) + output = F.linear(inp, weight, bias=None) + + def test_dropout(self): + inp = torch.randn(16, 8, 64, 64, device='cuda', dtype=self.dtype) + output = F.dropout(inp, p=0.5, training=True, inplace=False) + + def test_alpha_dropout(self): + inp = torch.randn(16, 8, 64, 64, device='cuda', dtype=self.dtype) + output = F.alpha_dropout(inp, p=0.5, training=True, inplace=False) + + def test_dropout2d(self): + inp = torch.randn(16, 8, 64, 64, device='cuda', dtype=self.dtype) + output = F.dropout2d(inp, p=0.5, training=True, inplace=False) + + def test_dropout3d(self): + inp = torch.randn(16, 8, 32, 64, 64, device='cuda', dtype=self.dtype) + output = F.dropout3d(inp, p=0.5, training=True, inplace=False) + + def test_embedding(self): + pre_embed_dim = 1024 + post_embed_dim = 32 + inp = torch.randint(0, pre_embed_dim, (128, 16), device='cuda') + weight = torch.randn(pre_embed_dim, post_embed_dim, device='cuda', dtype=self.dtype) + output = F.embedding(inp, weight, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False) + + def test_embedding_bag(self): + pre_embed_dim = 1024 + post_embed_dim = 32 + inp = torch.randint(0, pre_embed_dim, (128, 16), device='cuda') + weight = torch.randn(pre_embed_dim, post_embed_dim, device='cuda', dtype=self.dtype) + output = F.embedding_bag(inp, weight, offsets=None, max_norm=None, norm_type=2, + scale_grad_by_freq=False, mode='mean', sparse=False) + + def test_one_hot(self): + num_classes = 10 + inp = torch.randint(0, num_classes, (128, 16), device='cuda') + output = F.one_hot(inp, num_classes=10) + + def test_pairwise_distance(self): + inp1 = torch.randn(1024, 128, device='cuda', dtype=self.dtype) + inp2 = torch.randn(1024, 128, device='cuda', dtype=self.dtype) + output = F.pairwise_distance(inp1, inp2, p=2.0, eps=1e-06, keepdim=False) + + def test_cosine_similarity(self): + inp1 = torch.randn(1024, 128, device='cuda', dtype=self.dtype) + inp2 = torch.randn(1024, 128, device='cuda', dtype=self.dtype) + output = F.cosine_similarity(inp1, inp2, dim=1, eps=1e-8) + + def test_pdist(self): + # pdist is not implemented for fp16 + inp = torch.randn(128, 128, device='cuda', dtype=torch.float32) + output = F.pdist(inp, p=2) + + def test_binary_cross_entropy(self): + # binary_cross_entropy is not implemented for fp16 + inp = torch.randn(32, 128, device='cuda', dtype=torch.float32, requires_grad=True) + target = torch.randn(32, 128, device='cuda', dtype=torch.float32, requires_grad=False) + output = F.binary_cross_entropy(torch.sigmoid(inp), target) + + def test_binary_cross_entropy_with_logits(self): + inp = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) + target = torch.empty_like(inp).random_(2) + output = F.binary_cross_entropy_with_logits(inp, target) + + def test_poisson_nll_loss(self): + inp = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) + target = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=False) + output = F.poisson_nll_loss(inp, target, log_input=True, full=False, + size_average=None, eps=1e-08, reduce=None, reduction='mean') + + def test_cosine_embedding_loss(self): + inp1 = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) + inp2 = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) + target = torch.randn(32, device='cuda', dtype=self.dtype, requires_grad=False) + output = F.cosine_embedding_loss(inp1, inp2, target, margin=0, + size_average=None, reduce=None, reduction='mean') + + def test_cross_entropy(self): + inp = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) + target = torch.randint(0, 100, (32,), device='cuda', dtype=torch.long, requires_grad=False) + output = F.cross_entropy(inp, target, weight=None, size_average=None, + ignore_index=-100, reduce=None, reduction='mean') + + def test_ctc_loss(self): + # force fp32 because _th_normal_ (used by next line is not supported for fp16) + log_probs = torch.randn(50, 16, 20, device='cuda', dtype=torch.float32).log_softmax(2).detach().requires_grad_() + targets = torch.randint(1, 20, (16, 30), device='cuda', dtype=torch.long) + input_lengths = torch.full((16,), 50, dtype=torch.long) + target_lengths = torch.randint(10, 30, (16,), dtype=torch.long) + loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) + + def test_hinge_embedding_loss(self): + inp = torch.randn(128, 32, device='cuda', dtype=self.dtype) + target = torch.randint(0, 1, (32,), device='cuda') - 1 + output = F.hinge_embedding_loss(inp, target, margin=1.0, size_average=None, reduce=None, reduction='mean') + + def test_kl_div(self): + inp = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) + target = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) + output = F.kl_div(inp, target, size_average=None, reduce=None, reduction='batchmean') + + def test_mse_loss(self): + inp = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) + target = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) + output = F.mse_loss(inp, target, size_average=None, reduce=None, reduction='mean') + + def test_margin_ranking_loss(self): + inp1 = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) + inp2 = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) + target = (torch.randint(0, 1, (128,), device='cuda') - 1).type_as(inp1) + output = F.margin_ranking_loss(inp1, inp2, target, margin=0, size_average=None, reduce=None, reduction='mean') + + def test_multilabel_margin_loss(self): + inp = torch.randn(1024, device='cuda', dtype=self.dtype, requires_grad=True) + target = torch.randint(0, 10, (1024,), dtype=torch.long, device='cuda') + output = F.multilabel_margin_loss(inp, target, size_average=None, reduce=None, reduction='mean') + + def test_nll_loss(self): + inp = torch.randn(64, 128, device='cuda', dtype=self.dtype, requires_grad=True) + target = torch.randint(0, 10, (64,), device='cuda', dtype=torch.long) + output = F.nll_loss(inp, target, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean') + + def test_smooth_l1_loss(self): + inp = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) + target = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=False) + output = F.smooth_l1_loss(inp, target, size_average=None, reduce=None, reduction='mean') + + def test_soft_margin_loss(self): + inp = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) + target = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=False) + output = F.soft_margin_loss(inp, target, size_average=None, reduce=None, reduction='mean') + + def test_triplet_margin_loss(self): + inp1 = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) + inp2 = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) + inp3 = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) + output = F.triplet_margin_loss(inp1, inp2, inp3, margin=1.0, p=2, + eps=1e-06, swap=False, size_average=None, reduce=None, reduction='mean') + + def test_pixel_shuffle(self): + inp = torch.randn(16, 8, 64, 64, device='cuda', dtype=self.dtype) + output = torch.nn.functional.pixel_shuffle(inp, 2) + + def test_pad(self): + inp = torch.randn(16, 8, 64, 64, device='cuda', dtype=self.dtype) + pad = (3, 3) + output = F.pad(inp, pad, mode='constant', value=0) + + def test_interpolate(self): + inp = torch.randn(16, 8, 64, 64, device='cuda', dtype=self.dtype) + output = F.interpolate(inp, size=None, scale_factor=2, mode='nearest', align_corners=None) + + def test_grid_sample(self): + inp = torch.randn(16, 8, 64, 64, device='cuda', dtype=self.dtype) + grid = torch.randn(16, 32, 32, 2, device='cuda', dtype=self.dtype) + output = F.grid_sample(inp, grid, mode='bilinear', padding_mode='zeros') + + def test_affine_grid(self): + theta = torch.randn(32, 2, 3, device='cuda', dtype=self.dtype) + size = (32, 8, 32, 32) + output = F.affine_grid(theta, size) + + +def run_tests(precision): + dummy = TestPyProfNvtx('test_affine_grid', None) + test_cases = list(filter(lambda x: 'test_' in x, map(lambda x: x[0], inspect.getmembers(dummy, predicate=inspect.ismethod)))) + print("Running tests for {}".format(precision)) + suite = unittest.TestSuite() + for test_case in test_cases: + suite.addTest(TestPyProfNvtx(test_case, precision)) + unittest.TextTestRunner().run(suite) + +if __name__ == '__main__': + run_tests(torch.float32) + run_tests(torch.float16) diff --git a/tests/L0/run_test.py b/tests/L0/run_test.py index 93c7529b1..3ded0cf6c 100644 --- a/tests/L0/run_test.py +++ b/tests/L0/run_test.py @@ -1,7 +1,7 @@ import unittest import sys -test_dirs = ["run_amp", "run_fp16util", "run_mixed_adam", "run_fused_layer_norm"] +test_dirs = ["run_amp", "run_fp16util", "run_mixed_adam", "run_fused_layer_norm", "run_pyprof_nvtx"] runner = unittest.TextTestRunner(verbosity=2)