Refactor the logic for printf argument type-checking into analyze_printf::ArgTypeResult.
Implement printf argument type checking for '%s'.

Fixes <rdar://problem/3065808>.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@96310 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Sema/SemaChecking.cpp b/lib/Sema/SemaChecking.cpp
index f9466ca..b62cd19 100644
--- a/lib/Sema/SemaChecking.cpp
+++ b/lib/Sema/SemaChecking.cpp
@@ -754,7 +754,7 @@
   if (!TheCall->getArg(0)->isIntegerConstantExpr(Result, Context))
     return Diag(TheCall->getLocStart(), diag::err_expr_not_ice)
       << TheCall->getArg(0)->getSourceRange();
-  
+
   return false;
 }
 
@@ -930,7 +930,7 @@
   for (NonNullAttr::iterator i = NonNull->begin(), e = NonNull->end();
        i != e; ++i) {
     const Expr *ArgExpr = TheCall->getArg(*i);
-    if (ArgExpr->isNullPointerConstant(Context, 
+    if (ArgExpr->isNullPointerConstant(Context,
                                        Expr::NPC_ValueDependentIsNotNull))
       Diag(TheCall->getCallee()->getLocStart(), diag::warn_null_arg)
         << ArgExpr->getSourceRange();
@@ -1049,7 +1049,7 @@
   const bool HasVAListArg;
   const CallExpr *TheCall;
   unsigned FormatIdx;
-public:  
+public:
   CheckPrintfHandler(Sema &s, const StringLiteral *fexpr,
                      const Expr *origFormatExpr,
                      unsigned numDataArgs, bool isObjCLiteral,
@@ -1060,19 +1060,19 @@
       IsObjCLiteral(isObjCLiteral), Beg(beg),
       HasVAListArg(hasVAListArg),
       TheCall(theCall), FormatIdx(formatIdx) {}
-  
+
   void DoneProcessing();
-     
+
   void HandleIncompleteFormatSpecifier(const char *startSpecifier,
                                        unsigned specifierLen);
-  
+
   void
   HandleInvalidConversionSpecifier(const analyze_printf::FormatSpecifier &FS,
                                    const char *startSpecifier,
                                    unsigned specifierLen);
-  
+
   void HandleNullChar(const char *nullCharacter);
-  
+
   bool HandleFormatSpecifier(const analyze_printf::FormatSpecifier &FS,
                              const char *startSpecifier,
                              unsigned specifierLen);
@@ -1081,16 +1081,14 @@
   SourceRange getFormatSpecifierRange(const char *startSpecifier,
                                       unsigned specifierLen);
   SourceLocation getLocationOfByte(const char *x);
-  
+
   bool HandleAmount(const analyze_printf::OptionalAmount &Amt,
                     unsigned MissingArgDiag, unsigned BadTypeDiag,
           const char *startSpecifier, unsigned specifierLen);
   void HandleFlags(const analyze_printf::FormatSpecifier &FS,
                    llvm::StringRef flag, llvm::StringRef cspec,
                    const char *startSpecifier, unsigned specifierLen);
-  
-  bool MatchType(QualType A, QualType B, bool ignoreSign);
-  
+
   const Expr *getDataArg(unsigned i) const;
 };
 }
@@ -1106,12 +1104,12 @@
 }
 
 SourceLocation CheckPrintfHandler::getLocationOfByte(const char *x) {
-  return S.getLocationOfStringLiteralByte(FExpr, x - Beg);  
+  return S.getLocationOfStringLiteralByte(FExpr, x - Beg);
 }
 
 void CheckPrintfHandler::
 HandleIncompleteFormatSpecifier(const char *startSpecifier,
-                                unsigned specifierLen) {  
+                                unsigned specifierLen) {
   SourceLocation Loc = getLocationOfByte(startSpecifier);
   S.Diag(Loc, diag::warn_printf_incomplete_specifier)
     << getFormatSpecifierRange(startSpecifier, specifierLen);
@@ -1121,14 +1119,14 @@
 HandleInvalidConversionSpecifier(const analyze_printf::FormatSpecifier &FS,
                                  const char *startSpecifier,
                                  unsigned specifierLen) {
-  
+
   ++NumConversions;
   const analyze_printf::ConversionSpecifier &CS =
-    FS.getConversionSpecifier();  
+    FS.getConversionSpecifier();
   SourceLocation Loc = getLocationOfByte(CS.getStart());
   S.Diag(Loc, diag::warn_printf_invalid_conversion)
       << llvm::StringRef(CS.getStart(), CS.getLength())
-      << getFormatSpecifierRange(startSpecifier, specifierLen);  
+      << getFormatSpecifierRange(startSpecifier, specifierLen);
 }
 
 void CheckPrintfHandler::HandleNullChar(const char *nullCharacter) {
@@ -1139,49 +1137,10 @@
 }
 
 const Expr *CheckPrintfHandler::getDataArg(unsigned i) const {
-  return TheCall->getArg(FormatIdx + i);  
+  return TheCall->getArg(FormatIdx + i);
 }
 
-bool CheckPrintfHandler::MatchType(QualType A, QualType B, bool ignoreSign) {
-  A = S.Context.getCanonicalType(A).getUnqualifiedType();
-  B = S.Context.getCanonicalType(B).getUnqualifiedType();
-  
-  if (A == B)
-    return true;
-  
-  if (ignoreSign) {
-    if (const BuiltinType *BT = B->getAs<BuiltinType>()) {
-      switch (BT->getKind()) {
-        default:
-          return false;
-        case BuiltinType::Char_S:          
-        case BuiltinType::SChar:
-          return A == S.Context.UnsignedCharTy;
-        case BuiltinType::Char_U:
-        case BuiltinType::UChar:
-          return A == S.Context.SignedCharTy;
-        case BuiltinType::Short:
-          return A == S.Context.UnsignedShortTy;
-        case BuiltinType::UShort:
-          return A == S.Context.ShortTy;          
-        case BuiltinType::Int:
-          return A == S.Context.UnsignedIntTy;
-        case BuiltinType::UInt:
-          return A == S.Context.IntTy;
-        case BuiltinType::Long:
-          return A == S.Context.UnsignedLongTy;
-        case BuiltinType::ULong:
-          return A == S.Context.LongTy;
-        case BuiltinType::LongLong:
-          return A == S.Context.UnsignedLongLongTy;
-        case BuiltinType::ULongLong:
-          return A == S.Context.LongLongTy;          
-      }
-      return A == B;
-    }
-  }
-  return false;  
-}
+
 
 void CheckPrintfHandler::HandleFlags(const analyze_printf::FormatSpecifier &FS,
                                      llvm::StringRef flag,
@@ -1205,21 +1164,25 @@
     if (!HasVAListArg) {
       if (NumConversions > NumDataArgs) {
         S.Diag(getLocationOfByte(Amt.getStart()), MissingArgDiag)
-          << getFormatSpecifierRange(startSpecifier, specifierLen);      
+          << getFormatSpecifierRange(startSpecifier, specifierLen);
         // Don't do any more checking.  We will just emit
         // spurious errors.
         return false;
       }
-      
+
       // Type check the data argument.  It should be an 'int'.
       // Although not in conformance with C99, we also allow the argument to be
       // an 'unsigned int' as that is a reasonably safe case.  GCC also
       // doesn't emit a warning for that case.
       const Expr *Arg = getDataArg(NumConversions);
       QualType T = Arg->getType();
-      if (!MatchType(T, S.Context.IntTy, true)) {
+
+      const analyze_printf::ArgTypeResult &ATR = Amt.getArgType(S.Context);
+      assert(ATR.isValid());
+
+      if (!ATR.matchesType(S.Context, T)) {
         S.Diag(getLocationOfByte(Amt.getStart()), BadTypeDiag)
-          << S.Context.IntTy << T
+          << ATR.getRepresentativeType(S.Context) << T
           << getFormatSpecifierRange(startSpecifier, specifierLen)
           << Arg->getSourceRange();
         // Don't do any more checking.  We will just emit
@@ -1248,7 +1211,7 @@
           startSpecifier, specifierLen)) {
     return false;
   }
-    
+
   if (!HandleAmount(FS.getPrecision(),
                     diag::warn_printf_asterisk_precision_missing_arg,
                     diag::warn_printf_asterisk_precision_wrong_type,
@@ -1260,7 +1223,7 @@
   // in a non-ObjC literal.
   if (!IsObjCLiteral && CS.isObjCArg()) {
     HandleInvalidConversionSpecifier(FS, startSpecifier, specifierLen);
-    
+
     // Continue checking the other format specifiers.
     return true;
   }
@@ -1270,27 +1233,27 @@
     // makes no sense.  Worth issuing a warning at some point.
     return true;
   }
-  
-  ++NumConversions;  
-  
+
+  ++NumConversions;
+
   // Are we using '%n'?  Issue a warning about this being
   // a possible security issue.
   if (CS.getKind() == ConversionSpecifier::OutIntPtrArg) {
     S.Diag(getLocationOfByte(CS.getStart()), diag::warn_printf_write_back)
-      << getFormatSpecifierRange(startSpecifier, specifierLen);           
+      << getFormatSpecifierRange(startSpecifier, specifierLen);
     // Continue checking the other format specifiers.
     return true;
   }
 
   if (CS.getKind() == ConversionSpecifier::VoidPtrArg) {
     if (FS.getPrecision().getHowSpecified() != OptionalAmount::NotSpecified)
-      S.Diag(getLocationOfByte(CS.getStart()), 
+      S.Diag(getLocationOfByte(CS.getStart()),
              diag::warn_printf_nonsensical_precision)
         << CS.getCharacters()
         << getFormatSpecifierRange(startSpecifier, specifierLen);
   }
-  if (CS.getKind() == ConversionSpecifier::VoidPtrArg || 
-      CS.getKind() == ConversionSpecifier::CStrArg) {    
+  if (CS.getKind() == ConversionSpecifier::VoidPtrArg ||
+      CS.getKind() == ConversionSpecifier::CStrArg) {
     // FIXME: Instead of using "0", "+", etc., eventually get them from
     // the FormatSpecifier.
     if (FS.hasLeadingZeros())
@@ -1299,42 +1262,38 @@
       HandleFlags(FS, "+", CS.getCharacters(), startSpecifier, specifierLen);
     if (FS.hasSpacePrefix())
       HandleFlags(FS, " ", CS.getCharacters(), startSpecifier, specifierLen);
-  }  
-  
+  }
+
   // The remaining checks depend on the data arguments.
   if (HasVAListArg)
     return true;
-  
+
   if (NumConversions > NumDataArgs) {
     S.Diag(getLocationOfByte(CS.getStart()),
            diag::warn_printf_insufficient_data_args)
-      << getFormatSpecifierRange(startSpecifier, specifierLen);    
+      << getFormatSpecifierRange(startSpecifier, specifierLen);
     // Don't do any more checking.
     return false;
   }
-  
+
   // Now type check the data expression that matches the
   // format specifier.
   const Expr *Ex = getDataArg(NumConversions);
   const analyze_printf::ArgTypeResult &ATR = FS.getArgType(S.Context);
-  
-  if (const QualType *T = ATR.getSpecificType()) {
-    if (!MatchType(*T, Ex->getType(), true)) {
-      // Check if we didn't match because of an implicit cast from a 'char'
-      // or 'short' to an 'int'.  This is done because printf is a varargs
-      // function.
-      if (const ImplicitCastExpr *ICE = dyn_cast<ImplicitCastExpr>(Ex))
-        if (ICE->getType() == S.Context.IntTy)
-          if (MatchType(*T, ICE->getSubExpr()->getType(), true))
-            return true;
+  if (ATR.isValid() && !ATR.matchesType(S.Context, Ex->getType())) {
+    // Check if we didn't match because of an implicit cast from a 'char'
+    // or 'short' to an 'int'.  This is done because printf is a varargs
+    // function.
+    if (const ImplicitCastExpr *ICE = dyn_cast<ImplicitCastExpr>(Ex))
+      if (ICE->getType() == S.Context.IntTy)
+        if (ATR.matchesType(S.Context, ICE->getSubExpr()->getType()))
+          return true;
 
-      S.Diag(getLocationOfByte(CS.getStart()),
-             diag::warn_printf_conversion_argument_type_mismatch)
-      << *T << Ex->getType()
+    S.Diag(getLocationOfByte(CS.getStart()),
+           diag::warn_printf_conversion_argument_type_mismatch)
+      << ATR.getRepresentativeType(S.Context) << Ex->getType()
       << getFormatSpecifierRange(startSpecifier, specifierLen)
       << Ex->getSourceRange();
-    }
-    return true;
   }
 
   return true;
@@ -1361,19 +1320,19 @@
     << OrigFormatExpr->getSourceRange();
     return;
   }
-  
+
   // Str - The format string.  NOTE: this is NOT null-terminated!
   const char *Str = FExpr->getStrData();
-  
+
   // CHECK: empty format string?
   unsigned StrLen = FExpr->getByteLength();
-  
+
   if (StrLen == 0) {
     Diag(FExpr->getLocStart(), diag::warn_printf_empty_format_string)
     << OrigFormatExpr->getSourceRange();
     return;
   }
-  
+
   CheckPrintfHandler H(*this, FExpr, OrigFormatExpr,
                        TheCall->getNumArgs() - firstDataArg,
                        isa<ObjCStringLiteral>(OrigFormatExpr), Str,
@@ -1407,11 +1366,11 @@
       if (C->hasBlockDeclRefExprs())
         Diag(C->getLocStart(), diag::err_ret_local_block)
           << C->getSourceRange();
-    
+
     if (AddrLabelExpr *ALE = dyn_cast<AddrLabelExpr>(RetValExp))
       Diag(ALE->getLocStart(), diag::warn_ret_addr_label)
         << ALE->getSourceRange();
-    
+
   } else if (lhsType->isReferenceType()) {
     // Perform checking for stack values returned by reference.
     // Check for a reference to the stack
@@ -1887,7 +1846,7 @@
       if (BO->getLHS()->getType()->isPointerType())
         return IntRange::forType(C, E->getType());
       // fallthrough
-      
+
     default:
       break;
     }
@@ -2328,7 +2287,7 @@
   CFG *cfg = AC.getCFG();
   if (cfg == 0)
     return;
-  
+
   llvm::BitVector live(cfg->getNumBlockIDs());
   // Mark all live things first.
   count = MarkLive(&cfg->getEntry(), live);
@@ -2527,7 +2486,7 @@
   // which this code would then warn about.
   if (getDiagnostics().hasErrorOccurred())
     return;
-  
+
   bool ReturnsVoid = false;
   bool HasNoReturn = false;
   if (FunctionDecl *FD = dyn_cast<FunctionDecl>(D)) {