#ifndef AUTO_MEMORY_H_ #define AUTO_MEMORY_H_ #include #include #include #include #if defined(_MSC_VER) // windows #if defined(BUILDING_DLL) #define NN_PUBLIC __declspec(dllexport) #elif defined(USING_DLL) #define NN_PUBLIC __declspec(dllimport) #else #define NN_PUBLIC #endif #else // unix #define NN_PUBLIC __attribute__((visibility("default"))) #endif static inline void **alignPointer(void **ptr, size_t alignment) { return (void **)((intptr_t)((unsigned char *)ptr + alignment - 1) & -alignment); } #ifdef __cplusplus extern "C" { #endif #define MEMORY_ALIGN_DEFAULT 64 NN_PUBLIC void *memoryAllocAlign(size_t size, size_t alignment) { assert(size > 0); void **origin = (void **)malloc(size + sizeof(void *) + alignment); assert(origin != NULL); if (!origin) { return NULL; } void **aligned = alignPointer(origin + 1, alignment); aligned[-1] = origin; return aligned; } NN_PUBLIC void *memoryCallocAlign(size_t size, size_t alignment) { assert(size > 0); void **origin = (void **)calloc(size + sizeof(void *) + alignment, 1); assert(origin != NULL); if (!origin) { return NULL; } void **aligned = alignPointer(origin + 1, alignment); aligned[-1] = origin; return aligned; } NN_PUBLIC void memoryFreeAlign(void *aligned) { if (aligned) { void *origin = ((void **)aligned)[-1]; free(origin); } } #ifdef __cplusplus } #endif namespace NN { template class AutoMemory { public: AutoStorage() { mSize = 0; mData = NULL; } AutoStorage(int size) { mData = (T *)memoryAllocAlign(sizeof(T) * size, MEMORY_ALIGN_DEFAULT); mSize = size; } ~AutoStorage() { if (NULL != mData) { memoryFreeAlign(mData); } } inline int size() const { return mSize; } void set(T *data, int size) { if (NULL != mData && mData != data) { memoryFreeAlign(mData); } mData = data; mSize = size; } void reset(int size) { if (NULL != mData) { memoryFreeAlign(mData); } mData = (T *)memoryAllocAlign(sizeof(T) * size, MNN_MEMORY_ALIGN_DEFAULT); mSize = size; } void release() { if (NULL != mData) { memoryFreeAlign(mData); mData = NULL; mSize = 0; } } void clear() { ::memset(mData, 0, mSize * sizeof(T)); } T *get() const { return mData; } private: T *mData = NULL; int mSize = 0; }; } // namespace NN #endif