Support overriding memory allocation functions
- Let users provide their own memory management functions for XNNPACK
PiperOrigin-RevId: 281355722
diff --git a/src/init.c b/src/init.c
index f43fcf1..f63d7a8 100644
--- a/src/init.c
+++ b/src/init.c
@@ -9,6 +9,7 @@
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
+#include <string.h>
#include <pthread.h>
@@ -31,6 +32,7 @@
#include <xnnpack/log.h>
#include <xnnpack/lut.h>
#include <xnnpack/maxpool.h>
+#include <xnnpack/memory.h>
#include <xnnpack/pad.h>
#include <xnnpack/params.h>
#include <xnnpack/pavgpool.h>
@@ -1150,7 +1152,7 @@
xnn_params.initialized = true;
}
-enum xnn_status xnn_initialize(void) {
+enum xnn_status xnn_initialize(const struct xnn_allocator* allocator) {
#ifndef __EMSCRIPTEN__
if (!cpuinfo_initialize()) {
return xnn_status_out_of_memory;
@@ -1158,6 +1160,15 @@
#endif
pthread_once(&init_guard, &init);
if (xnn_params.initialized) {
+ if (allocator != NULL) {
+ memcpy(&xnn_params.allocator, allocator, sizeof(struct xnn_allocator));
+ } else {
+ xnn_params.allocator.allocate = &xnn_allocate;
+ xnn_params.allocator.reallocate = &xnn_reallocate;
+ xnn_params.allocator.deallocate = &xnn_deallocate;
+ xnn_params.allocator.aligned_allocate = &xnn_aligned_allocate;
+ xnn_params.allocator.aligned_deallocate = &xnn_aligned_deallocate;
+ }
return xnn_status_success;
} else {
return xnn_status_unsupported_hardware;