[OPENMP]Add call to __kmpc_push_target_tripcount() function.

Each we create the target regions with the teams distribute inner
region, we can better estimate number of the teams required to execute
the target region. Function __kmpc_push_target_tripcount() is used for
purpose, which accepts device_id and the number of the iterations,
performed by the associated loop.

llvm-svn: 350571
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 8176093..20eb0b2 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -673,6 +673,9 @@
   //
   // Offloading related calls
   //
+  // Call to void __kmpc_push_target_tripcount(int64_t device_id, kmp_uint64
+  // size);
+  OMPRTL__kmpc_push_target_tripcount,
   // Call to int32_t __tgt_target(int64_t device_id, void *host_ptr, int32_t
   // arg_num, void** args_base, void **args, size_t *arg_sizes, int64_t
   // *arg_types);
@@ -2163,6 +2166,15 @@
         FnTy, /*Name=*/"__kmpc_task_reduction_get_th_data");
     break;
   }
+  case OMPRTL__kmpc_push_target_tripcount: {
+    // Build void __kmpc_push_target_tripcount(int64_t device_id, kmp_uint64
+    // size);
+    llvm::Type *TypeParams[] = {CGM.Int64Ty, CGM.Int64Ty};
+    llvm::FunctionType *FnTy =
+        llvm::FunctionType::get(CGM.VoidTy, TypeParams, /*isVarArg=*/false);
+    RTLFn = CGM.CreateRuntimeFunction(FnTy, "__kmpc_push_target_tripcount");
+    break;
+  }
   case OMPRTL__tgt_target: {
     // Build int32_t __tgt_target(int64_t device_id, void *host_ptr, int32_t
     // arg_num, void** args_base, void **args, size_t *arg_sizes, int64_t
@@ -8053,6 +8065,183 @@
   }
 }
 
