adjust gather ops launch config. for NCF model, this means ~20% gain. (due to grid size from 80->160 on volta).
PiperOrigin-RevId: 312373706
Change-Id: I2413d301ec170e6e90eeae025e4bb17fccd5abbb
diff --git a/tensorflow/core/kernels/gather_functor_gpu.cu.h b/tensorflow/core/kernels/gather_functor_gpu.cu.h
index 1cadee4..b2dd438 100644
--- a/tensorflow/core/kernels/gather_functor_gpu.cu.h
+++ b/tensorflow/core/kernels/gather_functor_gpu.cu.h
@@ -92,13 +92,18 @@
const int64 indices_size = indices.size();
const int64 slice_size = params.dimension(2);
- GpuLaunchConfig config = GetGpuLaunchConfig(out_size, d);
if (is_axis_zero) {
+ GpuLaunchConfig config = GetGpuLaunchConfig(
+ out_size, d, &GatherOpKernel<T, Index, true>,
+ /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
TF_CHECK_OK(GpuLaunchKernel(
GatherOpKernel<T, Index, true>, config.block_count,
config.thread_per_block, 0, d.stream(), params.data(), indices.data(),
out.data(), gather_dim_size, indices_size, slice_size, out_size));
} else {
+ GpuLaunchConfig config = GetGpuLaunchConfig(
+ out_size, d, &GatherOpKernel<T, Index, false>,
+ /*dynamic_shared_memory_size=*/0, /*block_size_limit=*/0);
TF_CHECK_OK(GpuLaunchKernel(
GatherOpKernel<T, Index, false>, config.block_count,
config.thread_per_block, 0, d.stream(), params.data(), indices.data(),