blob: aeaf81db9fe15e50321b79f387914e11c919c4cd [file] [log] [blame]
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +00001===================================
2Compiling CUDA C/C++ with LLVM
3===================================
4
5.. contents::
6 :local:
7
8Introduction
9============
10
11This document contains the user guides and the internals of compiling CUDA
12C/C++ with LLVM. It is aimed at both users who want to compile CUDA with LLVM
13and developers who want to improve LLVM for GPUs. This document assumes a basic
14familiarity with CUDA. Information about CUDA programming can be found in the
15`CUDA programming guide
16<http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html>`_.
17
18How to Build LLVM with CUDA Support
19===================================
20
Jingyue Wu313496b2016-01-30 23:48:47 +000021CUDA support is still in development and works the best in the trunk version
22of LLVM. Below is a quick summary of downloading and building the trunk
23version. Consult the `Getting Started
24<http://llvm.org/docs/GettingStarted.html>`_ page for more details on setting
25up LLVM.
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000026
27#. Checkout LLVM
28
29 .. code-block:: console
30
31 $ cd where-you-want-llvm-to-live
32 $ svn co http://llvm.org/svn/llvm-project/llvm/trunk llvm
33
34#. Checkout Clang
35
36 .. code-block:: console
37
38 $ cd where-you-want-llvm-to-live
39 $ cd llvm/tools
40 $ svn co http://llvm.org/svn/llvm-project/cfe/trunk clang
41
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000042#. Configure and build LLVM and Clang
43
44 .. code-block:: console
45
46 $ cd where-you-want-llvm-to-live
47 $ mkdir build
48 $ cd build
49 $ cmake [options] ..
50 $ make
51
52How to Compile CUDA C/C++ with LLVM
53===================================
54
55We assume you have installed the CUDA driver and runtime. Consult the `NVIDIA
56CUDA installation Guide
57<https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html>`_ if
58you have not.
59
60Suppose you want to compile and run the following CUDA program (``axpy.cu``)
61which multiplies a ``float`` array by a ``float`` scalar (AXPY).
62
63.. code-block:: c++
64
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000065 #include <iostream>
66
67 __global__ void axpy(float a, float* x, float* y) {
68 y[threadIdx.x] = a * x[threadIdx.x];
69 }
70
71 int main(int argc, char* argv[]) {
72 const int kDataLen = 4;
73
74 float a = 2.0f;
75 float host_x[kDataLen] = {1.0f, 2.0f, 3.0f, 4.0f};
76 float host_y[kDataLen];
77
78 // Copy input data to device.
79 float* device_x;
80 float* device_y;
Jingyue Wu313496b2016-01-30 23:48:47 +000081 cudaMalloc(&device_x, kDataLen * sizeof(float));
82 cudaMalloc(&device_y, kDataLen * sizeof(float));
83 cudaMemcpy(device_x, host_x, kDataLen * sizeof(float),
84 cudaMemcpyHostToDevice);
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000085
86 // Launch the kernel.
87 axpy<<<1, kDataLen>>>(a, device_x, device_y);
88
89 // Copy output data to host.
Jingyue Wu313496b2016-01-30 23:48:47 +000090 cudaDeviceSynchronize();
91 cudaMemcpy(host_y, device_y, kDataLen * sizeof(float),
92 cudaMemcpyDeviceToHost);
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000093
94 // Print the results.
95 for (int i = 0; i < kDataLen; ++i) {
96 std::cout << "y[" << i << "] = " << host_y[i] << "\n";
97 }
98
Jingyue Wu313496b2016-01-30 23:48:47 +000099 cudaDeviceReset();
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +0000100 return 0;
101 }
102
103The command line for compilation is similar to what you would use for C++.
104
105.. code-block:: console
106
Jingyue Wu313496b2016-01-30 23:48:47 +0000107 $ clang++ axpy.cu -o axpy --cuda-gpu-arch=<GPU arch> \
108 -L<CUDA install path>/<lib64 or lib> \
109 -lcudart_static -ldl -lrt -pthread
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +0000110 $ ./axpy
111 y[0] = 2
112 y[1] = 4
113 y[2] = 6
114 y[3] = 8
115
Jingyue Wu313496b2016-01-30 23:48:47 +0000116``<CUDA install path>`` is the root directory where you installed CUDA SDK,
117typically ``/usr/local/cuda``. ``<GPU arch>`` is `the compute capability of
118your GPU <https://developer.nvidia.com/cuda-gpus>`_. For example, if you want
119to run your program on a GPU with compute capability of 3.5, you should specify
120``--cuda-gpu-arch=sm_35``.
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +0000121
122Optimizations
123=============
124
125CPU and GPU have different design philosophies and architectures. For example, a
126typical CPU has branch prediction, out-of-order execution, and is superscalar,
127whereas a typical GPU has none of these. Due to such differences, an
128optimization pipeline well-tuned for CPUs may be not suitable for GPUs.
129
130LLVM performs several general and CUDA-specific optimizations for GPUs. The
131list below shows some of the more important optimizations for GPUs. Most of
132them have been upstreamed to ``lib/Transforms/Scalar`` and
133``lib/Target/NVPTX``. A few of them have not been upstreamed due to lack of a
134customizable target-independent optimization pipeline.
135
136* **Straight-line scalar optimizations**. These optimizations reduce redundancy
137 in straight-line code. Details can be found in the `design document for
138 straight-line scalar optimizations <https://goo.gl/4Rb9As>`_.
139
140* **Inferring memory spaces**. `This optimization
141 <http://www.llvm.org/docs/doxygen/html/NVPTXFavorNonGenericAddrSpaces_8cpp_source.html>`_
142 infers the memory space of an address so that the backend can emit faster
143 special loads and stores from it. Details can be found in the `design
144 document for memory space inference <https://goo.gl/5wH2Ct>`_.
145
146* **Aggressive loop unrooling and function inlining**. Loop unrolling and
147 function inlining need to be more aggressive for GPUs than for CPUs because
148 control flow transfer in GPU is more expensive. They also promote other
149 optimizations such as constant propagation and SROA which sometimes speed up
150 code by over 10x. An empirical inline threshold for GPUs is 1100. This
151 configuration has yet to be upstreamed with a target-specific optimization
152 pipeline. LLVM also provides `loop unrolling pragmas
153 <http://clang.llvm.org/docs/AttributeReference.html#pragma-unroll-pragma-nounroll>`_
154 and ``__attribute__((always_inline))`` for programmers to force unrolling and
155 inling.
156
157* **Aggressive speculative execution**. `This transformation
158 <http://llvm.org/docs/doxygen/html/SpeculativeExecution_8cpp_source.html>`_ is
159 mainly for promoting straight-line scalar optimizations which are most
160 effective on code along dominator paths.
161
162* **Memory-space alias analysis**. `This alias analysis
Jingyue Wu03d90e52015-11-18 22:01:44 +0000163 <http://reviews.llvm.org/D12414>`_ infers that two pointers in different
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +0000164 special memory spaces do not alias. It has yet to be integrated to the new
165 alias analysis infrastructure; the new infrastructure does not run
166 target-specific alias analysis.
167
168* **Bypassing 64-bit divides**. `An existing optimization
169 <http://llvm.org/docs/doxygen/html/BypassSlowDivision_8cpp_source.html>`_
170 enabled in the NVPTX backend. 64-bit integer divides are much slower than
171 32-bit ones on NVIDIA GPUs due to lack of a divide unit. Many of the 64-bit
172 divides in our benchmarks have a divisor and dividend which fit in 32-bits at
173 runtime. This optimization provides a fast path for this common case.