+/// Checks if the expression is constant or does not have non-trivial function
+/// calls.
+static bool isTrivial(ASTContext &Ctx, const Expr * E) {
+  // We can skip constant expressions.
+  // We can skip expressions with trivial calls or simple expressions.
+  return (E->isEvaluatable(Ctx, Expr::SE_AllowUndefinedBehavior) ||
+          !E->hasNonTrivialCall(Ctx)) &&
+         !E->HasSideEffects(Ctx, /*IncludePossibleEffects=*/true);
+}
+
+/// Checks if the \p Body is the \a CompoundStmt and returns its child statement
+/// iff there is only one that is not evaluatable at the compile time.
+static const Stmt *getSingleCompoundChild(ASTContext &Ctx, const Stmt *Body) {
+  if (const auto *C = dyn_cast<CompoundStmt>(Body)) {
+    const Stmt *Child = nullptr;
+    for (const Stmt *S : C->body()) {
+      if (const auto *E = dyn_cast<Expr>(S)) {
+        if (isTrivial(Ctx, E))
+          continue;
+      }
+      // Some of the statements can be ignored.
+      if (isa<AsmStmt>(S) || isa<NullStmt>(S) || isa<OMPFlushDirective>(S) ||
+          isa<OMPBarrierDirective>(S) || isa<OMPTaskyieldDirective>(S))
+        continue;
+      // Analyze declarations.
+      if (const auto *DS = dyn_cast<DeclStmt>(S)) {
+        if (llvm::all_of(DS->decls(), [&Ctx](const Decl *D) {
+              if (isa<EmptyDecl>(D) || isa<DeclContext>(D) ||
+                  isa<TypeDecl>(D) || isa<PragmaCommentDecl>(D) ||
+                  isa<PragmaDetectMismatchDecl>(D) || isa<UsingDecl>(D) ||
+                  isa<UsingDirectiveDecl>(D) ||
+                  isa<OMPDeclareReductionDecl>(D) ||
+                  isa<OMPThreadPrivateDecl>(D))
+                return true;
+              const auto *VD = dyn_cast<VarDecl>(D);
+              if (!VD)
+                return false;
+              return VD->isConstexpr() ||
+                     ((VD->getType().isTrivialType(Ctx) ||
+                       VD->getType()->isReferenceType()) &&
+                      (!VD->hasInit() || isTrivial(Ctx, VD->getInit())));
+            }))
+          continue;
+      }
+      // Found multiple children - cannot get the one child only.
+      if (Child)
+        return Body;
+      Child = S;
+    }
+    if (Child)
+      return Child;
+  }
+  return Body;
+}
+
+/// Check for inner distribute directive.
+static const OMPExecutableDirective *
+getNestedDistributeDirective(ASTContext &Ctx, const OMPExecutableDirective &D) {
+  const auto *CS = D.getInnermostCapturedStmt();
+  const auto *Body =
+      CS->getCapturedStmt()->IgnoreContainers(/*IgnoreCaptured=*/true);
+  const Stmt *ChildStmt = getSingleCompoundChild(Ctx, Body);
+
+  if (const auto *NestedDir = dyn_cast<OMPExecutableDirective>(ChildStmt)) {
+    OpenMPDirectiveKind DKind = NestedDir->getDirectiveKind();
+    switch (D.getDirectiveKind()) {
+    case OMPD_target:
+      if (isOpenMPDistributeDirective(DKind))
+        return NestedDir;
+      if (DKind == OMPD_teams) {
+        Body = NestedDir->getInnermostCapturedStmt()->IgnoreContainers(
+            /*IgnoreCaptured=*/true);
+        if (!Body)
+          return nullptr;
+        ChildStmt = getSingleCompoundChild(Ctx, Body);
+        if (const auto *NND = dyn_cast<OMPExecutableDirective>(ChildStmt)) {
+          DKind = NND->getDirectiveKind();
+          if (isOpenMPDistributeDirective(DKind))
+            return NND;
+        }
+      }
+      return nullptr;
+    case OMPD_target_teams:
+      if (isOpenMPDistributeDirective(DKind))
+        return NestedDir;
+      return nullptr;
+    case OMPD_target_parallel:
+    case OMPD_target_simd:
+    case OMPD_target_parallel_for:
+    case OMPD_target_parallel_for_simd:
+      return nullptr;
+    case OMPD_target_teams_distribute:
+    case OMPD_target_teams_distribute_simd:
+    case OMPD_target_teams_distribute_parallel_for:
+    case OMPD_target_teams_distribute_parallel_for_simd:
+    case OMPD_parallel:
+    case OMPD_for:
+    case OMPD_parallel_for:
+    case OMPD_parallel_sections:
+    case OMPD_for_simd:
+    case OMPD_parallel_for_simd:
+    case OMPD_cancel:
+    case OMPD_cancellation_point:
+    case OMPD_ordered:
+    case OMPD_threadprivate:
+    case OMPD_task:
+    case OMPD_simd:
+    case OMPD_sections:
+    case OMPD_section:
+    case OMPD_single:
+    case OMPD_master:
+    case OMPD_critical:
+    case OMPD_taskyield:
+    case OMPD_barrier:
+    case OMPD_taskwait:
+    case OMPD_taskgroup:
+    case OMPD_atomic:
+    case OMPD_flush:
+    case OMPD_teams:
+    case OMPD_target_data:
+    case OMPD_target_exit_data:
+    case OMPD_target_enter_data:
+    case OMPD_distribute:
+    case OMPD_distribute_simd:
+    case OMPD_distribute_parallel_for:
+    case OMPD_distribute_parallel_for_simd:
+    case OMPD_teams_distribute:
+    case OMPD_teams_distribute_simd:
+    case OMPD_teams_distribute_parallel_for:
+    case OMPD_teams_distribute_parallel_for_simd:
+    case OMPD_target_update:
+    case OMPD_declare_simd:
+    case OMPD_declare_target:
+    case OMPD_end_declare_target:
+    case OMPD_declare_reduction:
+    case OMPD_taskloop:
+    case OMPD_taskloop_simd:
+    case OMPD_requires:
+    case OMPD_unknown:
+      llvm_unreachable("Unexpected directive.");
+    }
+  }
+
+  return nullptr;
+}
+
+void CGOpenMPRuntime::emitTargetNumIterationsCall(
+    CodeGenFunction &CGF, const OMPExecutableDirective &D, const Expr *Device,
+    const llvm::function_ref<llvm::Value *(
+        CodeGenFunction &CGF, const OMPLoopDirective &D)> &SizeEmitter) {
+  OpenMPDirectiveKind Kind = D.getDirectiveKind();
+  const OMPExecutableDirective *TD = &D;
+  // Get nested teams distribute kind directive, if any.
+  if (!isOpenMPDistributeDirective(Kind) || !isOpenMPTeamsDirective(Kind))
+    TD = getNestedDistributeDirective(CGM.getContext(), D);
+  if (!TD)
+    return;
+  const auto *LD = cast<OMPLoopDirective>(TD);
+  auto &&CodeGen = [LD, &Device, &SizeEmitter, this](CodeGenFunction &CGF,
+                                                     PrePostActionTy &) {
+    llvm::Value *NumIterations = SizeEmitter(CGF, *LD);
+
+    // Emit device ID if any.
+    llvm::Value *DeviceID;
+    if (Device)
+      DeviceID = CGF.Builder.CreateIntCast(CGF.EmitScalarExpr(Device),
+                                           CGF.Int64Ty, /*isSigned=*/true);
+    else
+      DeviceID = CGF.Builder.getInt64(OMP_DEVICEID_UNDEF);
+
+    llvm::Value *Args[] = {DeviceID, NumIterations};
+    CGF.EmitRuntimeCall(
+        createRuntimeFunction(OMPRTL__kmpc_push_target_tripcount), Args);
+  };
+  emitInlinedDirective(CGF, OMPD_unknown, CodeGen);
+}
+
 void CGOpenMPRuntime::emitTargetCall(CodeGenFunction &CGF,
                                      const OMPExecutableDirective &D,
                                      llvm::Value *OutlinedFn,
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.h b/clang/lib/CodeGen/CGOpenMPRuntime.h
index d9ac5df..933c8af 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.h
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.h
@@ -1367,6 +1367,15 @@
                                           bool IsOffloadEntry,
                                           const RegionCodeGenTy &CodeGen);
 
