blob: 4f91606bfa421e9033db9448a33a7e9a3a8447c8 [file] [log] [blame]
//===--- RewriteBlocks.cpp ----------------------------------------------===//
//
// The LLVM Compiler Infrastructure
//
// This file is distributed under the University of Illinois Open Source
// License. See LICENSE.TXT for details.
//
//===----------------------------------------------------------------------===//
//
// Hacks and fun related to the closure rewriter.
//
//===----------------------------------------------------------------------===//
#include "ASTConsumers.h"
#include "clang/Rewrite/Rewriter.h"
#include "clang/AST/AST.h"
#include "clang/AST/ASTConsumer.h"
#include "clang/Basic/SourceManager.h"
#include "clang/Basic/IdentifierTable.h"
#include "clang/Basic/Diagnostic.h"
#include "clang/Basic/LangOptions.h"
#include "llvm/Support/MemoryBuffer.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
#include <sstream>
using namespace clang;
using llvm::utostr;
namespace {
class RewriteBlocks : public ASTConsumer {
Rewriter Rewrite;
Diagnostic &Diags;
const LangOptions &LangOpts;
unsigned RewriteFailedDiag;
unsigned NoNestedBlockCalls;
ASTContext *Context;
SourceManager *SM;
unsigned MainFileID;
const char *MainFileStart, *MainFileEnd;
// Block expressions.
llvm::SmallVector<BlockExpr *, 32> Blocks;
llvm::SmallVector<BlockDeclRefExpr *, 32> BlockDeclRefs;
llvm::DenseMap<BlockDeclRefExpr *, CallExpr *> BlockCallExprs;
// Block related declarations.
llvm::SmallPtrSet<ValueDecl *, 8> BlockByCopyDecls;
llvm::SmallPtrSet<ValueDecl *, 8> BlockByRefDecls;
// The function/method we are rewriting.
FunctionDecl *CurFunctionDef;
ObjCMethodDecl *CurMethodDef;
bool IsHeader;
std::string InFileName;
std::string OutFileName;
public:
RewriteBlocks(std::string inFile, std::string outFile, Diagnostic &D,
const LangOptions &LOpts);
~RewriteBlocks() {
// Get the buffer corresponding to MainFileID.
// If we haven't changed it, then we are done.
if (const RewriteBuffer *RewriteBuf =
Rewrite.getRewriteBufferFor(MainFileID)) {
std::string S(RewriteBuf->begin(), RewriteBuf->end());
printf("%s\n", S.c_str());
} else {
printf("No changes\n");
}
}
void Initialize(ASTContext &context);
void InsertText(SourceLocation Loc, const char *StrData, unsigned StrLen);
void ReplaceText(SourceLocation Start, unsigned OrigLength,
const char *NewStr, unsigned NewLength);
// Top Level Driver code.
virtual void HandleTopLevelDecl(Decl *D);
void HandleDeclInMainFile(Decl *D);
// Top level
Stmt *RewriteFunctionBody(Stmt *S);
void InsertBlockLiteralsWithinFunction(FunctionDecl *FD);
void InsertBlockLiteralsWithinMethod(ObjCMethodDecl *MD);
// Block specific rewrite rules.
void RewriteBlockExpr(BlockExpr *Exp);
void RewriteBlockCall(CallExpr *Exp);
void RewriteBlockPointerDecl(NamedDecl *VD);
void RewriteBlockPointerFunctionArgs(FunctionDecl *FD);
std::string SynthesizeBlockFunc(BlockExpr *CE, int i,
const char *funcName, std::string Tag);
std::string SynthesizeBlockImpl(BlockExpr *CE, std::string Tag);
std::string SynthesizeBlockCall(CallExpr *Exp);
void SynthesizeBlockLiterals(SourceLocation FunLocStart,
const char *FunName);
void GetBlockDeclRefExprs(Stmt *S);
void GetBlockCallExprs(Stmt *S);
// We avoid calling Type::isBlockPointerType(), since it operates on the
// canonical type. We only care if the top-level type is a closure pointer.
bool isBlockPointerType(QualType T) { return isa<BlockPointerType>(T); }
// FIXME: This predicate seems like it would be useful to add to ASTContext.
bool isObjCType(QualType T) {
if (!LangOpts.ObjC1 && !LangOpts.ObjC2)
return false;
QualType OCT = Context->getCanonicalType(T).getUnqualifiedType();
if (OCT == Context->getCanonicalType(Context->getObjCIdType()) ||
OCT == Context->getCanonicalType(Context->getObjCClassType()))
return true;
if (const PointerType *PT = OCT->getAsPointerType()) {
if (isa<ObjCInterfaceType>(PT->getPointeeType()) ||
isa<ObjCQualifiedIdType>(PT->getPointeeType()))
return true;
}
return false;
}
// ObjC rewrite methods.
void RewriteInterfaceDecl(ObjCInterfaceDecl *ClassDecl);
void RewriteCategoryDecl(ObjCCategoryDecl *CatDecl);
void RewriteProtocolDecl(ObjCProtocolDecl *PDecl);
void RewriteMethodDecl(ObjCMethodDecl *MDecl);
};
}
static bool IsHeaderFile(const std::string &Filename) {
std::string::size_type DotPos = Filename.rfind('.');
if (DotPos == std::string::npos) {
// no file extension
return false;
}
std::string Ext = std::string(Filename.begin()+DotPos+1, Filename.end());
// C header: .h
// C++ header: .hh or .H;
return Ext == "h" || Ext == "hh" || Ext == "H";
}
RewriteBlocks::RewriteBlocks(std::string inFile, std::string outFile,
Diagnostic &D, const LangOptions &LOpts) :
Diags(D), LangOpts(LOpts) {
IsHeader = IsHeaderFile(inFile);
InFileName = inFile;
OutFileName = outFile;
CurFunctionDef = 0;
CurMethodDef = 0;
RewriteFailedDiag = Diags.getCustomDiagID(Diagnostic::Warning,
"rewriting failed");
NoNestedBlockCalls = Diags.getCustomDiagID(Diagnostic::Warning,
"Rewrite support for closure calls nested within closure blocks is incomplete");
}
ASTConsumer *clang::CreateBlockRewriter(const std::string& InFile,
const std::string& OutFile,
Diagnostic &Diags,
const LangOptions &LangOpts) {
return new RewriteBlocks(InFile, OutFile, Diags, LangOpts);
}
void RewriteBlocks::Initialize(ASTContext &context) {
Context = &context;
SM = &Context->getSourceManager();
// Get the ID and start/end of the main file.
MainFileID = SM->getMainFileID();
const llvm::MemoryBuffer *MainBuf = SM->getBuffer(MainFileID);
MainFileStart = MainBuf->getBufferStart();
MainFileEnd = MainBuf->getBufferEnd();
Rewrite.setSourceMgr(Context->getSourceManager());
const char *s = "#pragma once\n"
"#ifndef CLOSURE_IMPL\n"
"struct __closure_impl {\n"
" long Reserved;\n"
" int Flags;\n"
" int Size;\n"
" void *Invoke;\n"
"};\n"
"enum {\n"
" HAS_NONPOD = (1<<25),\n"
" HAS_BYREF = (1<<26)\n"
"};\n"
"#define CLOSURE_IMPL\n"
"#endif\n";
if (IsHeader) {
// insert the whole string when rewriting a header file
InsertText(SourceLocation::getFileLoc(MainFileID, 0), s, strlen(s));
}
else {
// Not rewriting header, exclude the #pragma once pragma
const char *p = s + strlen("#pragma once\n");
InsertText(SourceLocation::getFileLoc(MainFileID, 0), p, strlen(p));
}
}
void RewriteBlocks::InsertText(SourceLocation Loc, const char *StrData,
unsigned StrLen)
{
if (!Rewrite.InsertText(Loc, StrData, StrLen))
return;
Diags.Report(Context->getFullLoc(Loc), RewriteFailedDiag);
}
void RewriteBlocks::ReplaceText(SourceLocation Start, unsigned OrigLength,
const char *NewStr, unsigned NewLength) {
if (!Rewrite.ReplaceText(Start, OrigLength, NewStr, NewLength))
return;
Diags.Report(Context->getFullLoc(Start), RewriteFailedDiag);
}
void RewriteBlocks::RewriteMethodDecl(ObjCMethodDecl *Method) {
bool haveBlockPtrs = false;
for (ObjCMethodDecl::param_iterator I = Method->param_begin(),
E = Method->param_end(); I != E; ++I)
if (isBlockPointerType((*I)->getType()))
haveBlockPtrs = true;
if (!haveBlockPtrs)
return;
// Do a fuzzy rewrite.
// We have 1 or more arguments that have closure pointers.
SourceLocation Loc = Method->getLocStart();
SourceLocation LocEnd = Method->getLocEnd();
const char *startBuf = SM->getCharacterData(Loc);
const char *endBuf = SM->getCharacterData(LocEnd);
const char *methodPtr = startBuf;
std::string Tag = "struct __closure_impl *";
while (*methodPtr++ && (methodPtr != endBuf)) {
switch (*methodPtr) {
case ':':
methodPtr++;
if (*methodPtr == '(') {
const char *scanType = ++methodPtr;
bool foundBlockPointer = false;
unsigned parenCount = 1;
while (parenCount) {
switch (*scanType) {
case '(':
parenCount++;
break;
case ')':
parenCount--;
break;
case '^':
foundBlockPointer = true;
break;
}
scanType++;
}
if (foundBlockPointer) {
// advance the location to startArgList.
Loc = Loc.getFileLocWithOffset(methodPtr-startBuf);
assert((Loc.isValid()) && "Invalid Loc");
ReplaceText(Loc, scanType-methodPtr-1, Tag.c_str(), Tag.size());
// Advance startBuf. Since the underlying buffer has changed,
// it's very important to advance startBuf (so we can correctly
// compute a relative Loc the next time around).
startBuf = methodPtr;
}
// Advance the method ptr to the end of the type.
methodPtr = scanType;
}
break;
}
}
return;
}
void RewriteBlocks::RewriteInterfaceDecl(ObjCInterfaceDecl *ClassDecl) {
for (ObjCInterfaceDecl::instmeth_iterator I = ClassDecl->instmeth_begin(),
E = ClassDecl->instmeth_end(); I != E; ++I)
RewriteMethodDecl(*I);
for (ObjCInterfaceDecl::classmeth_iterator I = ClassDecl->classmeth_begin(),
E = ClassDecl->classmeth_end(); I != E; ++I)
RewriteMethodDecl(*I);
}
void RewriteBlocks::RewriteCategoryDecl(ObjCCategoryDecl *CatDecl) {
for (ObjCCategoryDecl::instmeth_iterator I = CatDecl->instmeth_begin(),
E = CatDecl->instmeth_end(); I != E; ++I)
RewriteMethodDecl(*I);
for (ObjCCategoryDecl::classmeth_iterator I = CatDecl->classmeth_begin(),
E = CatDecl->classmeth_end(); I != E; ++I)
RewriteMethodDecl(*I);
}
void RewriteBlocks::RewriteProtocolDecl(ObjCProtocolDecl *PDecl) {
for (ObjCProtocolDecl::instmeth_iterator I = PDecl->instmeth_begin(),
E = PDecl->instmeth_end(); I != E; ++I)
RewriteMethodDecl(*I);
for (ObjCProtocolDecl::classmeth_iterator I = PDecl->classmeth_begin(),
E = PDecl->classmeth_end(); I != E; ++I)
RewriteMethodDecl(*I);
}
//===----------------------------------------------------------------------===//
// Top Level Driver Code
//===----------------------------------------------------------------------===//
void RewriteBlocks::HandleTopLevelDecl(Decl *D) {
// Two cases: either the decl could be in the main file, or it could be in a
// #included file. If the former, rewrite it now. If the later, check to see
// if we rewrote the #include/#import.
SourceLocation Loc = D->getLocation();
Loc = SM->getLogicalLoc(Loc);
// If this is for a builtin, ignore it.
if (Loc.isInvalid()) return;
if (ObjCInterfaceDecl *MD = dyn_cast<ObjCInterfaceDecl>(D))
RewriteInterfaceDecl(MD);
else if (ObjCCategoryDecl *CD = dyn_cast<ObjCCategoryDecl>(D))
RewriteCategoryDecl(CD);
else if (ObjCProtocolDecl *PD = dyn_cast<ObjCProtocolDecl>(D))
RewriteProtocolDecl(PD);
// If we have a decl in the main file, see if we should rewrite it.
if (SM->getDecomposedFileLoc(Loc).first == MainFileID)
HandleDeclInMainFile(D);
return;
}
std::string RewriteBlocks::SynthesizeBlockFunc(BlockExpr *CE, int i,
const char *funcName,
std::string Tag) {
const FunctionType *AFT = CE->getFunctionType();
QualType RT = AFT->getResultType();
std::string S = "static " + RT.getAsString() + " __" +
funcName + "_" + "closure_" + utostr(i);
if (isa<FunctionTypeNoProto>(AFT)) {
S += "()";
} else if (CE->arg_empty()) {
S += "(" + Tag + " *__cself)";
} else {
const FunctionTypeProto *FT = cast<FunctionTypeProto>(AFT);
assert(FT && "SynthesizeBlockFunc: No function proto");
S += '(';
// first add the implicit argument.
S += Tag + " *__cself, ";
std::string ParamStr;
for (BlockExpr::arg_iterator AI = CE->arg_begin(),
E = CE->arg_end(); AI != E; ++AI) {
if (AI != CE->arg_begin()) S += ", ";
ParamStr = (*AI)->getName();
(*AI)->getType().getAsStringInternal(ParamStr);
S += ParamStr;
}
if (FT->isVariadic()) {
if (!CE->arg_empty()) S += ", ";
S += "...";
}
S += ')';
}
S += " {\n";
bool haveByRefDecls = false;
// Create local declarations to avoid rewriting all closure decl ref exprs.
// First, emit a declaration for all "by ref" decls.
for (llvm::SmallPtrSet<ValueDecl*,8>::iterator I = BlockByRefDecls.begin(),
E = BlockByRefDecls.end(); I != E; ++I) {
// Note: It is not possible to have "by ref" closure pointer decls.
haveByRefDecls = true;
S += " ";
std::string Name = (*I)->getName();
Context->getPointerType((*I)->getType()).getAsStringInternal(Name);
S += Name + " = __cself->" + (*I)->getName() + "; // bound by ref\n";
}
// Next, emit a declaration for all "by copy" declarations.
for (llvm::SmallPtrSet<ValueDecl*,8>::iterator I = BlockByCopyDecls.begin(),
E = BlockByCopyDecls.end(); I != E; ++I) {
S += " ";
std::string Name = (*I)->getName();
// Handle nested closure invocation. For example:
//
// void (^myImportedClosure)(void);
// myImportedClosure = ^(void) { setGlobalInt(x + y); };
//
// void (^anotherClosure)(void);
// anotherClosure = ^(void) {
// myImportedClosure(); // import and invoke the closure
// };
//
if (isBlockPointerType((*I)->getType()))
S += "struct __closure_impl *";
else
(*I)->getType().getAsStringInternal(Name);
S += Name + " = __cself->" + (*I)->getName() + "; // bound by copy\n";
}
if (BlockExpr *CBE = dyn_cast<BlockExpr>(CE)) {
std::string BodyBuf;
SourceLocation BodyLocStart = CBE->getBody()->getLocStart();
SourceLocation BodyLocEnd = CBE->getBody()->getLocEnd();
const char *BodyStartBuf = SM->getCharacterData(BodyLocStart);
const char *BodyEndBuf = SM->getCharacterData(BodyLocEnd);
BodyBuf.append(BodyStartBuf, BodyEndBuf-BodyStartBuf+1);
if (BlockDeclRefs.size()) {
unsigned int nCharsAdded = 0;
for (unsigned i = 0; i < BlockDeclRefs.size(); i++) {
if (BlockDeclRefs[i]->isByRef()) {
// Add a level of indirection! The code below assumes
// the closure decl refs/locations are in strictly ascending
// order. The traversal performed by GetBlockDeclRefExprs()
// currently does this. FIXME: Wrap the *x with parens,
// just in case x is a more complex expression, like x->member,
// which needs to be rewritten to (*x)->member.
SourceLocation StarLoc = BlockDeclRefs[i]->getLocStart();
const char *StarBuf = SM->getCharacterData(StarLoc);
BodyBuf.insert(StarBuf-BodyStartBuf+nCharsAdded, 1, '*');
// Get a fresh buffer, the insert might have caused it to grow.
BodyStartBuf = SM->getCharacterData(BodyLocStart);
nCharsAdded++;
} else if (isBlockPointerType(BlockDeclRefs[i]->getType())) {
Diags.Report(NoNestedBlockCalls);
GetBlockCallExprs(CE);
// Rewrite the closure in place.
// The character based equivalent of RewriteBlockCall().
// Need to get the CallExpr associated with this BlockDeclRef.
std::string BlockCall = SynthesizeBlockCall(BlockCallExprs[BlockDeclRefs[i]]);
SourceLocation CallLocStart = BlockCallExprs[BlockDeclRefs[i]]->getLocStart();
SourceLocation CallLocEnd = BlockCallExprs[BlockDeclRefs[i]]->getLocEnd();
const char *CallStart = SM->getCharacterData(CallLocStart);
const char *CallEnd = SM->getCharacterData(CallLocEnd);
unsigned CallBytes = CallEnd-CallStart;
BodyBuf.replace(CallStart-BodyStartBuf, CallBytes, BlockCall.c_str());
nCharsAdded += CallBytes;
}
}
}
if (haveByRefDecls) {
// Remove |...|.
const char *firstBarPtr = strchr(BodyStartBuf, '|');
const char *secondBarPtr = strchr(firstBarPtr+1, '|');
BodyBuf.replace(firstBarPtr-BodyStartBuf, secondBarPtr-firstBarPtr+1, "");
}
S += " ";
S += BodyBuf;
}
S += "\n}\n";
return S;
}
std::string RewriteBlocks::SynthesizeBlockImpl(BlockExpr *CE,
std::string Tag) {
std::string S = Tag + " {\n struct __closure_impl impl;\n";
GetBlockDeclRefExprs(CE);
if (BlockDeclRefs.size()) {
// Unique all "by copy" declarations.
for (unsigned i = 0; i < BlockDeclRefs.size(); i++)
if (!BlockDeclRefs[i]->isByRef())
BlockByCopyDecls.insert(BlockDeclRefs[i]->getDecl());
// Unique all "by ref" declarations.
for (unsigned i = 0; i < BlockDeclRefs.size(); i++)
if (BlockDeclRefs[i]->isByRef())
BlockByRefDecls.insert(BlockDeclRefs[i]->getDecl());
// Output all "by copy" declarations.
for (llvm::SmallPtrSet<ValueDecl*,8>::iterator I = BlockByCopyDecls.begin(),
E = BlockByCopyDecls.end(); I != E; ++I) {
S += " ";
std::string Name = (*I)->getName();
// Handle nested closure invocation. For example:
//
// void (^myImportedBlock)(void);
// myImportedBlock = ^(void) { setGlobalInt(x + y); };
//
// void (^anotherBlock)(void);
// anotherBlock = ^(void) {
// myImportedBlock(); // import and invoke the closure
// };
//
if (isBlockPointerType((*I)->getType()))
S += "struct __closure_impl *";
else
(*I)->getType().getAsStringInternal(Name);
S += Name + ";\n";
}
// Output all "by ref" declarations.
for (llvm::SmallPtrSet<ValueDecl*,8>::iterator I = BlockByRefDecls.begin(),
E = BlockByRefDecls.end(); I != E; ++I) {
S += " ";
std::string Name = (*I)->getName();
if (isBlockPointerType((*I)->getType()))
S += "struct __closure_impl *";
else
Context->getPointerType((*I)->getType()).getAsStringInternal(Name);
S += Name + "; // by ref\n";
}
}
S += "};\n";
return S;
}
void RewriteBlocks::SynthesizeBlockLiterals(SourceLocation FunLocStart,
const char *FunName) {
// Insert closures that were part of the function.
for (unsigned i = 0; i < Blocks.size(); i++) {
std::string Tag = "struct __" + std::string(FunName) +
"_closure_impl_" + utostr(i);
std::string CI = SynthesizeBlockImpl(Blocks[i], Tag);
InsertText(FunLocStart, CI.c_str(), CI.size());
std::string CF = SynthesizeBlockFunc(Blocks[i], i, FunName, Tag);
InsertText(FunLocStart, CF.c_str(), CF.size());
BlockDeclRefs.clear();
BlockByRefDecls.clear();
BlockByCopyDecls.clear();
BlockCallExprs.clear();
}
Blocks.clear();
}
void RewriteBlocks::InsertBlockLiteralsWithinFunction(FunctionDecl *FD) {
SourceLocation FunLocStart = FD->getLocation();
const char *FuncName = FD->getName();
SynthesizeBlockLiterals(FunLocStart, FuncName);
}
void RewriteBlocks::InsertBlockLiteralsWithinMethod(ObjCMethodDecl *MD) {
SourceLocation FunLocStart = MD->getLocStart();
std::string FuncName = std::string(MD->getSelector().getName());
// Convert colons to underscores.
std::string::size_type loc = 0;
while ((loc = FuncName.find(":", loc)) != std::string::npos)
FuncName.replace(loc, 1, "_");
SynthesizeBlockLiterals(FunLocStart, FuncName.c_str());
}
/// HandleDeclInMainFile - This is called for each top-level decl defined in the
/// main file of the input.
void RewriteBlocks::HandleDeclInMainFile(Decl *D) {
if (FunctionDecl *FD = dyn_cast<FunctionDecl>(D)) {
// Since function prototypes don't have ParmDecl's, we check the function
// prototype. This enables us to rewrite function declarations and
// definitions using the same code.
QualType funcType = FD->getType();
if (FunctionTypeProto *fproto = dyn_cast<FunctionTypeProto>(funcType)) {
for (FunctionTypeProto::arg_type_iterator I = fproto->arg_type_begin(),
E = fproto->arg_type_end(); I && (I != E); ++I)
if (isBlockPointerType(*I)) {
// All the args are checked/rewritten. Don't call twice!
RewriteBlockPointerDecl(FD);
break;
}
}
if (Stmt *Body = FD->getBody()) {
CurFunctionDef = FD;
FD->setBody(RewriteFunctionBody(Body));
InsertBlockLiteralsWithinFunction(FD);
CurFunctionDef = 0;
}
return;
}
if (ObjCMethodDecl *MD = dyn_cast<ObjCMethodDecl>(D)) {
RewriteMethodDecl(MD);
if (Stmt *Body = MD->getBody()) {
CurMethodDef = MD;
RewriteFunctionBody(Body);
InsertBlockLiteralsWithinMethod(MD);
CurMethodDef = 0;
}
}
if (ValueDecl *ND = dyn_cast<ValueDecl>(D)) {
if (isBlockPointerType(ND->getType()))
RewriteBlockPointerDecl(ND);
return;
}
if (TypedefDecl *TD = dyn_cast<TypedefDecl>(D)) {
if (isBlockPointerType(TD->getUnderlyingType()))
RewriteBlockPointerDecl(TD);
return;
}
}
void RewriteBlocks::GetBlockDeclRefExprs(Stmt *S) {
for (Stmt::child_iterator CI = S->child_begin(), E = S->child_end();
CI != E; ++CI)
if (*CI)
GetBlockDeclRefExprs(*CI);
// Handle specific things.
if (BlockDeclRefExpr *CDRE = dyn_cast<BlockDeclRefExpr>(S))
// FIXME: Handle enums.
if (!isa<FunctionDecl>(CDRE->getDecl()))
BlockDeclRefs.push_back(CDRE);
return;
}
void RewriteBlocks::GetBlockCallExprs(Stmt *S) {
for (Stmt::child_iterator CI = S->child_begin(), E = S->child_end();
CI != E; ++CI)
if (*CI)
GetBlockCallExprs(*CI);
if (CallExpr *CE = dyn_cast<CallExpr>(S)) {
if (CE->getCallee()->getType()->isBlockPointerType())
BlockCallExprs[dyn_cast<BlockDeclRefExpr>(CE->getCallee())] = CE;
}
return;
}
//===----------------------------------------------------------------------===//
// Function Body / Expression rewriting
//===----------------------------------------------------------------------===//
Stmt *RewriteBlocks::RewriteFunctionBody(Stmt *S) {
// Start by rewriting all children.
for (Stmt::child_iterator CI = S->child_begin(), E = S->child_end();
CI != E; ++CI)
if (*CI) {
if (BlockExpr *CBE = dyn_cast<BlockExpr>(*CI)) {
// We intentionally avoid rewritting the contents of a closure block
// expr. InsertBlockLiteralsWithinFunction() will rewrite the body.
RewriteBlockExpr(CBE);
} else {
Stmt *newStmt = RewriteFunctionBody(*CI);
if (newStmt)
*CI = newStmt;
}
}
// Handle specific things.
if (CallExpr *CE = dyn_cast<CallExpr>(S)) {
if (CE->getCallee()->getType()->isBlockPointerType())
RewriteBlockCall(CE);
}
if (DeclStmt *DS = dyn_cast<DeclStmt>(S)) {
ScopedDecl *SD = DS->getDecl();
if (ValueDecl *ND = dyn_cast<ValueDecl>(SD)) {
if (isBlockPointerType(ND->getType()))
RewriteBlockPointerDecl(ND);
}
if (TypedefDecl *TD = dyn_cast<TypedefDecl>(SD)) {
if (isBlockPointerType(TD->getUnderlyingType()))
RewriteBlockPointerDecl(TD);
}
}
// Return this stmt unmodified.
return S;
}
std::string RewriteBlocks::SynthesizeBlockCall(CallExpr *Exp) {
// Navigate to relevant type information.
const char *closureName;
const BlockPointerType *CPT;
if (const DeclRefExpr *DRE = dyn_cast<DeclRefExpr>(Exp->getCallee())) {
closureName = DRE->getDecl()->getName();
CPT = DRE->getType()->getAsBlockPointerType();
} else if (BlockDeclRefExpr *CDRE = dyn_cast<BlockDeclRefExpr>(Exp->getCallee())) {
closureName = CDRE->getDecl()->getName();
CPT = CDRE->getType()->getAsBlockPointerType();
} else {
assert(1 && "RewriteBlockClass: Bad type");
}
assert(CPT && "RewriteBlockClass: Bad type");
const FunctionType *FT = CPT->getPointeeType()->getAsFunctionType();
assert(FT && "RewriteBlockClass: Bad type");
const FunctionTypeProto *FTP = dyn_cast<FunctionTypeProto>(FT);
// FTP will be null for closures that don't take arguments.
// Build a closure call - start with a paren expr to enforce precedence.
std::string BlockCall = "(";
// Synthesize the cast.
BlockCall += "(" + Exp->getType().getAsString() + "(*)";
BlockCall += "(struct __closure_impl *";
if (FTP) {
for (FunctionTypeProto::arg_type_iterator I = FTP->arg_type_begin(),
E = FTP->arg_type_end(); I && (I != E); ++I)
BlockCall += ", " + (*I).getAsString();
}
BlockCall += "))"; // close the argument list and paren expression.
// Invoke the closure.
BlockCall += closureName;
BlockCall += "->Invoke)";
// Add the arguments.
BlockCall += "(";
BlockCall += closureName;
for (CallExpr::arg_iterator I = Exp->arg_begin(),
E = Exp->arg_end(); I != E; ++I) {
std::string syncExprBufS;
llvm::raw_string_ostream Buf(syncExprBufS);
(*I)->printPretty(Buf);
BlockCall += ", " + Buf.str();
}
return BlockCall;
}
void RewriteBlocks::RewriteBlockCall(CallExpr *Exp) {
std::string BlockCall = SynthesizeBlockCall(Exp);
const char *startBuf = SM->getCharacterData(Exp->getLocStart());
const char *endBuf = SM->getCharacterData(Exp->getLocEnd());
ReplaceText(Exp->getLocStart(), endBuf-startBuf,
BlockCall.c_str(), BlockCall.size());
}
void RewriteBlocks::RewriteBlockPointerFunctionArgs(FunctionDecl *FD) {
SourceLocation DeclLoc = FD->getLocation();
unsigned parenCount = 0, nArgs = 0;
// We have 1 or more arguments that have closure pointers.
const char *startBuf = SM->getCharacterData(DeclLoc);
const char *startArgList = strchr(startBuf, '(');
assert((*startArgList == '(') && "Rewriter fuzzy parser confused");
parenCount++;
// advance the location to startArgList.
DeclLoc = DeclLoc.getFileLocWithOffset(startArgList-startBuf+1);
assert((DeclLoc.isValid()) && "Invalid DeclLoc");
const char *topLevelCommaCursor = 0;
const char *argPtr = startArgList;
bool scannedBlockDecl = false;
std::string Tag = "struct __closure_impl *";
while (*argPtr++ && parenCount) {
switch (*argPtr) {
case '^':
scannedBlockDecl = true;
break;
case '(':
parenCount++;
break;
case ')':
parenCount--;
if (parenCount == 0) {
if (scannedBlockDecl) {
// If we are rewriting a definition, don't forget the arg name.
if (FD->getBody())
Tag += FD->getParamDecl(nArgs)->getName();
// The last argument is a closure pointer decl, rewrite it!
if (topLevelCommaCursor)
ReplaceText(DeclLoc, argPtr-topLevelCommaCursor-2, Tag.c_str(), Tag.size());
else
ReplaceText(DeclLoc, argPtr-startArgList-1, Tag.c_str(), Tag.size());
scannedBlockDecl = false; // reset.
}
nArgs++;
}
break;
case ',':
if (parenCount == 1) {
// Make sure the function takes more than one argument.
assert((FD->getNumParams() > 1) && "Rewriter fuzzy parser confused");
if (scannedBlockDecl) {
// If we are rewriting a definition, don't forget the arg name.
if (FD->getBody())
Tag += FD->getParamDecl(nArgs)->getName();
// The current argument is a closure pointer decl, rewrite it!
if (topLevelCommaCursor)
ReplaceText(DeclLoc, argPtr-topLevelCommaCursor-1, Tag.c_str(), Tag.size());
else
ReplaceText(DeclLoc, argPtr-startArgList-1, Tag.c_str(), Tag.size());
scannedBlockDecl = false;
}
nArgs++;
// advance the location to topLevelCommaCursor.
if (topLevelCommaCursor)
DeclLoc = DeclLoc.getFileLocWithOffset(argPtr-topLevelCommaCursor);
else
DeclLoc = DeclLoc.getFileLocWithOffset(argPtr-startArgList+1);
topLevelCommaCursor = argPtr;
assert((DeclLoc.isValid()) && "Invalid DeclLoc");
}
break;
}
}
return;
}
void RewriteBlocks::RewriteBlockPointerDecl(NamedDecl *ND) {
SourceLocation DeclLoc = ND->getLocation();
const char *startBuf, *endBuf;
if (FunctionDecl *FD = dyn_cast<FunctionDecl>(ND)) {
RewriteBlockPointerFunctionArgs(FD);
return;
} else if (VarDecl *VD = dyn_cast<VarDecl>(ND)) {
DeclLoc = VD->getLocation();
startBuf = SM->getCharacterData(DeclLoc);
endBuf = startBuf;
} else if (TypedefDecl *TDD = dyn_cast<TypedefDecl>(ND)) {
DeclLoc = TDD->getLocation();
startBuf = SM->getCharacterData(DeclLoc);
if (!strncmp("typedef ", startBuf, 8)) {
startBuf += 8; // skip the typedef...
DeclLoc = DeclLoc.getFileLocWithOffset(8);
}
endBuf = startBuf;
}
// FIXME: need to skip past the argument list...then check for ','.
while (*endBuf && *endBuf != '=' && *endBuf != ';')
endBuf++;
SourceLocation DeclEndLoc = DeclLoc.getFileLocWithOffset(endBuf-startBuf);
std::string Tag = "struct __closure_impl *" + std::string(ND->getName());
ReplaceText(DeclLoc, endBuf-startBuf, Tag.c_str(), Tag.size());
return;
}
void RewriteBlocks::RewriteBlockExpr(BlockExpr *Exp) {
Blocks.push_back(Exp);
bool haveByRefDecls = false;
// Add initializers for any closure decl refs.
GetBlockDeclRefExprs(Exp);
if (BlockDeclRefs.size()) {
// Unique all "by copy" declarations.
for (unsigned i = 0; i < BlockDeclRefs.size(); i++)
if (!BlockDeclRefs[i]->isByRef())
BlockByCopyDecls.insert(BlockDeclRefs[i]->getDecl());
// Unique all "by ref" declarations.
for (unsigned i = 0; i < BlockDeclRefs.size(); i++)
if (BlockDeclRefs[i]->isByRef()) {
haveByRefDecls = true;
BlockByRefDecls.insert(BlockDeclRefs[i]->getDecl());
}
}
std::string FuncName;
if (CurFunctionDef)
FuncName = std::string(CurFunctionDef->getName());
else if (CurMethodDef) {
FuncName = std::string(CurMethodDef->getSelector().getName());
// Convert colons to underscores.
std::string::size_type loc = 0;
while ((loc = FuncName.find(":", loc)) != std::string::npos)
FuncName.replace(loc, 1, "_");
}
std::string BlockNumber = utostr(Blocks.size()-1);
std::string Tag = "struct __" + FuncName + "_closure_impl_" + BlockNumber;
std::string Func = "__" + FuncName + "_" + "closure_" + BlockNumber;
// Rewrite the closure block with a compound literal. The first cast is
// to prevent warnings from the C compiler.
std::string Init = "(struct __closure_impl *)&(" + Tag + "){{0,";
// Initialize the Flags, Size, and Invoke fields.
Init += (haveByRefDecls ? "HAS_BYREF," : "0,");
Init += "sizeof(" + Tag + ")," + Func + "}";
// Add initializers for any closure decl refs.
if (BlockDeclRefs.size()) {
// Output all "by copy" declarations.
for (llvm::SmallPtrSet<ValueDecl*,8>::iterator I = BlockByCopyDecls.begin(),
E = BlockByCopyDecls.end(); I != E; ++I) {
Init += ",";
if (isObjCType((*I)->getType())) {
Init += "[[";
Init += (*I)->getName();
Init += " retain] autorelease]";
} else {
Init += (*I)->getName();
}
}
// Output all "by ref" declarations.
for (llvm::SmallPtrSet<ValueDecl*,8>::iterator I = BlockByRefDecls.begin(),
E = BlockByRefDecls.end(); I != E; ++I) {
Init += ",&";
Init += (*I)->getName();
}
}
Init += "}";
BlockDeclRefs.clear();
BlockByRefDecls.clear();
BlockByCopyDecls.clear();
// Do the rewrite.
const char *startBuf = SM->getCharacterData(Exp->getLocStart());
const char *endBuf = SM->getCharacterData(Exp->getLocEnd());
ReplaceText(Exp->getLocStart(), endBuf-startBuf+1, Init.c_str(), Init.size());
return;
}