Trace external pointers through function returns (#1821)
* Trace external pointers through function returns
Surprisingly, the rewriter wasn't able to trace external pointers
returned by inlined functions until now. This commit fixes it by
adding functions that return an external pointer to ProbeVisitor's
set of external pointers, along with the levels of indirection.
This change requires reversing a few traversals to visit called
functions before they are called. Then, we check the presence of an
external pointer on return statements and retrieve that information
at the call expression.
* Tests dereferences of ext ptrs returned by inlined func
* tcpdrop: remove unnecessary bpf_probe_read calls
e783567a makes these calls unnecessary.
diff --git a/src/cc/frontends/clang/b_frontend_action.cc b/src/cc/frontends/clang/b_frontend_action.cc
index ab9b6f4..dad1e29 100644
--- a/src/cc/frontends/clang/b_frontend_action.cc
+++ b/src/cc/frontends/clang/b_frontend_action.cc
@@ -107,6 +107,23 @@
: ProbeChecker(arg, ptregs, is_transitive, false) {}
bool VisitCallExpr(CallExpr *E) {
needs_probe_ = false;
+
+ if (is_assign_) {
+ // We're looking for a function that returns an external pointer,
+ // regardless of the number of dereferences.
+ for(auto p : ptregs_) {
+ if (std::get<0>(p) == E->getDirectCallee()) {
+ needs_probe_ = true;
+ nb_derefs_ += std::get<1>(p);
+ return false;
+ }
+ }
+ } else {
+ tuple<Decl *, int> pt = make_tuple(E->getDirectCallee(), nb_derefs_);
+ if (ptregs_.find(pt) != ptregs_.end())
+ needs_probe_ = true;
+ }
+
if (!track_helpers_)
return false;
if (VarDecl *V = dyn_cast<VarDecl>(E->getCalleeDecl()))
@@ -220,6 +237,12 @@
return true;
}
+ /* If the expression contains a call to another function, we need to visit
+ * that function first to know if a rewrite is necessary (i.e., if the
+ * function returns an external pointer). */
+ if (!TraverseStmt(E))
+ return false;
+
ProbeChecker checker = ProbeChecker(E, ptregs_, track_helpers_,
true);
if (checker.is_transitive()) {
@@ -231,8 +254,8 @@
return true;
}
- if (E->getStmtClass() == Stmt::CallExprClass) {
- CallExpr *Call = dyn_cast<CallExpr>(E);
+ if (E->IgnoreParenCasts()->getStmtClass() == Stmt::CallExprClass) {
+ CallExpr *Call = dyn_cast<CallExpr>(E->IgnoreParenCasts());
if (MemberExpr *Memb = dyn_cast<MemberExpr>(Call->getCallee()->IgnoreImplicit())) {
StringRef memb_name = Memb->getMemberDecl()->getName();
if (DeclRefExpr *Ref = dyn_cast<DeclRefExpr>(Memb->getBase())) {
@@ -301,12 +324,49 @@
}
if (fn_visited_.find(F) == fn_visited_.end()) {
fn_visited_.insert(F);
+ /* Maintains a stack of the number of dereferences for the external
+ * pointers returned by each function in the call stack or -1 if the
+ * function didn't return an external pointer. */
+ ptregs_returned_.push_back(-1);
TraverseDecl(F);
+ int nb_derefs = ptregs_returned_.back();
+ ptregs_returned_.pop_back();
+ if (nb_derefs != -1) {
+ tuple<Decl *, int> pt = make_tuple(F, nb_derefs);
+ ptregs_.insert(pt);
+ }
}
}
}
return true;
}
+bool ProbeVisitor::VisitReturnStmt(ReturnStmt *R) {
+ /* If this function wasn't called by another, there's no need to check the
+ * return statement for external pointers. */
+ if (ptregs_returned_.size() == 0)
+ return true;
+
+ /* Reverse order of traversals. This is needed if, in the return statement,
+ * we're calling a function that's returning an external pointer: we need to
+ * know what the function is returning to decide what this function is
+ * returning. */
+ if (!TraverseStmt(R->getRetValue()))
+ return false;
+
+ ProbeChecker checker = ProbeChecker(R->getRetValue(), ptregs_,
+ track_helpers_, true);
+ if (checker.needs_probe()) {
+ int curr_nb_derefs = ptregs_returned_.back();
+ /* If the function returns external pointers with different levels of
+ * indirection, we handle the case with the highest level of indirection
+ * and leave it to the user to manually handle other cases. */
+ if (checker.get_nb_derefs() > curr_nb_derefs) {
+ ptregs_returned_.pop_back();
+ ptregs_returned_.push_back(checker.get_nb_derefs());
+ }
+ }
+ return true;
+}
bool ProbeVisitor::VisitBinaryOperator(BinaryOperator *E) {
if (!E->isAssignmentOp())
return true;
@@ -359,6 +419,15 @@
return false;
}
+ /* If the base of the dereference is a call to another function, we need to
+ * visit that function first to know if a rewrite is necessary (i.e., if the
+ * function returns an external pointer). */
+ if (base->IgnoreParenCasts()->getStmtClass() == Stmt::CallExprClass) {
+ CallExpr *Call = dyn_cast<CallExpr>(base->IgnoreParenCasts());
+ if (!TraverseStmt(Call))
+ return false;
+ }
+
// Checks to see if the expression references something that needs to be run
// through bpf_probe_read.
if (!ProbeChecker(base, ptregs_, track_helpers_).needs_probe())
diff --git a/src/cc/frontends/clang/b_frontend_action.h b/src/cc/frontends/clang/b_frontend_action.h
index 72498db..cf3cc59 100644
--- a/src/cc/frontends/clang/b_frontend_action.h
+++ b/src/cc/frontends/clang/b_frontend_action.h
@@ -98,6 +98,7 @@
bool VisitVarDecl(clang::VarDecl *Decl);
bool TraverseStmt(clang::Stmt *S);
bool VisitCallExpr(clang::CallExpr *Call);
+ bool VisitReturnStmt(clang::ReturnStmt *R);
bool VisitBinaryOperator(clang::BinaryOperator *E);
bool VisitUnaryOperator(clang::UnaryOperator *E);
bool VisitMemberExpr(clang::MemberExpr *E);
@@ -120,6 +121,7 @@
std::set<clang::Decl *> &m_;
clang::Decl *ctx_;
bool track_helpers_;
+ std::list<int> ptregs_returned_;
};
// A helper class to the frontend action, walks the decls
diff --git a/tests/python/test_clang.py b/tests/python/test_clang.py
index e291a9f..19bbb7b 100755
--- a/tests/python/test_clang.py
+++ b/tests/python/test_clang.py
@@ -940,6 +940,74 @@
b = BPF(text=text)
fn = b.load_func("test", BPF.SCHED_CLS)
+ def test_probe_read_return(self):
+ text = """
+#define KBUILD_MODNAME "foo"
+#include <uapi/linux/ptrace.h>
+#include <linux/tcp.h>
+static inline unsigned char *my_skb_transport_header(struct sk_buff *skb) {
+ return skb->head + skb->transport_header;
+}
+int test(struct pt_regs *ctx, struct sock *sk, struct sk_buff *skb) {
+ struct tcphdr *th = (struct tcphdr *)my_skb_transport_header(skb);
+ return th->seq;
+}
+"""
+ b = BPF(text=text)
+ fn = b.load_func("test", BPF.KPROBE)
+
+ def test_probe_read_multiple_return(self):
+ text = """
+#define KBUILD_MODNAME "foo"
+#include <uapi/linux/ptrace.h>
+#include <linux/tcp.h>
+static inline u64 error_function() {
+ return 0;
+}
+static inline unsigned char *my_skb_transport_header(struct sk_buff *skb) {
+ if (skb)
+ return skb->head + skb->transport_header;
+ return (unsigned char *)error_function();
+}
+int test(struct pt_regs *ctx, struct sock *sk, struct sk_buff *skb) {
+ struct tcphdr *th = (struct tcphdr *)my_skb_transport_header(skb);
+ return th->seq;
+}
+"""
+ b = BPF(text=text)
+ fn = b.load_func("test", BPF.KPROBE)
+
+ def test_probe_read_return_expr(self):
+ text = """
+#define KBUILD_MODNAME "foo"
+#include <uapi/linux/ptrace.h>
+#include <linux/tcp.h>
+static inline unsigned char *my_skb_transport_header(struct sk_buff *skb) {
+ return skb->head + skb->transport_header;
+}
+int test(struct pt_regs *ctx, struct sock *sk, struct sk_buff *skb) {
+ u32 *seq = (u32 *)my_skb_transport_header(skb) + offsetof(struct tcphdr, seq);
+ return *seq;
+}
+"""
+ b = BPF(text=text)
+ fn = b.load_func("test", BPF.KPROBE)
+
+ def test_probe_read_return_call(self):
+ text = """
+#define KBUILD_MODNAME "foo"
+#include <uapi/linux/ptrace.h>
+#include <linux/tcp.h>
+static inline struct tcphdr *my_skb_transport_header(struct sk_buff *skb) {
+ return (struct tcphdr *)skb->head + skb->transport_header;
+}
+int test(struct pt_regs *ctx, struct sock *sk, struct sk_buff *skb) {
+ return my_skb_transport_header(skb)->seq;
+}
+"""
+ b = BPF(text=text)
+ fn = b.load_func("test", BPF.KPROBE)
+
if __name__ == "__main__":
main()
diff --git a/tools/tcpdrop.py b/tools/tcpdrop.py
index 77ad752..9667868 100755
--- a/tools/tcpdrop.py
+++ b/tools/tcpdrop.py
@@ -107,16 +107,16 @@
u8 tcpflags = 0;
struct tcphdr *tcp = skb_to_tcphdr(skb);
struct iphdr *ip = skb_to_iphdr(skb);
- bpf_probe_read(&sport, sizeof(sport), &tcp->source);
- bpf_probe_read(&dport, sizeof(dport), &tcp->dest);
+ sport = tcp->source;
+ dport = tcp->dest;
bpf_probe_read(&tcpflags, sizeof(tcpflags), &tcp_flag_byte(tcp));
sport = ntohs(sport);
dport = ntohs(dport);
if (family == AF_INET) {
struct ipv4_data_t data4 = {.pid = pid, .ip = 4};
- bpf_probe_read(&data4.saddr, sizeof(u32), &ip->saddr);
- bpf_probe_read(&data4.daddr, sizeof(u32), &ip->daddr);
+ data4.saddr = ip->saddr;
+ data4.daddr = ip->daddr;
data4.dport = dport;
data4.sport = sport;
data4.state = state;