Merge pull request #863 from shodoco/master

support macro in call arguments
diff --git a/src/cc/frontends/clang/b_frontend_action.cc b/src/cc/frontends/clang/b_frontend_action.cc
index 9370386..7012146 100644
--- a/src/cc/frontends/clang/b_frontend_action.cc
+++ b/src/cc/frontends/clang/b_frontend_action.cc
@@ -336,6 +336,12 @@
 // to:
 //  bpf_table_foo_elem(bpf_pseudo_fd(table), &key [,&leaf])
 bool BTypeVisitor::VisitCallExpr(CallExpr *Call) {
+  // Get rewritten text given a source range, w/ expansion range applied
+  auto getRewrittenText = [this] (SourceRange R) {
+    auto r = rewriter_.getSourceMgr().getExpansionRange(R);
+    return rewriter_.getRewrittenText(r);
+  };
+
   // make sure node is a reference to a bpf table, which is assured by the
   // presence of the section("maps/<typename>") GNU __attribute__
   if (MemberExpr *Memb = dyn_cast<MemberExpr>(Call->getCallee()->IgnoreImplicit())) {
@@ -345,9 +351,8 @@
         if (!A->getName().startswith("maps"))
           return true;
 
-        SourceRange argRange(Call->getArg(0)->getLocStart(),
-                             Call->getArg(Call->getNumArgs()-1)->getLocEnd());
-        string args = rewriter_.getRewrittenText(argRange);
+        string args = getRewrittenText(SourceRange(Call->getArg(0)->getLocStart(),
+                                                   Call->getArg(Call->getNumArgs() - 1)->getLocEnd()));
 
         // find the table fd, which was opened at declaration time
         auto table_it = tables_.begin();
@@ -366,10 +371,8 @@
         if (memb_name == "lookup_or_init") {
           map_update_policy = "BPF_NOEXIST";
           string name = Ref->getDecl()->getName();
-          string arg0 = rewriter_.getRewrittenText(SourceRange(Call->getArg(0)->getLocStart(),
-                                                               Call->getArg(0)->getLocEnd()));
-          string arg1 = rewriter_.getRewrittenText(SourceRange(Call->getArg(1)->getLocStart(),
-                                                               Call->getArg(1)->getLocEnd()));
+          string arg0 = getRewrittenText(Call->getArg(0)->getSourceRange());
+          string arg1 = getRewrittenText(Call->getArg(1)->getSourceRange());
           string lookup = "bpf_map_lookup_elem_(bpf_pseudo_fd(1, " + fd + ")";
           string update = "bpf_map_update_elem_(bpf_pseudo_fd(1, " + fd + ")";
           txt  = "({typeof(" + name + ".leaf) *leaf = " + lookup + ", " + arg0 + "); ";
@@ -381,8 +384,7 @@
           txt += "leaf;})";
         } else if (memb_name == "increment") {
           string name = Ref->getDecl()->getName();
-          string arg0 = rewriter_.getRewrittenText(SourceRange(Call->getArg(0)->getLocStart(),
-                                                               Call->getArg(0)->getLocEnd()));
+          string arg0 = getRewrittenText(Call->getArg(0)->getSourceRange());
           string lookup = "bpf_map_lookup_elem_(bpf_pseudo_fd(1, " + fd + ")";
           string update = "bpf_map_update_elem_(bpf_pseudo_fd(1, " + fd + ")";
           txt  = "({ typeof(" + name + ".key) _key = " + arg0 + "; ";
@@ -394,21 +396,16 @@
           txt += "if (_leaf) (*_leaf)++; })";
         } else if (memb_name == "perf_submit") {
           string name = Ref->getDecl()->getName();
-          string arg0 = rewriter_.getRewrittenText(SourceRange(Call->getArg(0)->getLocStart(),
-                                                               Call->getArg(0)->getLocEnd()));
-          string args_other = rewriter_.getRewrittenText(SourceRange(Call->getArg(1)->getLocStart(),
-                                                                     Call->getArg(2)->getLocEnd()));
+          string arg0 = getRewrittenText(Call->getArg(0)->getSourceRange());
+          string args_other = getRewrittenText(SourceRange(Call->getArg(1)->getLocStart(),
+                                                           Call->getArg(2)->getLocEnd()));
           txt = "bpf_perf_event_output(" + arg0 + ", bpf_pseudo_fd(1, " + fd + ")";
           txt += ", bpf_get_smp_processor_id(), " + args_other + ")";
         } else if (memb_name == "perf_submit_skb") {
-          string skb = rewriter_.getRewrittenText(SourceRange(Call->getArg(0)->getLocStart(),
-                                                               Call->getArg(0)->getLocEnd()));
-          string skb_len = rewriter_.getRewrittenText(SourceRange(Call->getArg(1)->getLocStart(),
-                                                                  Call->getArg(1)->getLocEnd()));
-          string meta = rewriter_.getRewrittenText(SourceRange(Call->getArg(2)->getLocStart(),
-                                                               Call->getArg(2)->getLocEnd()));
-          string meta_len = rewriter_.getRewrittenText(SourceRange(Call->getArg(3)->getLocStart(),
-                                                                   Call->getArg(3)->getLocEnd()));
+          string skb = getRewrittenText(Call->getArg(0)->getSourceRange());
+          string skb_len = getRewrittenText(Call->getArg(1)->getSourceRange());
+          string meta = getRewrittenText(Call->getArg(2)->getSourceRange());
+          string meta_len = getRewrittenText(Call->getArg(3)->getSourceRange());
           txt = "bpf_perf_event_output(" +
             skb + ", " +
             "bpf_pseudo_fd(1, " + fd + "), " +
@@ -417,8 +414,7 @@
             meta_len + ");";
         } else if (memb_name == "get_stackid") {
             if (table_it->type == BPF_MAP_TYPE_STACK_TRACE) {
-              string arg0 = rewriter_.getRewrittenText(SourceRange(Call->getArg(0)->getLocStart(),
-                                                                   Call->getArg(0)->getLocEnd()));
+              string arg0 = getRewrittenText(Call->getArg(0)->getSourceRange());
               txt = "bpf_get_stackid(";
               txt += "bpf_pseudo_fd(1, " + fd + "), " + arg0;
               rewrite_end = Call->getArg(0)->getLocEnd();
@@ -474,7 +470,7 @@
 
         vector<string> args;
         for (auto arg : Call->arguments())
-          args.push_back(rewriter_.getRewrittenText(SourceRange(arg->getLocStart(), arg->getLocEnd())));
+          args.push_back(getRewrittenText(arg->getSourceRange()));
 
         string text;
         if (Decl->getName() == "incr_cksum_l3") {
diff --git a/tests/python/test_clang.py b/tests/python/test_clang.py
index 2d6e5bf..4725a84 100755
--- a/tests/python/test_clang.py
+++ b/tests/python/test_clang.py
@@ -352,5 +352,45 @@
         with self.assertRaises(Exception):
             b = BPF(text=text)
 
+    def test_call_macro_arg(self):
+        text = """
+BPF_TABLE("prog", u32, u32, jmp, 32);
+
+#define JMP_IDX_PIPE (1U << 1)
+
+enum action {
+    ACTION_PASS
+};
+
+int process(struct xdp_md *ctx) {
+    jmp.call((void *)ctx, ACTION_PASS);
+    jmp.call((void *)ctx, JMP_IDX_PIPE);
+    return XDP_PASS;
+}
+        """
+        b = BPF(text=text)
+        t = b["jmp"]
+        self.assertEquals(len(t), 32);
+
+    def test_update_macro_arg(self):
+        text = """
+BPF_TABLE("array", u32, u32, act, 32);
+
+#define JMP_IDX_PIPE (1U << 1)
+
+enum action {
+    ACTION_PASS
+};
+
+int process(struct xdp_md *ctx) {
+    act.increment(ACTION_PASS);
+    act.increment(JMP_IDX_PIPE);
+    return XDP_PASS;
+}
+        """
+        b = BPF(text=text)
+        t = b["act"]
+        self.assertEquals(len(t), 32);
+
 if __name__ == "__main__":
     main()