blob: 1bd094e55ba8e3b312f97eb488ba72804f5f8fa8 [file] [log] [blame]
Justin Lebar6f04ed92016-09-07 20:37:41 +00001=========================
Justin Lebar7029cb52016-09-07 20:09:53 +00002Compiling CUDA with clang
Justin Lebar6f04ed92016-09-07 20:37:41 +00003=========================
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +00004
5.. contents::
6 :local:
7
8Introduction
9============
10
Justin Lebar7029cb52016-09-07 20:09:53 +000011This document describes how to compile CUDA code with clang, and gives some
12details about LLVM and clang's CUDA implementations.
13
14This document assumes a basic familiarity with CUDA. Information about CUDA
15programming can be found in the
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000016`CUDA programming guide
17<http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html>`_.
18
Justin Lebar7029cb52016-09-07 20:09:53 +000019Compiling CUDA Code
20===================
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000021
Justin Lebar7029cb52016-09-07 20:09:53 +000022Prerequisites
23-------------
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000024
Justin Lebar7029cb52016-09-07 20:09:53 +000025CUDA is supported in llvm 3.9, but it's still in active development, so we
26recommend you `compile clang/LLVM from HEAD
27<http://llvm.org/docs/GettingStarted.html>`_.
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000028
Justin Lebar7029cb52016-09-07 20:09:53 +000029Before you build CUDA code, you'll need to have installed the appropriate
30driver for your nvidia GPU and the CUDA SDK. See `NVIDIA's CUDA installation
31guide <https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html>`_
32for details. Note that clang `does not support
33<https://llvm.org/bugs/show_bug.cgi?id=26966>`_ the CUDA toolkit as installed
34by many Linux package managers; you probably need to install nvidia's package.
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000035
Justin Lebar7029cb52016-09-07 20:09:53 +000036You will need CUDA 7.0 or 7.5 to compile with clang. CUDA 8 support is in the
37works.
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000038
Justin Lebar6f04ed92016-09-07 20:37:41 +000039Invoking clang
40--------------
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000041
Justin Lebar6f04ed92016-09-07 20:37:41 +000042Invoking clang for CUDA compilation works similarly to compiling regular C++.
43You just need to be aware of a few additional flags.
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000044
Justin Lebar62d5b012016-09-07 20:42:24 +000045You can use `this <https://gist.github.com/855e277884eb6b388cd2f00d956c2fd4>`_
Justin Lebar1c102572016-09-07 21:46:21 +000046program as a toy example. Save it as ``axpy.cu``. (Clang detects that you're
47compiling CUDA code by noticing that your filename ends with ``.cu``.
48Alternatively, you can pass ``-x cuda``.)
49
50To build and run, run the following commands, filling in the parts in angle
51brackets as described below:
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000052
53.. code-block:: console
54
Justin Lebar6f04ed92016-09-07 20:37:41 +000055 $ clang++ axpy.cu -o axpy --cuda-gpu-arch=<GPU arch> \
56 -L<CUDA install path>/<lib64 or lib> \
Jingyue Wu313496b2016-01-30 23:48:47 +000057 -lcudart_static -ldl -lrt -pthread
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000058 $ ./axpy
59 y[0] = 2
60 y[1] = 4
61 y[2] = 6
62 y[3] = 8
63
Justin Lebar1c102572016-09-07 21:46:21 +000064* ``<CUDA install path>`` -- the directory where you installed CUDA SDK.
65 Typically, ``/usr/local/cuda``.
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +000066
Justin Lebar1c102572016-09-07 21:46:21 +000067 Pass e.g. ``-L/usr/local/cuda/lib64`` if compiling in 64-bit mode; otherwise,
68 pass e.g. ``-L/usr/local/cuda/lib``. (In CUDA, the device code and host code
69 always have the same pointer widths, so if you're compiling 64-bit code for
70 the host, you're also compiling 64-bit code for the device.)
Justin Lebar84473cd2016-09-07 20:09:46 +000071
Justin Lebar1c102572016-09-07 21:46:21 +000072* ``<GPU arch>`` -- the `compute capability
73 <https://developer.nvidia.com/cuda-gpus>`_ of your GPU. For example, if you
74 want to run your program on a GPU with compute capability of 3.5, specify
Justin Lebar6f04ed92016-09-07 20:37:41 +000075 ``--cuda-gpu-arch=sm_35``.
Justin Lebar32835c82016-03-21 23:05:15 +000076
Justin Lebar6f04ed92016-09-07 20:37:41 +000077 Note: You cannot pass ``compute_XX`` as an argument to ``--cuda-gpu-arch``;
78 only ``sm_XX`` is currently supported. However, clang always includes PTX in
79 its binaries, so e.g. a binary compiled with ``--cuda-gpu-arch=sm_30`` would be
80 forwards-compatible with e.g. ``sm_35`` GPUs.
Justin Lebar32835c82016-03-21 23:05:15 +000081
Justin Lebar1c102572016-09-07 21:46:21 +000082 You can pass ``--cuda-gpu-arch`` multiple times to compile for multiple archs.
Justin Lebar32835c82016-03-21 23:05:15 +000083
Justin Lebarb5cb9df2016-09-07 21:46:49 +000084The `-L` and `-l` flags only need to be passed when linking. When compiling,
85you may also need to pass ``--cuda-path=/path/to/cuda`` if you didn't install
86the CUDA SDK into ``/usr/local/cuda``, ``/usr/local/cuda-7.0``, or
87``/usr/local/cuda-7.5``.
88
Justin Lebarb649e752016-05-25 23:11:31 +000089Flags that control numerical code
Justin Lebar6f04ed92016-09-07 20:37:41 +000090---------------------------------
Justin Lebarb649e752016-05-25 23:11:31 +000091
92If you're using GPUs, you probably care about making numerical code run fast.
93GPU hardware allows for more control over numerical operations than most CPUs,
94but this results in more compiler options for you to juggle.
95
96Flags you may wish to tweak include:
97
98* ``-ffp-contract={on,off,fast}`` (defaults to ``fast`` on host and device when
99 compiling CUDA) Controls whether the compiler emits fused multiply-add
100 operations.
101
102 * ``off``: never emit fma operations, and prevent ptxas from fusing multiply
103 and add instructions.
104 * ``on``: fuse multiplies and adds within a single statement, but never
105 across statements (C11 semantics). Prevent ptxas from fusing other
106 multiplies and adds.
107 * ``fast``: fuse multiplies and adds wherever profitable, even across
108 statements. Doesn't prevent ptxas from fusing additional multiplies and
109 adds.
110
111 Fused multiply-add instructions can be much faster than the unfused
112 equivalents, but because the intermediate result in an fma is not rounded,
113 this flag can affect numerical code.
114
115* ``-fcuda-flush-denormals-to-zero`` (default: off) When this is enabled,
116 floating point operations may flush `denormal
117 <https://en.wikipedia.org/wiki/Denormal_number>`_ inputs and/or outputs to 0.
118 Operations on denormal numbers are often much slower than the same operations
119 on normal numbers.
120
121* ``-fcuda-approx-transcendentals`` (default: off) When this is enabled, the
122 compiler may emit calls to faster, approximate versions of transcendental
123 functions, instead of using the slower, fully IEEE-compliant versions. For
124 example, this flag allows clang to emit the ptx ``sin.approx.f32``
125 instruction.
126
127 This is implied by ``-ffast-math``.
128
Justin Lebara4fa3592016-09-15 02:04:32 +0000129Standard library support
130========================
131
132In clang and nvcc, most of the C++ standard library is not supported on the
133device side.
134
135``math.h`` and ``cmath``
136------------------------
137
138In clang, ``math.h`` and ``cmath`` are available and `pass
139<https://github.com/llvm-mirror/test-suite/blob/master/External/CUDA/math_h.cu>`_
140`tests
141<https://github.com/llvm-mirror/test-suite/blob/master/External/CUDA/cmath.cu>`_
142adapted from libc++'s test suite.
143
144In nvcc ``math.h`` and ``cmath`` are mostly available. Versions of ``::foof``
145in namespace std (e.g. ``std::sinf``) are not available, and where the standard
146calls for overloads that take integral arguments, these are usually not
147available.
148
149.. code-block:: c++
150
151 #include <math.h>
152 #include <cmath.h>
153
154 // clang is OK with everything in this function.
155 __device__ void test() {
156 std::sin(0.); // nvcc - ok
157 std::sin(0); // nvcc - error, because no std::sin(int) override is available.
158 sin(0); // nvcc - same as above.
159
160 sinf(0.); // nvcc - ok
161 std::sinf(0.); // nvcc - no such function
162 }
163
164``std::complex``
165----------------
166
167nvcc does not officially support ``std::complex``. It's an error to use
168``std::complex`` in ``__device__`` code, but it often works in ``__host__
169__device__`` code due to nvcc's interpretation of the "wrong-side rule" (see
170below). However, we have heard from implementers that it's possible to get
171into situations where nvcc will omit a call to an ``std::complex`` function,
172especially when compiling without optimizations.
173
174clang does not yet support ``std::complex``. Because we interpret the
175"wrong-side rule" more strictly than nvcc, ``std::complex`` doesn't work in
176``__device__`` or ``__host__ __device__`` code.
177
178In the meantime, you can get limited ``std::complex`` support in clang by
179building your code for C++14. In clang, all ``constexpr`` functions are always
180implicitly ``__host__ __device__`` (this corresponds to nvcc's
181``--relaxed-constexpr`` flag). In C++14, many ``std::complex`` functions are
182``constexpr``, so you can use these with clang. (nvcc does not currently
183support C++14.)
184
185
Justin Lebar6f04ed92016-09-07 20:37:41 +0000186Detecting clang vs NVCC from code
187=================================
188
189Although clang's CUDA implementation is largely compatible with NVCC's, you may
190still want to detect when you're compiling CUDA code specifically with clang.
191
192This is tricky, because NVCC may invoke clang as part of its own compilation
193process! For example, NVCC uses the host compiler's preprocessor when
194compiling for device code, and that host compiler may in fact be clang.
195
196When clang is actually compiling CUDA code -- rather than being used as a
197subtool of NVCC's -- it defines the ``__CUDA__`` macro. ``__CUDA_ARCH__`` is
198defined only in device mode (but will be defined if NVCC is using clang as a
199preprocessor). So you can use the following incantations to detect clang CUDA
200compilation, in host and device modes:
201
202.. code-block:: c++
203
204 #if defined(__clang__) && defined(__CUDA__) && !defined(__CUDA_ARCH__)
Justin Lebara4fa3592016-09-15 02:04:32 +0000205 // clang compiling CUDA code, host mode.
Justin Lebar6f04ed92016-09-07 20:37:41 +0000206 #endif
207
208 #if defined(__clang__) && defined(__CUDA__) && defined(__CUDA_ARCH__)
Justin Lebara4fa3592016-09-15 02:04:32 +0000209 // clang compiling CUDA code, device mode.
Justin Lebar6f04ed92016-09-07 20:37:41 +0000210 #endif
211
212Both clang and nvcc define ``__CUDACC__`` during CUDA compilation. You can
213detect NVCC specifically by looking for ``__NVCC__``.
214
Justin Lebara4fa3592016-09-15 02:04:32 +0000215Dialect Differences Between clang and nvcc
216==========================================
217
218There is no formal CUDA spec, and clang and nvcc speak slightly different
219dialects of the language. Below, we describe some of the differences.
220
221This section is painful; hopefully you can skip this section and live your life
222blissfully unaware.
223
224Compilation Models
225------------------
226
227Most of the differences between clang and nvcc stem from the different
228compilation models used by clang and nvcc. nvcc uses *split compilation*,
229which works roughly as follows:
230
231 * Run a preprocessor over the input ``.cu`` file to split it into two source
232 files: ``H``, containing source code for the host, and ``D``, containing
233 source code for the device.
234
235 * For each GPU architecture ``arch`` that we're compiling for, do:
236
237 * Compile ``D`` using nvcc proper. The result of this is a ``ptx`` file for
238 ``P_arch``.
239
240 * Optionally, invoke ``ptxas``, the PTX assembler, to generate a file,
241 ``S_arch``, containing GPU machine code (SASS) for ``arch``.
242
243 * Invoke ``fatbin`` to combine all ``P_arch`` and ``S_arch`` files into a
244 single "fat binary" file, ``F``.
245
246 * Compile ``H`` using an external host compiler (gcc, clang, or whatever you
247 like). ``F`` is packaged up into a header file which is force-included into
248 ``H``; nvcc generates code that calls into this header to e.g. launch
249 kernels.
250
251clang uses *merged parsing*. This is similar to split compilation, except all
252of the host and device code is present and must be semantically-correct in both
253compilation steps.
254
255 * For each GPU architecture ``arch`` that we're compiling for, do:
256
257 * Compile the input ``.cu`` file for device, using clang. ``__host__`` code
258 is parsed and must be semantically correct, even though we're not
259 generating code for the host at this time.
260
261 The output of this step is a ``ptx`` file ``P_arch``.
262
263 * Invoke ``ptxas`` to generate a SASS file, ``S_arch``. Note that, unlike
264 nvcc, clang always generates SASS code.
265
266 * Invoke ``fatbin`` to combine all ``P_arch`` and ``S_arch`` files into a
267 single fat binary file, ``F``.
268
269 * Compile ``H`` using clang. ``__device__`` code is parsed and must be
270 semantically correct, even though we're not generating code for the device
271 at this time.
272
273 ``F`` is passed to this compilation, and clang includes it in a special ELF
274 section, where it can be found by tools like ``cuobjdump``.
275
276(You may ask at this point, why does clang need to parse the input file
277multiple times? Why not parse it just once, and then use the AST to generate
278code for the host and each device architecture?
279
280Unfortunately this can't work because we have to define different macros during
281host compilation and during device compilation for each GPU architecture.)
282
283clang's approach allows it to be highly robust to C++ edge cases, as it doesn't
284need to decide at an early stage which declarations to keep and which to throw
285away. But it has some consequences you should be aware of.
286
287Overloading Based on ``__host__`` and ``__device__`` Attributes
288---------------------------------------------------------------
289
290Let "H", "D", and "HD" stand for "``__host__`` functions", "``__device__``
291functions", and "``__host__ __device__`` functions", respectively. Functions
292with no attributes behave the same as H.
293
294nvcc does not allow you to create H and D functions with the same signature:
295
296.. code-block:: c++
297
298 // nvcc: error - function "foo" has already been defined
299 __host__ void foo() {}
300 __device__ void foo() {}
301
302However, nvcc allows you to "overload" H and D functions with different
303signatures:
304
305.. code-block:: c++
306
307 // nvcc: no error
308 __host__ void foo(int) {}
309 __device__ void foo() {}
310
311In clang, the ``__host__`` and ``__device__`` attributes are part of a
312function's signature, and so it's legal to have H and D functions with
313(otherwise) the same signature:
314
315.. code-block:: c++
316
317 // clang: no error
318 __host__ void foo() {}
319 __device__ void foo() {}
320
321HD functions cannot be overloaded by H or D functions with the same signature:
322
323.. code-block:: c++
324
325 // nvcc: error - function "foo" has already been defined
326 // clang: error - redefinition of 'foo'
327 __host__ __device__ void foo() {}
328 __device__ void foo() {}
329
330 // nvcc: no error
331 // clang: no error
332 __host__ __device__ void bar(int) {}
333 __device__ void bar() {}
334
335When resolving an overloaded function, clang considers the host/device
336attributes of the caller and callee. These are used as a tiebreaker during
337overload resolution. See `IdentifyCUDAPreference
338<http://clang.llvm.org/doxygen/SemaCUDA_8cpp.html>`_ for the full set of rules,
339but at a high level they are:
340
341 * D functions prefer to call other Ds. HDs are given lower priority.
342
343 * Similarly, H functions prefer to call other Hs, or ``__global__`` functions
344 (with equal priority). HDs are given lower priority.
345
346 * HD functions prefer to call other HDs.
347
348 When compiling for device, HDs will call Ds with lower priority than HD, and
349 will call Hs with still lower priority. If it's forced to call an H, the
350 program is malformed if we emit code for this HD function. We call this the
351 "wrong-side rule", see example below.
352
353 The rules are symmetrical when compiling for host.
354
355Some examples:
356
357.. code-block:: c++
358
359 __host__ void foo();
360 __device__ void foo();
361
362 __host__ void bar();
363 __host__ __device__ void bar();
364
365 __host__ void test_host() {
366 foo(); // calls H overload
367 bar(); // calls H overload
368 }
369
370 __device__ void test_device() {
371 foo(); // calls D overload
372 bar(); // calls HD overload
373 }
374
375 __host__ __device__ void test_hd() {
376 foo(); // calls H overload when compiling for host, otherwise D overload
377 bar(); // always calls HD overload
378 }
379
380Wrong-side rule example:
381
382.. code-block:: c++
383
384 __host__ void host_only();
385
386 // We don't codegen inline functions unless they're referenced by a
387 // non-inline function. inline_hd1() is called only from the host side, so
388 // does not generate an error. inline_hd2() is called from the device side,
389 // so it generates an error.
390 inline __host__ __device__ void inline_hd1() { host_only(); } // no error
391 inline __host__ __device__ void inline_hd2() { host_only(); } // error
392
393 __host__ void host_fn() { inline_hd1(); }
394 __device__ void device_fn() { inline_hd2(); }
395
396 // This function is not inline, so it's always codegen'ed on both the host
397 // and the device. Therefore, it generates an error.
398 __host__ __device__ void not_inline_hd() { host_only(); }
399
400For the purposes of the wrong-side rule, templated functions also behave like
401``inline`` functions: They aren't codegen'ed unless they're instantiated
402(usually as part of the process of invoking them).
403
404clang's behavior with respect to the wrong-side rule matches nvcc's, except
405nvcc only emits a warning for ``not_inline_hd``; device code is allowed to call
406``not_inline_hd``. In its generated code, nvcc may omit ``not_inline_hd``'s
407call to ``host_only`` entirely, or it may try to generate code for
408``host_only`` on the device. What you get seems to depend on whether or not
409the compiler chooses to inline ``host_only``.
410
411Member functions, including constructors, may be overloaded using H and D
412attributes. However, destructors cannot be overloaded.
413
414Using a Different Class on Host/Device
415--------------------------------------
416
417Occasionally you may want to have a class with different host/device versions.
418
419If all of the class's members are the same on the host and device, you can just
420provide overloads for the class's member functions.
421
422However, if you want your class to have different members on host/device, you
423won't be able to provide working H and D overloads in both classes. In this
424case, clang is likely to be unhappy with you.
425
426.. code-block:: c++
427
428 #ifdef __CUDA_ARCH__
429 struct S {
430 __device__ void foo() { /* use device_only */ }
431 int device_only;
432 };
433 #else
434 struct S {
435 __host__ void foo() { /* use host_only */ }
436 double host_only;
437 };
438
439 __device__ void test() {
440 S s;
441 // clang generates an error here, because during host compilation, we
442 // have ifdef'ed away the __device__ overload of S::foo(). The __device__
443 // overload must be present *even during host compilation*.
444 S.foo();
445 }
446 #endif
447
448We posit that you don't really want to have classes with different members on H
449and D. For example, if you were to pass one of these as a parameter to a
450kernel, it would have a different layout on H and D, so would not work
451properly.
452
453To make code like this compatible with clang, we recommend you separate it out
454into two classes. If you need to write code that works on both host and
455device, consider writing an overloaded wrapper function that returns different
456types on host and device.
457
458.. code-block:: c++
459
460 struct HostS { ... };
461 struct DeviceS { ... };
462
463 __host__ HostS MakeStruct() { return HostS(); }
464 __device__ DeviceS MakeStruct() { return DeviceS(); }
465
466 // Now host and device code can call MakeStruct().
467
468Unfortunately, this idiom isn't compatible with nvcc, because it doesn't allow
469you to overload based on the H/D attributes. Here's an idiom that works with
470both clang and nvcc:
471
472.. code-block:: c++
473
474 struct HostS { ... };
475 struct DeviceS { ... };
476
477 #ifdef __NVCC__
478 #ifndef __CUDA_ARCH__
479 __host__ HostS MakeStruct() { return HostS(); }
480 #else
481 __device__ DeviceS MakeStruct() { return DeviceS(); }
482 #endif
483 #else
484 __host__ HostS MakeStruct() { return HostS(); }
485 __device__ DeviceS MakeStruct() { return DeviceS(); }
486 #endif
487
488 // Now host and device code can call MakeStruct().
489
490Hopefully you don't have to do this sort of thing often.
491
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +0000492Optimizations
493=============
494
Justin Lebar66feaf92016-09-07 21:46:53 +0000495Modern CPUs and GPUs are architecturally quite different, so code that's fast
496on a CPU isn't necessarily fast on a GPU. We've made a number of changes to
497LLVM to make it generate good GPU code. Among these changes are:
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +0000498
Justin Lebar66feaf92016-09-07 21:46:53 +0000499* `Straight-line scalar optimizations <https://goo.gl/4Rb9As>`_ -- These
500 reduce redundancy within straight-line code.
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +0000501
Justin Lebar66feaf92016-09-07 21:46:53 +0000502* `Aggressive speculative execution
503 <http://llvm.org/docs/doxygen/html/SpeculativeExecution_8cpp_source.html>`_
504 -- This is mainly for promoting straight-line scalar optimizations, which are
505 most effective on code along dominator paths.
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +0000506
Justin Lebar66feaf92016-09-07 21:46:53 +0000507* `Memory space inference
508 <http://llvm.org/doxygen/NVPTXInferAddressSpaces_8cpp_source.html>`_ --
509 In PTX, we can operate on pointers that are in a paricular "address space"
510 (global, shared, constant, or local), or we can operate on pointers in the
511 "generic" address space, which can point to anything. Operations in a
512 non-generic address space are faster, but pointers in CUDA are not explicitly
513 annotated with their address space, so it's up to LLVM to infer it where
514 possible.
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +0000515
Justin Lebar66feaf92016-09-07 21:46:53 +0000516* `Bypassing 64-bit divides
517 <http://llvm.org/docs/doxygen/html/BypassSlowDivision_8cpp_source.html>`_ --
518 This was an existing optimization that we enabled for the PTX backend.
519
520 64-bit integer divides are much slower than 32-bit ones on NVIDIA GPUs.
521 Many of the 64-bit divides in our benchmarks have a divisor and dividend
522 which fit in 32-bits at runtime. This optimization provides a fast path for
523 this common case.
524
525* Aggressive loop unrooling and function inlining -- Loop unrolling and
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +0000526 function inlining need to be more aggressive for GPUs than for CPUs because
Justin Lebar66feaf92016-09-07 21:46:53 +0000527 control flow transfer in GPU is more expensive. More aggressive unrolling and
528 inlining also promote other optimizations, such as constant propagation and
529 SROA, which sometimes speed up code by over 10x.
530
531 (Programmers can force unrolling and inline using clang's `loop unrolling pragmas
Jingyue Wu4f2a6cb2015-11-10 22:35:47 +0000532 <http://clang.llvm.org/docs/AttributeReference.html#pragma-unroll-pragma-nounroll>`_
Justin Lebar66feaf92016-09-07 21:46:53 +0000533 and ``__attribute__((always_inline))``.)
Jingyue Wubec78182016-02-23 23:34:49 +0000534
Jingyue Wuf190ed42016-03-30 05:05:40 +0000535Publication
536===========
537
Justin Lebar66feaf92016-09-07 21:46:53 +0000538The team at Google published a paper in CGO 2016 detailing the optimizations
539they'd made to clang/LLVM. Note that "gpucc" is no longer a meaningful name:
540The relevant tools are now just vanilla clang/LLVM.
541
Jingyue Wuf190ed42016-03-30 05:05:40 +0000542| `gpucc: An Open-Source GPGPU Compiler <http://dl.acm.org/citation.cfm?id=2854041>`_
543| Jingyue Wu, Artem Belevich, Eli Bendersky, Mark Heffernan, Chris Leary, Jacques Pienaar, Bjarke Roune, Rob Springer, Xuetian Weng, Robert Hundt
544| *Proceedings of the 2016 International Symposium on Code Generation and Optimization (CGO 2016)*
Justin Lebar66feaf92016-09-07 21:46:53 +0000545|
546| `Slides from the CGO talk <http://wujingyue.com/docs/gpucc-talk.pdf>`_
547|
548| `Tutorial given at CGO <http://wujingyue.com/docs/gpucc-tutorial.pdf>`_
Jingyue Wuf190ed42016-03-30 05:05:40 +0000549
Jingyue Wubec78182016-02-23 23:34:49 +0000550Obtaining Help
551==============
552
553To obtain help on LLVM in general and its CUDA support, see `the LLVM
554community <http://llvm.org/docs/#mailing-lists>`_.