blob: 96c482541395550073f514fb9c692ed94ae3a673 [file] [log] [blame]
Justin Lebar6f04ed92016-09-07 20:37:41 +00001=========================
Justin Lebar7029cb52016-09-07 20:09:53 +00002Compiling CUDA with clang
Justin Lebar6f04ed92016-09-07 20:37:41 +00003=========================
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +00004
5.. contents::
6 :local:
7
8Introduction
9============
10
Justin Lebar7029cb52016-09-07 20:09:53 +000011This document describes how to compile CUDA code with clang, and gives some
12details about LLVM and clang's CUDA implementations.
13
14This document assumes a basic familiarity with CUDA. Information about CUDA
15programming can be found in the
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000016`CUDA programming guide
17<http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html>`_.
18
Justin Lebar7029cb52016-09-07 20:09:53 +000019Compiling CUDA Code
20===================
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000021
Justin Lebar7029cb52016-09-07 20:09:53 +000022Prerequisites
23-------------
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000024
Justin Lebar7029cb52016-09-07 20:09:53 +000025CUDA is supported in llvm 3.9, but it's still in active development, so we
26recommend you `compile clang/LLVM from HEAD
27<http://llvm.org/docs/GettingStarted.html>`_.
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000028
Justin Lebar7029cb52016-09-07 20:09:53 +000029Before you build CUDA code, you'll need to have installed the appropriate
30driver for your nvidia GPU and the CUDA SDK. See `NVIDIA's CUDA installation
31guide <https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html>`_
32for details. Note that clang `does not support
33<https://llvm.org/bugs/show_bug.cgi?id=26966>`_ the CUDA toolkit as installed
34by many Linux package managers; you probably need to install nvidia's package.
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000035
Justin Lebar7029cb52016-09-07 20:09:53 +000036You will need CUDA 7.0 or 7.5 to compile with clang. CUDA 8 support is in the
37works.
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000038
Justin Lebar6f04ed92016-09-07 20:37:41 +000039Invoking clang
40--------------
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000041
Justin Lebar6f04ed92016-09-07 20:37:41 +000042Invoking clang for CUDA compilation works similarly to compiling regular C++.
43You just need to be aware of a few additional flags.
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000044
Justin Lebar62d5b012016-09-07 20:42:24 +000045You can use `this <https://gist.github.com/855e277884eb6b388cd2f00d956c2fd4>`_
Justin Lebar6f04ed92016-09-07 20:37:41 +000046program as a toy example. Save it as ``axpy.cu``. To build and run, run the
47following commands:
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000048
49.. code-block:: console
50
Justin Lebar6f04ed92016-09-07 20:37:41 +000051 $ clang++ axpy.cu -o axpy --cuda-gpu-arch=<GPU arch> \
52 -L<CUDA install path>/<lib64 or lib> \
Jingyue Wu313496b2016-01-30 23:48:47 +000053 -lcudart_static -ldl -lrt -pthread
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000054 $ ./axpy
55 y[0] = 2
56 y[1] = 4
57 y[2] = 6
58 y[3] = 8
59
Justin Lebar6f04ed92016-09-07 20:37:41 +000060* clang detects that you're compiling CUDA by noticing that your source file ends
61 with ``.cu``. (Alternatively, you can pass ``-x cuda``.)
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000062
Justin Lebar6f04ed92016-09-07 20:37:41 +000063* ``<CUDA install path>`` is the root directory where you installed CUDA SDK,
64 typically ``/usr/local/cuda``.
Justin Lebar84473cd2016-09-07 20:09:46 +000065
Justin Lebar6f04ed92016-09-07 20:37:41 +000066 Pass e.g. ``/usr/local/cuda/lib64`` if compiling in 64-bit mode; otherwise,
67 pass ``/usr/local/cuda/lib``. (In CUDA, the device code and host code always
68 have the same pointer widths, so if you're compiling 64-bit code for the
69 host, you're also compiling 64-bit code for the device.)
Justin Lebar84473cd2016-09-07 20:09:46 +000070
Justin Lebar6f04ed92016-09-07 20:37:41 +000071* ``<GPU arch>`` is `the compute capability of your GPU
72 <https://developer.nvidia.com/cuda-gpus>`_. For example, if you want to run
73 your program on a GPU with compute capability of 3.5, you should specify
74 ``--cuda-gpu-arch=sm_35``.
Justin Lebar32835c82016-03-21 23:05:15 +000075
Justin Lebar6f04ed92016-09-07 20:37:41 +000076 Note: You cannot pass ``compute_XX`` as an argument to ``--cuda-gpu-arch``;
77 only ``sm_XX`` is currently supported. However, clang always includes PTX in
78 its binaries, so e.g. a binary compiled with ``--cuda-gpu-arch=sm_30`` would be
79 forwards-compatible with e.g. ``sm_35`` GPUs.
Justin Lebar32835c82016-03-21 23:05:15 +000080
Justin Lebar6f04ed92016-09-07 20:37:41 +000081 You can pass ``--cuda-gpu-arch`` multiple times to compile for multiple
82 archs.
Justin Lebar32835c82016-03-21 23:05:15 +000083
Justin Lebarb649e752016-05-25 23:11:31 +000084Flags that control numerical code
Justin Lebar6f04ed92016-09-07 20:37:41 +000085---------------------------------
Justin Lebarb649e752016-05-25 23:11:31 +000086
87If you're using GPUs, you probably care about making numerical code run fast.
88GPU hardware allows for more control over numerical operations than most CPUs,
89but this results in more compiler options for you to juggle.
90
91Flags you may wish to tweak include:
92
93* ``-ffp-contract={on,off,fast}`` (defaults to ``fast`` on host and device when
94 compiling CUDA) Controls whether the compiler emits fused multiply-add
95 operations.
96
97 * ``off``: never emit fma operations, and prevent ptxas from fusing multiply
98 and add instructions.
99 * ``on``: fuse multiplies and adds within a single statement, but never
100 across statements (C11 semantics). Prevent ptxas from fusing other
101 multiplies and adds.
102 * ``fast``: fuse multiplies and adds wherever profitable, even across
103 statements. Doesn't prevent ptxas from fusing additional multiplies and
104 adds.
105
106 Fused multiply-add instructions can be much faster than the unfused
107 equivalents, but because the intermediate result in an fma is not rounded,
108 this flag can affect numerical code.
109
110* ``-fcuda-flush-denormals-to-zero`` (default: off) When this is enabled,
111 floating point operations may flush `denormal
112 <https://en.wikipedia.org/wiki/Denormal_number>`_ inputs and/or outputs to 0.
113 Operations on denormal numbers are often much slower than the same operations
114 on normal numbers.
115
116* ``-fcuda-approx-transcendentals`` (default: off) When this is enabled, the
117 compiler may emit calls to faster, approximate versions of transcendental
118 functions, instead of using the slower, fully IEEE-compliant versions. For
119 example, this flag allows clang to emit the ptx ``sin.approx.f32``
120 instruction.
121
122 This is implied by ``-ffast-math``.
123
Justin Lebar6f04ed92016-09-07 20:37:41 +0000124Detecting clang vs NVCC from code
125=================================
126
127Although clang's CUDA implementation is largely compatible with NVCC's, you may
128still want to detect when you're compiling CUDA code specifically with clang.
129
130This is tricky, because NVCC may invoke clang as part of its own compilation
131process! For example, NVCC uses the host compiler's preprocessor when
132compiling for device code, and that host compiler may in fact be clang.
133
134When clang is actually compiling CUDA code -- rather than being used as a
135subtool of NVCC's -- it defines the ``__CUDA__`` macro. ``__CUDA_ARCH__`` is
136defined only in device mode (but will be defined if NVCC is using clang as a
137preprocessor). So you can use the following incantations to detect clang CUDA
138compilation, in host and device modes:
139
140.. code-block:: c++
141
142 #if defined(__clang__) && defined(__CUDA__) && !defined(__CUDA_ARCH__)
143 // clang compiling CUDA code, host mode.
144 #endif
145
146 #if defined(__clang__) && defined(__CUDA__) && defined(__CUDA_ARCH__)
147 // clang compiling CUDA code, device mode.
148 #endif
149
150Both clang and nvcc define ``__CUDACC__`` during CUDA compilation. You can
151detect NVCC specifically by looking for ``__NVCC__``.
152
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +0000153Optimizations
154=============
155
156CPU and GPU have different design philosophies and architectures. For example, a
157typical CPU has branch prediction, out-of-order execution, and is superscalar,
158whereas a typical GPU has none of these. Due to such differences, an
159optimization pipeline well-tuned for CPUs may be not suitable for GPUs.
160
161LLVM performs several general and CUDA-specific optimizations for GPUs. The
162list below shows some of the more important optimizations for GPUs. Most of
163them have been upstreamed to ``lib/Transforms/Scalar`` and
164``lib/Target/NVPTX``. A few of them have not been upstreamed due to lack of a
165customizable target-independent optimization pipeline.
166
167* **Straight-line scalar optimizations**. These optimizations reduce redundancy
168 in straight-line code. Details can be found in the `design document for
169 straight-line scalar optimizations <https://goo.gl/4Rb9As>`_.
170
171* **Inferring memory spaces**. `This optimization
Jingyue Wuf190ed42016-03-30 05:05:40 +0000172 <https://github.com/llvm-mirror/llvm/blob/master/lib/Target/NVPTX/NVPTXInferAddressSpaces.cpp>`_
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +0000173 infers the memory space of an address so that the backend can emit faster
Jingyue Wuf190ed42016-03-30 05:05:40 +0000174 special loads and stores from it.
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +0000175
176* **Aggressive loop unrooling and function inlining**. Loop unrolling and
177 function inlining need to be more aggressive for GPUs than for CPUs because
178 control flow transfer in GPU is more expensive. They also promote other
179 optimizations such as constant propagation and SROA which sometimes speed up
180 code by over 10x. An empirical inline threshold for GPUs is 1100. This
181 configuration has yet to be upstreamed with a target-specific optimization
182 pipeline. LLVM also provides `loop unrolling pragmas
183 <http://clang.llvm.org/docs/AttributeReference.html#pragma-unroll-pragma-nounroll>`_
184 and ``__attribute__((always_inline))`` for programmers to force unrolling and
185 inling.
186
187* **Aggressive speculative execution**. `This transformation
188 <http://llvm.org/docs/doxygen/html/SpeculativeExecution_8cpp_source.html>`_ is
189 mainly for promoting straight-line scalar optimizations which are most
190 effective on code along dominator paths.
191
192* **Memory-space alias analysis**. `This alias analysis
Jingyue Wu03d90e52015-11-18 22:01:44 +0000193 <http://reviews.llvm.org/D12414>`_ infers that two pointers in different
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +0000194 special memory spaces do not alias. It has yet to be integrated to the new
195 alias analysis infrastructure; the new infrastructure does not run
196 target-specific alias analysis.
197
198* **Bypassing 64-bit divides**. `An existing optimization
199 <http://llvm.org/docs/doxygen/html/BypassSlowDivision_8cpp_source.html>`_
200 enabled in the NVPTX backend. 64-bit integer divides are much slower than
201 32-bit ones on NVIDIA GPUs due to lack of a divide unit. Many of the 64-bit
202 divides in our benchmarks have a divisor and dividend which fit in 32-bits at
203 runtime. This optimization provides a fast path for this common case.
Jingyue Wubec78182016-02-23 23:34:49 +0000204
Jingyue Wuf190ed42016-03-30 05:05:40 +0000205Publication
206===========
207
208| `gpucc: An Open-Source GPGPU Compiler <http://dl.acm.org/citation.cfm?id=2854041>`_
209| Jingyue Wu, Artem Belevich, Eli Bendersky, Mark Heffernan, Chris Leary, Jacques Pienaar, Bjarke Roune, Rob Springer, Xuetian Weng, Robert Hundt
210| *Proceedings of the 2016 International Symposium on Code Generation and Optimization (CGO 2016)*
211| `Slides for the CGO talk <http://wujingyue.com/docs/gpucc-talk.pdf>`_
212
213Tutorial
214========
215
216`CGO 2016 gpucc tutorial <http://wujingyue.com/docs/gpucc-tutorial.pdf>`_
217
Jingyue Wubec78182016-02-23 23:34:49 +0000218Obtaining Help
219==============
220
221To obtain help on LLVM in general and its CUDA support, see `the LLVM
222community <http://llvm.org/docs/#mailing-lists>`_.