Check that arguments to a scanf call match the format specifier,
and offer fixits when there is a mismatch.

git-svn-id: https://llvm.org/svn/llvm-project/cfe/trunk@146326 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Analysis/ScanfFormatString.cpp b/lib/Analysis/ScanfFormatString.cpp
index 6a8673a..77d9c96 100644
--- a/lib/Analysis/ScanfFormatString.cpp
+++ b/lib/Analysis/ScanfFormatString.cpp
@@ -20,9 +20,11 @@
 using clang::analyze_format_string::LengthModifier;
 using clang::analyze_format_string::OptionalAmount;
 using clang::analyze_format_string::ConversionSpecifier;
+using clang::analyze_scanf::ScanfArgTypeResult;
 using clang::analyze_scanf::ScanfConversionSpecifier;
 using clang::analyze_scanf::ScanfSpecifier;
 using clang::UpdateOnReturn;
+using namespace clang;
 
 typedef clang::analyze_format_string::SpecifierResult<ScanfSpecifier>
         ScanfSpecifierResult;
@@ -190,7 +192,213 @@
   }
   return ScanfSpecifierResult(Start, FS);
 }
-  
+
+ScanfArgTypeResult ScanfSpecifier::getArgType(ASTContext &Ctx) const {
+  const ScanfConversionSpecifier &CS = getConversionSpecifier();
+
+  if (!CS.consumesDataArgument())
+    return ScanfArgTypeResult::Invalid();
+
+  switch(CS.getKind()) {
+    // Signed int.
+    case ConversionSpecifier::dArg:
+    case ConversionSpecifier::iArg:
+      switch (LM.getKind()) {
+        case LengthModifier::None: return ArgTypeResult(Ctx.IntTy);
+        case LengthModifier::AsChar:
+          return ArgTypeResult(ArgTypeResult::AnyCharTy);
+        case LengthModifier::AsShort: return ArgTypeResult(Ctx.ShortTy);
+        case LengthModifier::AsLong: return ArgTypeResult(Ctx.LongTy);
+        case LengthModifier::AsLongLong: return ArgTypeResult(Ctx.LongLongTy);
+        case LengthModifier::AsIntMax:
+          return ScanfArgTypeResult(Ctx.getIntMaxType(), "intmax_t *");
+        case LengthModifier::AsSizeT:
+          // FIXME: ssize_t.
+          return ScanfArgTypeResult();
+        case LengthModifier::AsPtrDiff:
+          return ScanfArgTypeResult(Ctx.getPointerDiffType(), "ptrdiff_t *");
+        case LengthModifier::AsLongDouble: return ScanfArgTypeResult::Invalid();
+      }
+
+    // Unsigned int.
+    case ConversionSpecifier::oArg:
+    case ConversionSpecifier::uArg:
+    case ConversionSpecifier::xArg:
+    case ConversionSpecifier::XArg:
+      switch (LM.getKind()) {
+        case LengthModifier::None: return ArgTypeResult(Ctx.UnsignedIntTy);
+        case LengthModifier::AsChar: return ArgTypeResult(Ctx.UnsignedCharTy);
+        case LengthModifier::AsShort: return ArgTypeResult(Ctx.UnsignedShortTy);
+        case LengthModifier::AsLong: return ArgTypeResult(Ctx.UnsignedLongTy);
+        case LengthModifier::AsLongLong:
+          return ArgTypeResult(Ctx.UnsignedLongLongTy);
+        case LengthModifier::AsIntMax:
+          return ScanfArgTypeResult(Ctx.getUIntMaxType(), "uintmax_t *");
+        case LengthModifier::AsSizeT:
+          return ScanfArgTypeResult(Ctx.getSizeType(), "size_t *");
+        case LengthModifier::AsPtrDiff:
+          // FIXME: Unsigned version of ptrdiff_t?
+          return ScanfArgTypeResult();
+        case LengthModifier::AsLongDouble: return ScanfArgTypeResult::Invalid();
+      }
+
+    // Float.
+    case ConversionSpecifier::aArg:
+    case ConversionSpecifier::AArg:
+    case ConversionSpecifier::eArg:
+    case ConversionSpecifier::EArg:
+    case ConversionSpecifier::fArg:
+    case ConversionSpecifier::FArg:
+    case ConversionSpecifier::gArg:
+    case ConversionSpecifier::GArg:
+      switch (LM.getKind()) {
+        case LengthModifier::None: return ArgTypeResult(Ctx.FloatTy);
+        case LengthModifier::AsLong: return ArgTypeResult(Ctx.DoubleTy);
+        case LengthModifier::AsLongDouble:
+          return ArgTypeResult(Ctx.LongDoubleTy);
+        default:
+          return ScanfArgTypeResult::Invalid();
+      }
+
+    // Char, string and scanlist.
+    case ConversionSpecifier::cArg:
+    case ConversionSpecifier::sArg:
+    case ConversionSpecifier::ScanListArg:
+      switch (LM.getKind()) {
+        case LengthModifier::None: return ScanfArgTypeResult::CStrTy;
+        case LengthModifier::AsLong:
+          return ScanfArgTypeResult(ScanfArgTypeResult::WCStrTy, "wchar_t *");
+        default:
+          return ScanfArgTypeResult::Invalid();
+      }
+    case ConversionSpecifier::CArg:
+    case ConversionSpecifier::SArg:
+      // FIXME: Mac OS X specific?
+      return ScanfArgTypeResult(ScanfArgTypeResult::WCStrTy, "wchar_t *");
+
+    // Pointer.
+    case ConversionSpecifier::pArg:
+      return ScanfArgTypeResult(ArgTypeResult(ArgTypeResult::CPointerTy));
+
+    default:
+      break;
+  }
+
+  return ScanfArgTypeResult();
+}
+
+bool ScanfSpecifier::fixType(QualType QT, const LangOptions &LangOpt)
+{
+  if (!QT->isPointerType())
+    return false;
+
+  QualType PT = QT->getPointeeType();
+  const BuiltinType *BT = PT->getAs<BuiltinType>();
+  if (!BT)
+    return false;
+
+  // Pointer to a character.
+  if (PT->isAnyCharacterType()) {
+    CS.setKind(ConversionSpecifier::sArg);
+    if (PT->isWideCharType())
+      LM.setKind(LengthModifier::AsWideChar);
+    else
+      LM.setKind(LengthModifier::None);
+    return true;
+  }
+
+  // Figure out the length modifier.
+  switch (BT->getKind()) {
+    // no modifier
+    case BuiltinType::UInt:
+    case BuiltinType::Int:
+    case BuiltinType::Float:
+      LM.setKind(LengthModifier::None);
+      break;
+
+    // hh
+    case BuiltinType::Char_U:
+    case BuiltinType::UChar:
+    case BuiltinType::Char_S:
+    case BuiltinType::SChar:
+      LM.setKind(LengthModifier::AsChar);
+      break;
+
+    // h
+    case BuiltinType::Short:
+    case BuiltinType::UShort:
+      LM.setKind(LengthModifier::AsShort);
+      break;
+
+    // l
+    case BuiltinType::Long:
+    case BuiltinType::ULong:
+    case BuiltinType::Double:
+      LM.setKind(LengthModifier::AsLong);
+      break;
+
+    // ll
+    case BuiltinType::LongLong:
+    case BuiltinType::ULongLong:
+      LM.setKind(LengthModifier::AsLongLong);
+      break;
+
+    // L
+    case BuiltinType::LongDouble:
+      LM.setKind(LengthModifier::AsLongDouble);
+      break;
+
+    // Don't know.
+    default:
+      return false;
+  }
+
+  // Handle size_t, ptrdiff_t, etc. that have dedicated length modifiers in C99.
+  if (isa<TypedefType>(PT) && (LangOpt.C99 || LangOpt.CPlusPlus0x)) {
+    const IdentifierInfo *Identifier = QT.getBaseTypeIdentifier();
+    if (Identifier->getName() == "size_t") {
+      LM.setKind(LengthModifier::AsSizeT);
+    } else if (Identifier->getName() == "ssize_t") {
+      // Not C99, but common in Unix.
+      LM.setKind(LengthModifier::AsSizeT);
+    } else if (Identifier->getName() == "intmax_t") {
+      LM.setKind(LengthModifier::AsIntMax);
+    } else if (Identifier->getName() == "uintmax_t") {
+      LM.setKind(LengthModifier::AsIntMax);
+    } else if (Identifier->getName() == "ptrdiff_t") {
+      LM.setKind(LengthModifier::AsPtrDiff);
+    }
+  }
+
+  // Figure out the conversion specifier.
+  if (PT->isRealFloatingType())
+    CS.setKind(ConversionSpecifier::fArg);
+  else if (PT->isSignedIntegerType())
+    CS.setKind(ConversionSpecifier::dArg);
+  else if (PT->isUnsignedIntegerType()) {
+    // Preserve the original formatting, e.g. 'X', 'o'.
+    if (!CS.isUIntArg()) {
+      CS.setKind(ConversionSpecifier::uArg);
+    }
+  } else
+    llvm_unreachable("Unexpected type");
+
+  return true;
+}
+
+void ScanfSpecifier::toString(raw_ostream &os) const {
+  os << "%";
+
+  if (usesPositionalArg())
+    os << getPositionalArgIndex() << "$";
+  if (SuppressAssignment)
+    os << "*";
+
+  FieldWidth.toString(os);
+  os << LM.toString();
+  os << CS.toString();
+}
+
 bool clang::analyze_format_string::ParseScanfString(FormatStringHandler &H,
                                                     const char *I,
                                                     const char *E) {
@@ -218,4 +426,47 @@
   return false;
 }
 
