Support for local RS zero initialization.

Change-Id: I785cfc6ee53abb6c88ab5bdba5e7c8c16b8409de
diff --git a/slang_rs_backend.cpp b/slang_rs_backend.cpp
index 9d8684a..a46dced 100644
--- a/slang_rs_backend.cpp
+++ b/slang_rs_backend.cpp
@@ -16,6 +16,7 @@
 
 #include "slang_rs_backend.h"
 
+#include <stack>
 #include <vector>
 #include <string>
 
@@ -26,14 +27,16 @@
 #include "llvm/Function.h"
 #include "llvm/DerivedTypes.h"
 
-#include "llvm/System/Path.h"
-
 #include "llvm/Support/IRBuilder.h"
 
 #include "llvm/ADT/Twine.h"
 #include "llvm/ADT/StringExtras.h"
 
 #include "clang/AST/DeclGroup.h"
+#include "clang/AST/Expr.h"
+#include "clang/AST/OperationKinds.h"
+#include "clang/AST/Stmt.h"
+#include "clang/AST/StmtVisitor.h"
 
 #include "slang_rs.h"
 #include "slang_rs_context.h"
@@ -77,25 +80,249 @@
       clang::FunctionDecl *FD = dyn_cast<clang::FunctionDecl>(*I);
       if (FD == NULL)
         continue;
-      if (FD->getName().startswith("rs")) {  // Check prefix
-        clang::FullSourceLoc FSL(FD->getLocStart(), mSourceMgr);
-        clang::PresumedLoc PLoc = mSourceMgr.getPresumedLoc(FSL);
-        llvm::sys::Path HeaderFilename(PLoc.getFilename());
-
-        // Skip if that function declared in the RS default header.
-        if (SlangRS::IsRSHeaderFile(HeaderFilename.getLast().data()))
-          continue;
-        mDiags.Report(FSL, mDiags.getCustomDiagID(clang::Diagnostic::Error,
-                      "invalid function name prefix, \"rs\" is reserved: '%0'"))
+      if (!FD->getName().startswith("rs"))  // Check prefix
+        continue;
+      if (!SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
+        mDiags.Report(clang::FullSourceLoc(FD->getLocation(), mSourceMgr),
+                      mDiags.getCustomDiagID(clang::Diagnostic::Error,
+                                             "invalid function name prefix, "
+                                             "\"rs\" is reserved: '%0'"))
             << FD->getName();
-      }
     }
   }
 
   Backend::HandleTopLevelDecl(D);
   return;
 }
+///////////////////////////////////////////////////////////////////////////////
 
