blob: 9b1be8252b6deb9296af0cc82c57df616d17dcf9 [file] [log] [blame]
Mike Kleinf9ae6702018-06-20 14:05:05 -04001/*
2 * Copyright 2018 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8// This is a simple OpenCL Hello World that tests you have a functioning OpenCL setup.
9
Mike Klein8a1f15d2019-02-11 11:59:41 -050010#include "cl.hpp"
Mike Kleinf9ae6702018-06-20 14:05:05 -040011#include <initializer_list>
Mike Klein8a1f15d2019-02-11 11:59:41 -050012#include <stdio.h>
13#include <stdlib.h>
14#include <string>
15#include <vector>
Mike Kleinf9ae6702018-06-20 14:05:05 -040016
Mike Klein8a1f15d2019-02-11 11:59:41 -050017static inline void assert_cl(cl_int rc, const char* file, int line) {
18 if (rc != CL_SUCCESS) {
19 fprintf(stderr, "%s:%d, got OpenCL error code %d\n", file,line,rc);
20 exit(1);
21 }
Mike Kleinf9ae6702018-06-20 14:05:05 -040022}
Mike Klein8a1f15d2019-02-11 11:59:41 -050023#define cl_ok(err) assert_cl(err, __FILE__, __LINE__)
Mike Kleinf9ae6702018-06-20 14:05:05 -040024
25int main(int argc, char** argv) {
26 // Find any OpenCL platform+device with these substrings.
27 const char* platform_match = argc > 1 ? argv[1] : "";
28 const char* device_match = argc > 2 ? argv[2] : "";
29
Mike Klein8a1f15d2019-02-11 11:59:41 -050030 cl::Platform platform;
31 {
32 std::vector<cl::Platform> platforms;
33 cl_ok(cl::Platform::get(&platforms));
Mike Kleinf9ae6702018-06-20 14:05:05 -040034
Mike Klein8a1f15d2019-02-11 11:59:41 -050035 bool found = false;
36 for (cl::Platform p : platforms) {
37 std::string name;
38 cl_ok(p.getInfo(CL_PLATFORM_NAME, &name));
Mike Kleinf9ae6702018-06-20 14:05:05 -040039
Mike Klein8a1f15d2019-02-11 11:59:41 -050040 fprintf(stdout, "Available platform %s\n", name.c_str());
Mike Kleinf9ae6702018-06-20 14:05:05 -040041
Mike Klein8a1f15d2019-02-11 11:59:41 -050042 if (name.find(platform_match) != std::string::npos) {
43 platform = p;
44 found = true;
45 }
46 }
47 if (!found) {
48 fprintf(stderr, "No platform containing '%s' found.\n", platform_match);
49 exit(1);
50 }
51 }
Mike Kleinf9ae6702018-06-20 14:05:05 -040052
Mike Klein8a1f15d2019-02-11 11:59:41 -050053 cl::Device device;
54 {
55 std::vector<cl::Device> devices;
56 cl_ok(platform.getDevices(CL_DEVICE_TYPE_ALL, &devices));
Mike Kleinf9ae6702018-06-20 14:05:05 -040057
Mike Klein8a1f15d2019-02-11 11:59:41 -050058 bool found = false;
59 for (cl::Device d : devices) {
60 std::string name,
61 version,
62 driver;
63 cl_ok(d.getInfo(CL_DEVICE_NAME, &name));
64 cl_ok(d.getInfo(CL_DEVICE_VERSION, &version));
65 cl_ok(d.getInfo(CL_DRIVER_VERSION, &driver));
Mike Kleinf9ae6702018-06-20 14:05:05 -040066
Mike Klein8a1f15d2019-02-11 11:59:41 -050067 fprintf(stdout, "Available device %s%s, driver version %s\n"
68 , version.c_str(), name.c_str(), driver.c_str());
69
70 if (name.find(device_match) != std::string::npos) {
71 device = d;
72 found = true;
73 }
74 }
75 if (!found) {
76 fprintf(stderr, "No device containing '%s' found.\n", device_match);
77 exit(2);
78 }
79 }
Mike Kleinf9ae6702018-06-20 14:05:05 -040080
81 std::string name,
82 vendor,
83 extensions;
84 cl_ok(device.getInfo(CL_DEVICE_NAME, &name));
85 cl_ok(device.getInfo(CL_DEVICE_VENDOR, &vendor));
86 cl_ok(device.getInfo(CL_DEVICE_EXTENSIONS, &extensions));
87
Mike Klein8a1f15d2019-02-11 11:59:41 -050088 fprintf(stdout, "Using %s, vendor %s, extensions:\n%s\n",
89 name.c_str(), vendor.c_str(), extensions.c_str());
Mike Kleinf9ae6702018-06-20 14:05:05 -040090
91 std::vector<cl::Device> devices = { device };
92
93 // Some APIs can't return their cl_int error but might still fail,
94 // so they take a pointer. cl_ok() is really handy here too.
95 cl_int ok;
96 cl::Context ctx(devices,
97 nullptr/*optional cl_context_properties*/,
98 nullptr/*optional error reporting callback*/,
99 nullptr/*context arguement for error reporting callback*/,
100 &ok);
101 cl_ok(ok);
102
103 cl::Program program(ctx,
104 "__kernel void mul(__global const float* a, "
105 " __global const float* b, "
106 " __global float* dst) {"
107 " int i = get_global_id(0); "
108 " dst[i] = a[i] * b[i]; "
109 "} ",
110 /*and build now*/true,
111 &ok);
112 cl_ok(ok);
113
114 std::vector<float> a,b,p;
115 for (int i = 0; i < 1000; i++) {
116 a.push_back(+i);
117 b.push_back(-i);
118 p.push_back( 0);
119 }
120
121 cl::Buffer A(ctx, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR , sizeof(float)*a.size(), a.data()),
122 B(ctx, CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR , sizeof(float)*b.size(), b.data()),
123 P(ctx, CL_MEM_WRITE_ONLY| CL_MEM_HOST_READ_ONLY, sizeof(float)*p.size());
124
125 cl::Kernel mul(program, "mul", &ok);
126 cl_ok(ok);
127 cl_ok(mul.setArg(0, A));
128 cl_ok(mul.setArg(1, B));
129 cl_ok(mul.setArg(2, P));
130
131 cl::CommandQueue queue(ctx, device);
132
133 cl_ok(queue.enqueueNDRangeKernel(mul, cl::NDRange(0) /*offset*/
134 , cl::NDRange(1000) /*size*/));
135
136 cl_ok(queue.enqueueReadBuffer(P, true/*block until read is done*/
137 , 0 /*offset in bytes*/
138 , sizeof(float)*p.size() /*size in bytes*/
139 , p.data()));
140
141 for (int i = 0; i < 1000; i++) {
142 if (p[i] != a[i]*b[i]) {
143 return 1;
144 }
145 }
146
Mike Klein8a1f15d2019-02-11 11:59:41 -0500147 fprintf(stdout, "OpenCL sez: %g x %g = %g\n", a[42], b[42], p[42]);
Mike Kleinf9ae6702018-06-20 14:05:05 -0400148 return 0;
149}