+bool ScanfArgTypeResult::matchesType(ASTContext& C, QualType argTy) const {
+  switch (K) {
+    case InvalidTy:
+      llvm_unreachable("ArgTypeResult must be valid");
+    case UnknownTy:
+      return true;
+    case CStrTy:
+      return ArgTypeResult(ArgTypeResult::CStrTy).matchesType(C, argTy);
+    case WCStrTy:
+      return ArgTypeResult(ArgTypeResult::WCStrTy).matchesType(C, argTy);
+    case PtrToArgTypeResultTy: {
+      const PointerType *PT = argTy->getAs<PointerType>();
+      if (!PT)
+        return false;
+      return A.matchesType(C, PT->getPointeeType());
+    }
+  }
 
+  return false; // Unreachable, but we still get a warning.
+}
+
+QualType ScanfArgTypeResult::getRepresentativeType(ASTContext &C) const {
+  switch (K) {
+    case InvalidTy:
+      llvm_unreachable("No representative type for Invalid ArgTypeResult");
+    case UnknownTy:
+      return QualType();
+    case CStrTy:
+      return C.getPointerType(C.CharTy);
+    case WCStrTy:
+      return C.getPointerType(C.getWCharType());
+    case PtrToArgTypeResultTy:
+      return C.getPointerType(A.getRepresentativeType(C));
+  }
+
+  return QualType(); // Not reachable.
+}
+
+std::string ScanfArgTypeResult::getRepresentativeTypeName(ASTContext& C) const {
+  std::string S = getRepresentativeType(C).getAsString();
+  if (!Name)
+    return std::string("'") + S + "'";
+  return std::string("'") + Name + "' (aka '" + S + "')";
+}
diff --git a/lib/Sema/SemaChecking.cpp b/lib/Sema/SemaChecking.cpp
index db60f23..1c93931 100644
--- a/lib/Sema/SemaChecking.cpp
+++ b/lib/Sema/SemaChecking.cpp
@@ -2371,8 +2371,38 @@
   if (!CheckNumArgs(FS, CS, startSpecifier, specifierLen, argIndex))
     return false;
   
