ruy_advanced API touchups: MulWithPrepacked does not need prepacked operands to be mutable, and PrepackedMatrix does not need accessor methods.
PiperOrigin-RevId: 308292063
diff --git a/ruy/matrix.h b/ruy/matrix.h
index 70e0726..5f480b7 100644
--- a/ruy/matrix.h
+++ b/ruy/matrix.h
@@ -168,15 +168,6 @@
// Opaque data structure representing a pre-packed matrix, as obtained from
// Ruy's advanced API.
struct PrepackedMatrix {
- void* get_data() const { return data; }
- void set_data(void* ptr) { data = ptr; }
- int get_data_size() const { return data_size; }
- void set_data_size(int value) { data_size = value; }
- void* get_sums() const { return sums; }
- void set_sums(void* ptr) { sums = ptr; }
- int get_sums_size() const { return sums_size; }
- void set_sums_size(int value) { sums_size = value; }
-
void* data = nullptr;
int data_size = 0;
void* sums = nullptr;
diff --git a/ruy/prepack.h b/ruy/prepack.h
index e16561a..3ee6a7c 100644
--- a/ruy/prepack.h
+++ b/ruy/prepack.h
@@ -77,7 +77,7 @@
const Mat<RhsScalar>& rhs,
const MulParamsType& mul_params, Ctx* ctx,
Mat<DstScalar>* dst,
- SidePair<PrepackedMatrix*> prepacked) {
+ SidePair<const PrepackedMatrix*> prepacked) {
profiler::ScopeLabel label("MulWithPrepacked");
EnforceLayoutSupport<MulParamsType>(lhs.layout, rhs.layout, dst->layout);
diff --git a/ruy/ruy_advanced.h b/ruy/ruy_advanced.h
index 8b4c9ef..3a4d9be 100644
--- a/ruy/ruy_advanced.h
+++ b/ruy/ruy_advanced.h
@@ -65,12 +65,13 @@
void MulWithPrepacked(const Matrix<LhsScalar>& lhs,
const Matrix<RhsScalar>& rhs,
const MulParamsType& mul_params, Context* context,
- Matrix<DstScalar>* dst, PrepackedMatrix* prepacked_lhs,
- PrepackedMatrix* prepacked_rhs) {
+ Matrix<DstScalar>* dst,
+ const PrepackedMatrix* prepacked_lhs,
+ const PrepackedMatrix* prepacked_rhs) {
Mat<LhsScalar> internal_lhs = ToInternal(lhs);
Mat<RhsScalar> internal_rhs = ToInternal(rhs);
Mat<DstScalar> internal_dst = ToInternal(*dst);
- SidePair<PrepackedMatrix*> prepacked(prepacked_lhs, prepacked_rhs);
+ SidePair<const PrepackedMatrix*> prepacked(prepacked_lhs, prepacked_rhs);
MulWithPrepackedInternal<CompiledPaths>(internal_lhs, internal_rhs,
mul_params, get_ctx(context),
diff --git a/ruy/test.h b/ruy/test.h
index 5234878..8f50899 100644
--- a/ruy/test.h
+++ b/ruy/test.h
@@ -1950,10 +1950,10 @@
rhs.matrix.set_data(cold_rhs.Next());
result->storage_matrix.matrix.set_data(cold_dst.Next());
if (benchmark_prepack_lhs) {
- result->prepacked_lhs.set_data(cold_prepacked_lhs.Next());
+ result->prepacked_lhs.data = cold_prepacked_lhs.Next();
}
if (benchmark_prepack_rhs) {
- result->prepacked_rhs.set_data(cold_prepacked_rhs.Next());
+ result->prepacked_rhs.data = cold_prepacked_rhs.Next();
}
}
EvalResult(result);
@@ -2014,8 +2014,8 @@
memcpy(orig_dst_data, result->storage_matrix.matrix.data(),
StorageSize(result->storage_matrix.matrix));
result->storage_matrix.matrix.set_data(orig_dst_data);
- result->prepacked_lhs.set_data(orig_prepacked_lhs_data);
- result->prepacked_rhs.set_data(orig_prepacked_rhs_data);
+ result->prepacked_lhs.data = orig_prepacked_lhs_data;
+ result->prepacked_rhs.data = orig_prepacked_rhs_data;
}
}