Emit operator== for structs with derive(PartialEq)
diff --git a/gen/src/write.rs b/gen/src/write.rs
index 60936a0..54c0bba 100644
--- a/gen/src/write.rs
+++ b/gen/src/write.rs
@@ -5,8 +5,8 @@
use crate::syntax::atom::Atom::{self, *};
use crate::syntax::symbol::Symbol;
use crate::syntax::{
- mangle, Api, Enum, ExternFn, ExternType, Pair, ResolvableName, Signature, Struct, Type, Types,
- Var,
+ derive, mangle, Api, Enum, ExternFn, ExternType, Pair, ResolvableName, Signature, Struct,
+ Trait, Type, Types, Var,
};
use proc_macro2::Ident;
use std::collections::{HashMap, HashSet};
@@ -128,6 +128,7 @@
if !out.header {
for api in apis {
match api {
+ Api::Struct(strct) => write_struct_operator_decls(out, strct),
Api::CxxFunction(efn) => write_cxx_function_shim(out, efn),
Api::RustFunction(efn) => write_rust_function_decl(out, efn),
_ => {}
@@ -136,9 +137,13 @@
}
for api in apis {
- if let Api::RustFunction(efn) = api {
- out.next_section();
- write_rust_function_shim(out, efn);
+ match api {
+ Api::Struct(strct) => write_struct_operators(out, strct),
+ Api::RustFunction(efn) => {
+ out.next_section();
+ write_rust_function_shim(out, efn);
+ }
+ _ => {}
}
}
}
@@ -175,6 +180,8 @@
}
fn write_struct<'a>(out: &mut OutFile<'a>, strct: &'a Struct, methods: &[&ExternFn]) {
+ let operator_eq = derive::contains(&strct.derives, Trait::PartialEq);
+
out.set_namespace(&strct.name.namespace);
let guard = format!("CXXBRIDGE1_STRUCT_{}", strct.name.to_symbol());
writeln!(out, "#ifndef {}", guard);
@@ -188,7 +195,7 @@
write_type_space(out, &field.ty);
writeln!(out, "{};", field.ident);
}
- if !methods.is_empty() {
+ if !methods.is_empty() || operator_eq {
writeln!(out);
}
for method in methods {
@@ -198,6 +205,18 @@
write_rust_function_shim_decl(out, &local_name, sig, false);
writeln!(out, ";");
}
+ if operator_eq {
+ writeln!(
+ out,
+ " bool operator==(const {} &) const noexcept;",
+ strct.name.cxx,
+ );
+ writeln!(
+ out,
+ " bool operator!=(const {} &) const noexcept;",
+ strct.name.cxx,
+ );
+ }
writeln!(out, "}};");
writeln!(out, "#endif // {}", guard);
}
@@ -330,6 +349,50 @@
);
}
+fn write_struct_operator_decls<'a>(out: &mut OutFile<'a>, strct: &'a Struct) {
+ out.set_namespace(&strct.name.namespace);
+ out.begin_block(Block::ExternC);
+
+ if derive::contains(&strct.derives, Trait::PartialEq) {
+ let link_name = mangle::operator(&strct.name, "__operator_eq");
+ writeln!(
+ out,
+ "bool {}(const {1} &, const {1} &) noexcept;",
+ link_name, strct.name.cxx,
+ );
+ }
+
+ out.end_block(Block::ExternC);
+}
+
+fn write_struct_operators<'a>(out: &mut OutFile<'a>, strct: &'a Struct) {
+ if out.header {
+ return;
+ }
+
+ out.set_namespace(&strct.name.namespace);
+
+ if derive::contains(&strct.derives, Trait::PartialEq) {
+ let link_name = mangle::operator(&strct.name, "__operator_eq");
+ out.next_section();
+ writeln!(
+ out,
+ "bool {0}::operator==(const {0} &rhs) const noexcept {{",
+ strct.name.cxx,
+ );
+ writeln!(out, " return {}(*this, rhs);", link_name);
+ writeln!(out, "}}");
+ out.next_section();
+ writeln!(
+ out,
+ "bool {0}::operator!=(const {0} &rhs) const noexcept {{",
+ strct.name.cxx,
+ );
+ writeln!(out, " return !(*this == rhs);");
+ writeln!(out, "}}");
+ }
+}
+
fn write_cxx_function_shim<'a>(out: &mut OutFile<'a>, efn: &'a ExternFn) {
out.next_section();
out.set_namespace(&efn.name.namespace);