-  // FIXME: Check that the argument type matches the format specifier.
-  
+  // Check that the argument type matches the format specifier.
+  const Expr *Ex = getDataArg(argIndex);
+  const analyze_scanf::ScanfArgTypeResult &ATR = FS.getArgType(S.Context);
+  if (ATR.isValid() && !ATR.matchesType(S.Context, Ex->getType())) {
+    ScanfSpecifier fixedFS = FS;
+    bool success = fixedFS.fixType(Ex->getType(), S.getLangOptions());
+
+    if (success) {
+      // Get the fix string from the fixed format specifier.
+      llvm::SmallString<128> buf;
+      llvm::raw_svector_ostream os(buf);
+      fixedFS.toString(os);
+
+      EmitFormatDiagnostic(
+        S.PDiag(diag::warn_printf_conversion_argument_type_mismatch)
+          << ATR.getRepresentativeTypeName(S.Context) << Ex->getType()
+          << Ex->getSourceRange(),
+        getLocationOfByte(CS.getStart()),
+        /*IsStringLocation*/true,
+        getSpecifierRange(startSpecifier, specifierLen),
+        FixItHint::CreateReplacement(
+          getSpecifierRange(startSpecifier, specifierLen),
+          os.str()));
+    } else {
+      S.Diag(getLocationOfByte(CS.getStart()),
+             diag::warn_printf_conversion_argument_type_mismatch)
+          << ATR.getRepresentativeTypeName(S.Context) << Ex->getType()
+          << getSpecifierRange(startSpecifier, specifierLen)
+          << Ex->getSourceRange();
+    }
+  }
+
   return true;
 }