Implement multi-thread CPU GEMM for BLAS Intrinsics
- Multi-thread GEMM utilizes existing RS thread pool on top of
Eigen.
- Large matrix-matrix multiplication is decomposed into multiple
tiled matrix-matrix multiplications. Each thread iterates on
the unfinished works.
- The tiling applies to ONLY ONE dimension of each input matrix,
and whether to tile X or Y depends on the transpose of the matrix.
- The performance increase is proportional to the number of
available CPU cores, for sufficiently large matrices.
Test: CTS test (rsblas) pass on Angler, Fugu and new devices.
Performance test with RsBlasBenchmark and RsNeuralNet demo
on Anger, Ryu, Seed, Shamu, Volantis, Fugu and new devices,
showing roughly 70%(Volantix 2 core) ~ 400+%(Angler 8 core) perf gain.
Change-Id: If96f4119fd34d5d9d98a2542801495e7ffe577ae
(cherry picked from commit 41ab8faaf0d90238d42d8e2bbb7177467c10b4f6)
diff --git a/cpu_ref/rsCpuIntrinsicBLAS.cpp b/cpu_ref/rsCpuIntrinsicBLAS.cpp
index 2a5ab21..4b08634 100644
--- a/cpu_ref/rsCpuIntrinsicBLAS.cpp
+++ b/cpu_ref/rsCpuIntrinsicBLAS.cpp
@@ -33,7 +33,6 @@
const void * usr,
uint32_t usrLen,
const RsScriptCall *sc) override;
-
void populateScript(Script *) override;
~RsdCpuScriptIntrinsicBLAS() override;
RsdCpuScriptIntrinsicBLAS(RsdCpuReferenceImpl *ctx, const Script *s);
@@ -82,10 +81,158 @@
*C = ain[2]->mHal.drvState.lod[0].mallocPtr;
*ldc = (int)(ain[2]->mHal.drvState.lod[0].stride/size);
}
-
-
}
+// Routine to setup LaunchStruct for GEMM callback.
+static void setupGEMM(MTLaunchStructForEachBlas *mtls, const Allocation **ain, RsBlasCall* call,
+ RsdCpuReferenceImpl *ctx) {
+ uint32_t mm, nn, kk;
+ mm = call->M;
+ nn = call->N;
+ kk = call->K;
+
+ memset(mtls, 0, sizeof(MTLaunchStructForEachBlas));
+ mtls->rs = ctx;
+ mtls->sc = call;
+ mtls->dimPtr = &mtls->fep.dim;
+ mtls->fep.dim.x = nn;
+ mtls->fep.dim.y = mm;
+ mtls->fep.dim.z = kk;
+ if (ain) {
+ memcpy(mtls->ains, ain, 3 * sizeof(ain[0]));
+ }
+ uint32_t elementBytes = 4;
+ if (ain[0]) {
+ elementBytes = ain[0]->getType()->getElement()->getSizeBytes();
+ }
+ const uint32_t MIN_SIZE_TO_TILE = 64 * 1024 / elementBytes;
+ const uint32_t MAX_WORK_PER_THREAD = 512 / elementBytes;
+ const uint32_t THREAD_COUNT = ctx->getThreadCount();
+ uint32_t tileSizeN = 0;
+ uint32_t tileSizeM = 0;
+
+ // Do not tile the matrix if:
+ // 1. It is too small comparing to the other matrix.
+ // 2. It is too small comparing to MIN_SIZE_TO_TILE .
+ if (nn * kk > MIN_SIZE_TO_TILE && nn * THREAD_COUNT > mm) {
+ tileSizeN = rsMin(nn / THREAD_COUNT, MAX_WORK_PER_THREAD);
+ }
+ if (mm * kk > MIN_SIZE_TO_TILE && mm * THREAD_COUNT > nn) {
+ tileSizeM = rsMin(mm / THREAD_COUNT, MAX_WORK_PER_THREAD);
+ }
+ mtls->numTileM = 1;
+ mtls->numTileN = 1;
+ mtls->tileSizeM = mm;
+ mtls->tileSizeN = nn;
+
+ // If tiling is needed, compute the number of slices for A & B.
+ mtls->isThreadable = (tileSizeM > 0 || tileSizeN > 0);
+ if (tileSizeM) {
+ mtls->numTileM += (mm - 1) / tileSizeM;
+ mtls->tileSizeM = tileSizeM;
+ }
+ if (tileSizeN) {
+ mtls->numTileN += (nn - 1) / tileSizeN;
+ mtls->tileSizeN = tileSizeN;
+ }
+
+ mtls->mSliceNum = 0;
+}
+
+// Generic GEMM callback routine.
+template <typename T_data, typename T_param, typename Func>
+static void walk_tiled_gemm(Func blasFunc, T_param alpha, T_param beta, int vecSize,
+ RsBlasCall* call, const MTLaunchStructForEachBlas *mtls) {
+ // setup BLAS enum args
+ enum CBLAS_TRANSPOSE TransA = (enum CBLAS_TRANSPOSE)call->transA;
+ enum CBLAS_TRANSPOSE TransB = (enum CBLAS_TRANSPOSE)call->transB;
+
+ void *A = nullptr;
+ void *B = nullptr;
+ void *C = nullptr;
+
+ int lda = 0, ldb = 0, ldc = 0;
+
+ const Allocation *ain[RS_KERNEL_INPUT_LIMIT];
+ ain[0] = mtls->ains[0];
+ ain[1] = mtls->ains[1];
+ ain[2] = mtls->ains[2];
+
+ initABC(ain, sizeof(T_data) * vecSize, &A, &B, &C, &lda, &ldb, &ldc);
+
+ // Determin the stride of the tiled matrices.
+ int mStride = (TransA == CblasNoTrans) ? lda : 1;
+ int nStride = (TransB == CblasNoTrans) ? 1 : ldb;
+ while (1) {
+ uint32_t slice = (uint32_t)__sync_fetch_and_add(&mtls->mSliceNum, 1);
+
+ uint32_t mStart = (slice % mtls->numTileM) * mtls->tileSizeM;
+ uint32_t mEnd = mStart + mtls->tileSizeM;
+ mEnd = rsMin(mEnd, (uint32_t)call->M);
+ if (mEnd <= mStart) {
+ return;
+ }
+
+ uint32_t nStart = (slice / mtls->numTileM) * mtls->tileSizeN;
+ uint32_t nEnd = nStart + mtls->tileSizeN;
+ nEnd = rsMin(nEnd, (uint32_t)call->N);
+ if (nEnd <= nStart) {
+ return;
+ }
+
+ blasFunc(CblasRowMajor, TransA, TransB,
+ mEnd - mStart, nEnd - nStart, call->K, alpha,
+ (T_data *)A + mStart * mStride * vecSize, lda,
+ (T_data *)B + nStart * nStride * vecSize, ldb, beta,
+ (T_data *)C + (mStart * ldc + nStart) * vecSize, ldc);
+ }
+}
+
+// SGEMM callback
+static void walk_2d_sgemm(void *usr, uint32_t idx) {
+ const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr;
+ RsBlasCall* call = (RsBlasCall*) mtls->sc;
+
+ float alpha = call->alpha.f;
+ float beta = call->beta.f;
+
+ walk_tiled_gemm<float, float, FnPtr_cblas_sgemm>(cblas_sgemm, alpha, beta, 1, call, mtls);
+}
+
+// DGEMM callback
+static void walk_2d_dgemm(void *usr, uint32_t idx) {
+ const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr;
+ RsBlasCall* call = (RsBlasCall*) mtls->sc;
+
+ double alpha = call->alpha.d;
+ double beta = call->beta.d;
+
+ walk_tiled_gemm<double, double, FnPtr_cblas_dgemm>(cblas_dgemm, alpha, beta, 1, call, mtls);
+}
+
+// CGEMM callback
+static void walk_2d_cgemm(void *usr, uint32_t idx) {
+ const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr;
+ RsBlasCall* call = (RsBlasCall*) mtls->sc;
+
+ void * alpha = (void *)&call->alpha.c;
+ void * beta = (void *)&call->beta.c;
+
+ walk_tiled_gemm<float, void *, FnPtr_cblas_cgemm>(cblas_cgemm, alpha, beta, 2, call, mtls);
+}
+
+// ZGEMM callback
+static void walk_2d_zgemm(void *usr, uint32_t idx) {
+ const MTLaunchStructForEachBlas *mtls = (const MTLaunchStructForEachBlas *)usr;
+ RsBlasCall* call = (RsBlasCall*) mtls->sc;
+
+ void * alpha = (void *)&call->alpha.z;
+ void * beta = (void *)&call->beta.z;
+
+ walk_tiled_gemm<double, void *, FnPtr_cblas_zgemm>(cblas_zgemm, alpha, beta, 2, call, mtls);
+}
+
+
void RsdCpuScriptIntrinsicBLAS::invokeForEach(uint32_t slot,
const Allocation ** ain,
uint32_t inLen,
@@ -109,6 +256,8 @@
int lda = 0, ldb = 0, ldc = 0;
+ MTLaunchStructForEachBlas mtls;
+
#ifdef RS_COMPATIBILITY_LIB
// Allow BNNM even without libblas
if (call->func != RsBlas_bnnm && !isBlasLibInitialized) {
@@ -486,9 +635,14 @@
// Level 3 BLAS
case (RsBlas_sgemm):
- initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
- cblas_sgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.f,
- (float*)A, lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
+ setupGEMM(&mtls, ain, call, mCtx);
+ if (mtls.isThreadable) {
+ mCtx->launchThreads(walk_2d_sgemm, &mtls);
+ } else {
+ initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
+ cblas_sgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.f,
+ (float*)A, lda, (float*)B, ldb, call->beta.f, (float*)C, ldc);
+ }
break;
case (RsBlas_ssymm):
initABC(ain, sizeof(float), &A, &B, &C, &lda, &ldb, &ldc);
@@ -518,9 +672,14 @@
case (RsBlas_dgemm):
- initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
- cblas_dgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.d,
- (double*)A, lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
+ setupGEMM(&mtls, ain, call, mCtx);
+ if (mtls.isThreadable) {
+ mCtx->launchThreads(walk_2d_dgemm, &mtls);
+ } else {
+ initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
+ cblas_dgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, call->alpha.d,
+ (double*)A, lda, (double*)B, ldb, call->beta.d, (double*)C, ldc);
+ }
break;
case (RsBlas_dsymm):
initABC(ain, sizeof(double), &A, &B, &C, &lda, &ldb, &ldc);
@@ -549,9 +708,14 @@
break;
case (RsBlas_cgemm):
- initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
- cblas_cgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.c,
- A, lda, B, ldb, (void*)&call->beta.c, C, ldc);
+ setupGEMM(&mtls, ain, call, mCtx);
+ if (mtls.isThreadable) {
+ mCtx->launchThreads(walk_2d_cgemm, &mtls);
+ } else {
+ initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
+ cblas_cgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.c,
+ A, lda, B, ldb, (void*)&call->beta.c, C, ldc);
+ }
break;
case (RsBlas_csymm):
initABC(ain, sizeof(float)*2, &A, &B, &C, &lda, &ldb, &ldc);
@@ -580,9 +744,14 @@
break;
case (RsBlas_zgemm):
- initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
- cblas_zgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.z,
- A, lda, B, ldb, (void*)&call->beta.z, C, ldc);
+ setupGEMM(&mtls, ain, call, mCtx);
+ if (mtls.isThreadable) {
+ mCtx->launchThreads(walk_2d_zgemm, &mtls);
+ } else {
+ initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);
+ cblas_zgemm(CblasRowMajor, TransA, TransB, call->M, call->N, call->K, (void*)&call->alpha.z,
+ A, lda, B, ldb, (void*)&call->beta.z, C, ldc);
+ }
break;
case (RsBlas_zsymm):
initABC(ain, sizeof(double)*2, &A, &B, &C, &lda, &ldb, &ldc);