+namespace {
+
+  class RSObjectRefCounting : public clang::StmtVisitor<RSObjectRefCounting> {
+   private:
+    class Scope {
+     private:
+      clang::CompoundStmt *mCS;      // Associated compound statement ({ ... })
+      std::list<clang::Decl*> mRSO;  // Declared RS object in this scope
+
+     public:
+      Scope(clang::CompoundStmt *CS) : mCS(CS) {
+        return;
+      }
+
+      inline void addRSObject(clang::Decl* D) { mRSO.push_back(D); }
+    };
+    std::stack<Scope*> mScopeStack;
+
+    inline Scope *getCurrentScope() { return mScopeStack.top(); }
+
+    // Return false if the process was terminated early. I.e., Type of variable
+    // in VD is not an RS object type.
+    static bool InitializeRSObject(clang::VarDecl *VD);
+    static clang::Expr *CreateZeroInitializerForRSObject(
+        RSExportPrimitiveType::DataType DT,
+        clang::ASTContext &C,
+        const clang::SourceLocation &Loc);
+
+   public:
+    void VisitChildren(clang::Stmt *S) {
+      for (clang::Stmt::child_iterator I = S->child_begin(), E = S->child_end();
+           I != E;
+           I++)
+        if (clang::Stmt *Child = *I)
+          Visit(Child);
+    }
+    void VisitStmt(clang::Stmt *S) { VisitChildren(S); }
+
+    void VisitDeclStmt(clang::DeclStmt *DS);
+    void VisitCompoundStmt(clang::CompoundStmt *CS);
+    void VisitBinAssign(clang::BinaryOperator *AS);
+
+    // We believe that RS objects never are involved in CompoundAssignOperator.
+    // I.e., rs_allocation foo; foo += bar;
+  };
+}
+
+bool RSObjectRefCounting::InitializeRSObject(clang::VarDecl *VD) {
+  const clang::Type *T = RSExportType::GetTypeOfDecl(VD);
+  RSExportPrimitiveType::DataType DT =
+      RSExportPrimitiveType::GetRSObjectType(T);
+
+  if (DT == RSExportPrimitiveType::DataTypeUnknown)
+    return false;
+
+  if (VD->hasInit()) {
+    // TODO: Update the reference count of RS object in initializer.
+    // This can potentially be done as part of the assignment pass.
+  } else {
+    clang::Expr *ZeroInitializer =
+        CreateZeroInitializerForRSObject(DT,
+                                         VD->getASTContext(),
+                                         VD->getLocation());
+
+    if (ZeroInitializer) {
+      ZeroInitializer->setType(T->getCanonicalTypeInternal());
+      VD->setInit(ZeroInitializer);
+    }
+  }
+
+  return true;
+}
+
+clang::Expr *RSObjectRefCounting::CreateZeroInitializerForRSObject(
+    RSExportPrimitiveType::DataType DT,
+    clang::ASTContext &C,
+    const clang::SourceLocation &Loc) {
+  clang::Expr *Res = NULL;
+  switch (DT) {
+    case RSExportPrimitiveType::DataTypeRSElement:
+    case RSExportPrimitiveType::DataTypeRSType:
+    case RSExportPrimitiveType::DataTypeRSAllocation:
+    case RSExportPrimitiveType::DataTypeRSSampler:
+    case RSExportPrimitiveType::DataTypeRSScript:
+    case RSExportPrimitiveType::DataTypeRSMesh:
+    case RSExportPrimitiveType::DataTypeRSProgramFragment:
+    case RSExportPrimitiveType::DataTypeRSProgramVertex:
+    case RSExportPrimitiveType::DataTypeRSProgramRaster:
+    case RSExportPrimitiveType::DataTypeRSProgramStore:
+    case RSExportPrimitiveType::DataTypeRSFont: {
+      //    (ImplicitCastExpr 'nullptr_t'
+      //      (IntegerLiteral 0)))
+      llvm::APInt Zero(C.getTypeSize(C.IntTy), 0);
+      clang::Expr *Int0 = clang::IntegerLiteral::Create(C, Zero, C.IntTy, Loc);
+      clang::Expr *CastToNull =
+          clang::ImplicitCastExpr::Create(C,
+                                          C.NullPtrTy,
+                                          clang::CK_IntegralToPointer,
+                                          Int0,
+                                          NULL,
+                                          clang::VK_RValue);
+
+      Res = new (C) clang::InitListExpr(C, Loc, &CastToNull, 1, Loc);
+      break;
+    }
+    case RSExportPrimitiveType::DataTypeRSMatrix2x2:
+    case RSExportPrimitiveType::DataTypeRSMatrix3x3:
+    case RSExportPrimitiveType::DataTypeRSMatrix4x4: {
+      // RS matrix is not completely an RS object. They hold data by themselves.
+      // (InitListExpr rs_matrix2x2
+      //   (InitListExpr float[4]
+      //     (FloatingLiteral 0)
+      //     (FloatingLiteral 0)
+      //     (FloatingLiteral 0)
+      //     (FloatingLiteral 0)))
+      clang::QualType FloatTy = C.FloatTy;
+      // Constructor sets value to 0.0f by default
+      llvm::APFloat Val(C.getFloatTypeSemantics(FloatTy));
+      clang::FloatingLiteral *Float0Val =
+          clang::FloatingLiteral::Create(C,
+                                         Val,
+                                         /* isExact = */true,
+                                         FloatTy,
+                                         Loc);
+
+      unsigned N;
+      if (DT == RSExportPrimitiveType::DataTypeRSMatrix2x2)
+        N = 2;
+      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix3x3)
+        N = 3;
+      else if (DT == RSExportPrimitiveType::DataTypeRSMatrix4x4)
+        N = 4;
+
+      // Directly allocate 16 elements instead of dynamically allocate N*N
+      clang::Expr *InitVals[16];
+      for (unsigned i = 0; i < sizeof(InitVals) / sizeof(InitVals[0]); i++)
+        InitVals[i] = Float0Val;
+      clang::Expr *InitExpr =
+          new (C) clang::InitListExpr(C, Loc, InitVals, N * N, Loc);
+      InitExpr->setType(C.getConstantArrayType(FloatTy,
+                                               llvm::APInt(32, 4),
+                                               clang::ArrayType::Normal,
+                                               /* EltTypeQuals = */0));
+
+      Res = new (C) clang::InitListExpr(C, Loc, &InitExpr, 1, Loc);
+      break;
+    }
+    case RSExportPrimitiveType::DataTypeUnknown:
+    case RSExportPrimitiveType::DataTypeFloat16:
+    case RSExportPrimitiveType::DataTypeFloat32:
+    case RSExportPrimitiveType::DataTypeFloat64:
+    case RSExportPrimitiveType::DataTypeSigned8:
+    case RSExportPrimitiveType::DataTypeSigned16:
+    case RSExportPrimitiveType::DataTypeSigned32:
+    case RSExportPrimitiveType::DataTypeSigned64:
+    case RSExportPrimitiveType::DataTypeUnsigned8:
+    case RSExportPrimitiveType::DataTypeUnsigned16:
+    case RSExportPrimitiveType::DataTypeUnsigned32:
+    case RSExportPrimitiveType::DataTypeUnsigned64:
+    case RSExportPrimitiveType::DataTypeBoolean:
+    case RSExportPrimitiveType::DataTypeUnsigned565:
+    case RSExportPrimitiveType::DataTypeUnsigned5551:
+    case RSExportPrimitiveType::DataTypeUnsigned4444:
+    case RSExportPrimitiveType::DataTypeMax: {
+      assert(false && "Not RS object type!");
+    }
+    // No default case will enable compiler detecting the missing cases
+  }
+
+  return Res;
+}
+
+void RSObjectRefCounting::VisitDeclStmt(clang::DeclStmt *DS) {
+  for (clang::DeclStmt::decl_iterator I = DS->decl_begin(), E = DS->decl_end();
+       I != E;
+       I++) {
+    clang::Decl *D = *I;
+    if (D->getKind() == clang::Decl::Var) {
+      clang::VarDecl *VD = static_cast<clang::VarDecl*>(D);
+      if (InitializeRSObject(VD))
+        getCurrentScope()->addRSObject(VD);
+    }
+  }
+  return;
+}
+
+void RSObjectRefCounting::VisitCompoundStmt(clang::CompoundStmt *CS) {
+  if (!CS->body_empty()) {
+    // Push a new scope
+    Scope *S = new Scope(CS);
+    mScopeStack.push(S);
+
+    VisitChildren(CS);
+
+    // Destroy the scope
+    // TODO: Update reference count of the RS object refenced by the
+    //       getCurrentScope().
+    assert((getCurrentScope() == S) && "Corrupted scope stack!");
+    mScopeStack.pop();
+    delete S;
+  }
+  return;
+}
+
+void RSObjectRefCounting::VisitBinAssign(clang::BinaryOperator *AS) {
+  // TODO: Update reference count
+  return;
+}
+
+void RSBackend::HandleTranslationUnitPre(clang::ASTContext& C) {
+  RSObjectRefCounting RSObjectRefCounter;
+  clang::TranslationUnitDecl *TUDecl = C.getTranslationUnitDecl();
+
+  for (clang::DeclContext::decl_iterator I = TUDecl->decls_begin(),
+          E = TUDecl->decls_end(); I != E; I++) {
+    if ((I->getKind() >= clang::Decl::firstFunction) &&
+        (I->getKind() <= clang::Decl::lastFunction)) {
+      clang::FunctionDecl *FD = static_cast<clang::FunctionDecl*>(*I);
+      if (FD->hasBody() && !SlangRS::IsFunctionInRSHeaderFile(FD, mSourceMgr))
+        RSObjectRefCounter.Visit( FD->getBody());
+    }
+  }
+
+  return;
+}
+
+///////////////////////////////////////////////////////////////////////////////
 void RSBackend::HandleTranslationUnitPost(llvm::Module *M) {
   mContext->processExport();