Emit operator== for structs with derive(PartialEq)
diff --git a/gen/src/write.rs b/gen/src/write.rs
index 60936a0..54c0bba 100644
--- a/gen/src/write.rs
+++ b/gen/src/write.rs
@@ -5,8 +5,8 @@
use crate::syntax::atom::Atom::{self, *};
use crate::syntax::symbol::Symbol;
use crate::syntax::{
- mangle, Api, Enum, ExternFn, ExternType, Pair, ResolvableName, Signature, Struct, Type, Types,
- Var,
+ derive, mangle, Api, Enum, ExternFn, ExternType, Pair, ResolvableName, Signature, Struct,
+ Trait, Type, Types, Var,
};
use proc_macro2::Ident;
use std::collections::{HashMap, HashSet};
@@ -128,6 +128,7 @@
if !out.header {
for api in apis {
match api {
+ Api::Struct(strct) => write_struct_operator_decls(out, strct),
Api::CxxFunction(efn) => write_cxx_function_shim(out, efn),
Api::RustFunction(efn) => write_rust_function_decl(out, efn),
_ => {}
@@ -136,9 +137,13 @@
}
for api in apis {
- if let Api::RustFunction(efn) = api {
- out.next_section();
- write_rust_function_shim(out, efn);
+ match api {
+ Api::Struct(strct) => write_struct_operators(out, strct),
+ Api::RustFunction(efn) => {
+ out.next_section();
+ write_rust_function_shim(out, efn);
+ }
+ _ => {}
}
}
}
@@ -175,6 +180,8 @@
}
fn write_struct<'a>(out: &mut OutFile<'a>, strct: &'a Struct, methods: &[&ExternFn]) {
+ let operator_eq = derive::contains(&strct.derives, Trait::PartialEq);
+
out.set_namespace(&strct.name.namespace);
let guard = format!("CXXBRIDGE1_STRUCT_{}", strct.name.to_symbol());
writeln!(out, "#ifndef {}", guard);
@@ -188,7 +195,7 @@
write_type_space(out, &field.ty);
writeln!(out, "{};", field.ident);
}
- if !methods.is_empty() {
+ if !methods.is_empty() || operator_eq {
writeln!(out);
}
for method in methods {
@@ -198,6 +205,18 @@
write_rust_function_shim_decl(out, &local_name, sig, false);
writeln!(out, ";");
}
+ if operator_eq {
+ writeln!(
+ out,
+ " bool operator==(const {} &) const noexcept;",
+ strct.name.cxx,
+ );
+ writeln!(
+ out,
+ " bool operator!=(const {} &) const noexcept;",
+ strct.name.cxx,
+ );
+ }
writeln!(out, "}};");
writeln!(out, "#endif // {}", guard);
}
@@ -330,6 +349,50 @@
);
}
+fn write_struct_operator_decls<'a>(out: &mut OutFile<'a>, strct: &'a Struct) {
+ out.set_namespace(&strct.name.namespace);
+ out.begin_block(Block::ExternC);
+
+ if derive::contains(&strct.derives, Trait::PartialEq) {
+ let link_name = mangle::operator(&strct.name, "__operator_eq");
+ writeln!(
+ out,
+ "bool {}(const {1} &, const {1} &) noexcept;",
+ link_name, strct.name.cxx,
+ );
+ }
+
+ out.end_block(Block::ExternC);
+}
+
+fn write_struct_operators<'a>(out: &mut OutFile<'a>, strct: &'a Struct) {
+ if out.header {
+ return;
+ }
+
+ out.set_namespace(&strct.name.namespace);
+
+ if derive::contains(&strct.derives, Trait::PartialEq) {
+ let link_name = mangle::operator(&strct.name, "__operator_eq");
+ out.next_section();
+ writeln!(
+ out,
+ "bool {0}::operator==(const {0} &rhs) const noexcept {{",
+ strct.name.cxx,
+ );
+ writeln!(out, " return {}(*this, rhs);", link_name);
+ writeln!(out, "}}");
+ out.next_section();
+ writeln!(
+ out,
+ "bool {0}::operator!=(const {0} &rhs) const noexcept {{",
+ strct.name.cxx,
+ );
+ writeln!(out, " return !(*this == rhs);");
+ writeln!(out, "}}");
+ }
+}
+
fn write_cxx_function_shim<'a>(out: &mut OutFile<'a>, efn: &'a ExternFn) {
out.next_section();
out.set_namespace(&efn.name.namespace);
diff --git a/macro/src/derive.rs b/macro/src/derive.rs
index 91d4c64..28dfae3 100644
--- a/macro/src/derive.rs
+++ b/macro/src/derive.rs
@@ -2,6 +2,8 @@
use proc_macro2::{Ident, Span, TokenStream};
use quote::{quote, quote_spanned, ToTokens};
+pub use crate::syntax::derive::*;
+
pub fn expand_struct(strct: &Struct, actual_derives: &mut Option<TokenStream>) -> TokenStream {
let mut expanded = TokenStream::new();
let mut traits = Vec::new();
diff --git a/macro/src/expand.rs b/macro/src/expand.rs
index 62ab95f..bd1aefa 100644
--- a/macro/src/expand.rs
+++ b/macro/src/expand.rs
@@ -5,7 +5,7 @@
use crate::syntax::symbol::Symbol;
use crate::syntax::{
self, check, mangle, Api, Enum, ExternFn, ExternType, Impl, Pair, ResolvableName, Signature,
- Struct, Type, TypeAlias, Types,
+ Struct, Trait, Type, TypeAlias, Types,
};
use proc_macro2::{Ident, Span, TokenStream};
use quote::{format_ident, quote, quote_spanned, ToTokens};
@@ -40,7 +40,10 @@
for api in apis {
match api {
Api::Include(_) | Api::Impl(_) => {}
- Api::Struct(strct) => expanded.extend(expand_struct(strct)),
+ Api::Struct(strct) => {
+ expanded.extend(expand_struct(strct));
+ hidden.extend(expand_struct_operators(strct));
+ }
Api::Enum(enm) => expanded.extend(expand_enum(enm)),
Api::CxxType(ety) => {
let ident = &ety.name.rust;
@@ -154,6 +157,30 @@
}
}
+fn expand_struct_operators(strct: &Struct) -> TokenStream {
+ let ident = &strct.name.rust;
+ let mut operators = TokenStream::new();
+
+ for derive in &strct.derives {
+ let span = derive.span;
+ match derive.what {
+ Trait::PartialEq => operators.extend({
+ let link_name = mangle::operator(&strct.name, "__operator_eq");
+ quote_spanned! {span=>
+ #[doc(hidden)]
+ #[export_name = #link_name]
+ extern "C" fn __operator_eq(lhs: &#ident, rhs: &#ident) -> bool {
+ *lhs == *rhs
+ }
+ }
+ }),
+ _ => {}
+ }
+ }
+
+ operators
+}
+
fn expand_enum(enm: &Enum) -> TokenStream {
let ident = &enm.name.rust;
let doc = &enm.doc;
diff --git a/syntax/mangle.rs b/syntax/mangle.rs
index 480a0ff..ef605ef 100644
--- a/syntax/mangle.rs
+++ b/syntax/mangle.rs
@@ -1,5 +1,5 @@
use crate::syntax::symbol::{self, Symbol};
-use crate::syntax::{ExternFn, Types};
+use crate::syntax::{ExternFn, Pair, Types};
use proc_macro2::Ident;
const CXXBRIDGE: &str = "cxxbridge1";
@@ -25,6 +25,10 @@
}
}
+pub fn operator(receiver: &Pair, operator: &'static str) -> Symbol {
+ join!(receiver.namespace, CXXBRIDGE, receiver.cxx, operator)
+}
+
// The C half of a function pointer trampoline.
pub fn c_trampoline(efn: &ExternFn, var: &Ident, types: &Types) -> Symbol {
join!(extern_fn(efn, types), var, 0)