Rust: Deserialize parcelables in place
When deserializing inout- and out-argument parcelables, we want to
avoid overwriting fields that the server did not include in the reply.
This patch implements in-place deserialization for parcelables
which takes an existing parcelable and only overwrites the
decoded fields instead of the entire structure.
Bug: 186724059
Test: atest aidl_integration_test
Change-Id: I906f1efdcb2ecb27638550eda45cbaf42066a12f
diff --git a/generate_rust.cpp b/generate_rust.cpp
index 3817f64..1e96802 100644
--- a/generate_rust.cpp
+++ b/generate_rust.cpp
@@ -187,7 +187,7 @@
}
for (const AidlArgument* arg : method.GetOutArguments()) {
- out << "*" << kArgumentPrefix << arg->GetName() << " = _aidl_reply.read()?;\n";
+ out << "_aidl_reply.read_onto(" << kArgumentPrefix << arg->GetName() << ")?;\n";
}
}
@@ -570,17 +570,16 @@
prologue << "if (parcel.get_data_position() - start_pos) == parcelable_size {\n";
// We assume the lhs can never be > parcelable_size, because then the read
// immediately preceding this check would have returned NOT_ENOUGH_DATA
- prologue << " return Ok(Some(result));\n";
+ prologue << " return Ok(());\n";
prologue << "}\n";
string prologue_str = prologue.str();
- out << "let mut result = Self::default();\n";
for (const auto& variable : parcel->GetFields()) {
out << prologue_str;
if (!TypeHasDefault(variable->GetType(), typenames)) {
- out << "result." << variable->GetName() << " = Some(parcel.read()?);\n";
+ out << "self." << variable->GetName() << " = Some(parcel.read()?);\n";
} else {
- out << "result." << variable->GetName() << " = parcel.read()?;\n";
+ out << "self." << variable->GetName() << " = parcel.read()?;\n";
}
}
// Now we read all fields.
@@ -588,7 +587,7 @@
out << "unsafe {\n";
out << " parcel.set_data_position(start_pos + parcelable_size)?;\n";
out << "}\n";
- out << "Ok(Some(result))\n";
+ out << "Ok(())\n";
}
void GenerateParcelBody(CodeWriter& out, const AidlUnionDecl* parcel,
@@ -668,7 +667,8 @@
} else {
out << "parcel.read()?;\n";
}
- out << "Ok(Some(Self::" << variable->GetCapitalizedName() << "(value)))\n";
+ out << "*self = Self::" << variable->GetCapitalizedName() << "(value);\n";
+ out << "Ok(())\n";
out.Dedent();
out << "}\n";
}
@@ -713,23 +713,15 @@
template <typename ParcelableType>
void GenerateParcelDeserialize(CodeWriter& out, const ParcelableType* parcel,
const AidlTypenames& typenames) {
- out << "impl binder::parcel::Deserialize for " << parcel->GetName() << " {\n";
- out << " fn deserialize(parcel: &binder::parcel::Parcel) -> binder::Result<Self> {\n";
- out << " <Self as binder::parcel::DeserializeOption>::deserialize_option(parcel)\n";
- out << " .transpose()\n";
- out << " .unwrap_or(Err(binder::StatusCode::UNEXPECTED_NULL))\n";
- out << " }\n";
- out << "}\n";
+ out << "binder::impl_deserialize_for_parcelable!(" << parcel->GetName() << ");\n";
- out << "impl binder::parcel::DeserializeArray for " << parcel->GetName() << " {}\n";
-
- out << "impl binder::parcel::DeserializeOption for " << parcel->GetName() << " {\n";
+ // The actual deserialization code lives in the private
+ // deserialize_parcelable() method which we emit here.
+ out << "impl " << parcel->GetName() << " {\n";
out.Indent();
- out << "fn deserialize_option(parcel: &binder::parcel::Parcel) -> binder::Result<Option<Self>> "
- "{\n";
+ out << "fn deserialize_parcelable(&mut self, "
+ "parcel: &binder::parcel::Parcel) -> binder::Result<()> {\n";
out.Indent();
- out << "let status: i32 = parcel.read()?;\n";
- out << "if status == 0 { return Ok(None); }\n";
GenerateParcelDeserializeBody(out, parcel, typenames);