blob: 4e3cf0795b5beb960348db93440347150f93fda5 [file] [log] [blame]
Jason Henlineac232dd2016-10-25 20:18:56 +00001//===--- opencl_example.cpp - Example of using Acxxel with OpenCL ---------===//
2//
3// The LLVM Compiler Infrastructure
4//
5// This file is distributed under the University of Illinois Open Source
6// License. See LICENSE.TXT for details.
7//
8//===----------------------------------------------------------------------===//
9///
10/// This file is an example of using OpenCL with Acxxel.
11///
12//===----------------------------------------------------------------------===//
13
14#include "acxxel.h"
15
16#include <array>
17#include <cstdio>
18#include <cstring>
19
20static const char *SaxpyKernelSource = R"(
21__kernel void saxpyKernel(float A, __global float *X, __global float *Y, int N) {
22 int I = get_global_id(0);
23 if (I < N)
24 X[I] = A * X[I] + Y[I];
25}
26)";
27
28template <size_t N>
29void saxpy(float A, std::array<float, N> &X, const std::array<float, N> &Y) {
30 acxxel::Platform *OpenCL = acxxel::getOpenCLPlatform().getValue();
31 acxxel::Stream Stream = OpenCL->createStream().takeValue();
32 auto DeviceX = OpenCL->mallocD<float>(N).takeValue();
33 auto DeviceY = OpenCL->mallocD<float>(N).takeValue();
34 Stream.syncCopyHToD(X, DeviceX).syncCopyHToD(Y, DeviceY);
35 acxxel::Program Program =
36 OpenCL
37 ->createProgramFromSource(acxxel::Span<const char>(
38 SaxpyKernelSource, std::strlen(SaxpyKernelSource)))
39 .takeValue();
40 acxxel::Kernel Kernel = Program.createKernel("saxpyKernel").takeValue();
41 float *RawX = static_cast<float *>(DeviceX);
42 float *RawY = static_cast<float *>(DeviceY);
43 int IntLength = N;
44 void *Arguments[] = {&A, &RawX, &RawY, &IntLength};
45 size_t ArgumentSizes[] = {sizeof(float), sizeof(float *), sizeof(float *),
46 sizeof(int)};
47 acxxel::Status Status =
48 Stream.asyncKernelLaunch(Kernel, N, Arguments, ArgumentSizes)
49 .syncCopyDToH(DeviceX, X)
50 .sync();
51 if (Status.isError()) {
52 std::fprintf(stderr, "Error during saxpy: %s\n",
53 Status.getMessage().c_str());
54 std::exit(EXIT_FAILURE);
55 }
56}
57
58int main() {
59 float A = 2.f;
Jason Henline492c5a12016-12-19 21:34:07 +000060 std::array<float, 3> X{{0.f, 1.f, 2.f}};
61 std::array<float, 3> Y{{3.f, 4.f, 5.f}};
62 std::array<float, 3> Expected{{3.f, 6.f, 9.f}};
Jason Henlineac232dd2016-10-25 20:18:56 +000063 saxpy(A, X, Y);
64 for (int I = 0; I < 3; ++I)
65 if (X[I] != Expected[I]) {
66 std::fprintf(stderr, "Mismatch at position %d, %f != %f\n", I, X[I],
67 Expected[I]);
68 std::exit(EXIT_FAILURE);
69 }
70}