+  /// Emit code that pushes the trip count of loops associated with constructs
+  /// 'target teams distribute' and 'teams distribute parallel for'.
+  /// \param SizeEmitter Emits the int64 value for the number of iterations of
+  /// the associated loop.
+  virtual void emitTargetNumIterationsCall(
+      CodeGenFunction &CGF, const OMPExecutableDirective &D, const Expr *Device,
+      const llvm::function_ref<llvm::Value *(
+          CodeGenFunction &CGF, const OMPLoopDirective &D)> &SizeEmitter);
+
   /// Emit the target offloading code associated with \a D. The emitted
   /// code attempts offloading the execution to the device, an the event of
   /// a failure it executes the host version outlined in \a OutlinedFn.
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
index 403aefb..7046ab3 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeNVPTX.cpp
@@ -705,8 +705,8 @@
                                           : CGOpenMPRuntimeNVPTX::Generic;
 }
 
-// Checks if the expression is constant or does not have non-trivial function
-// calls.
+/// Checks if the expression is constant or does not have non-trivial function
+/// calls.
 static bool isTrivial(ASTContext &Ctx, const Expr * E) {
   // We can skip constant expressions.
   // We can skip expressions with trivial calls or simple expressions.
diff --git a/clang/lib/CodeGen/CGStmtOpenMP.cpp b/clang/lib/CodeGen/CGStmtOpenMP.cpp
index ee38593..650f4e8 100644
--- a/clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ b/clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -4071,6 +4071,16 @@
   CGM.getOpenMPRuntime().emitTargetOutlinedFunction(S, ParentName, Fn, FnID,
                                                     IsOffloadEntry, CodeGen);
   OMPLexicalScope Scope(CGF, S, OMPD_task);
+  auto &&SizeEmitter = [](CodeGenFunction &CGF, const OMPLoopDirective &D) {
+    OMPLoopScope(CGF, D);
+    // Emit calculation of the iterations count.
+    llvm::Value *NumIterations = CGF.EmitScalarExpr(D.getNumIterations());
+    NumIterations = CGF.Builder.CreateIntCast(NumIterations, CGF.Int64Ty,
+                                              /*IsSigned=*/false);
+    return NumIterations;
+  };
+  CGM.getOpenMPRuntime().emitTargetNumIterationsCall(CGF, S, Device,
+                                                     SizeEmitter);
   CGM.getOpenMPRuntime().emitTargetCall(CGF, S, Fn, FnID, IfCond, Device);
 }