Last active
February 16, 2024 16:43
-
-
Save nullhook/11d74c02dc42e061ade9528973fae7f4 to your computer and use it in GitHub Desktop.
compute in metal
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #include <iostream> | |
| #define NS_PRIVATE_IMPLEMENTATION | |
| #define CA_PRIVATE_IMPLEMENTATION | |
| #define MTL_PRIVATE_IMPLEMENTATION | |
| #include "Metal.hpp" | |
| MTL::Buffer* outputs; | |
| MTL::Buffer* input0; | |
| int main() { | |
| // both represents a metal context | |
| MTL::Device* device = MTL::CreateSystemDefaultDevice(); | |
| MTL::CommandQueue* command_queue = device->newCommandQueue(); | |
| MTL::Library* library = device->newDefaultLibrary(); | |
| if (!library) assert(false); | |
| MTL::Function* E_ = library->newFunction(NS::String::string("E_", NS::StringEncoding::UTF8StringEncoding)); | |
| NS::Error* error = nullptr; | |
| MTL::ComputePipelineState* pso = device->newComputePipelineState(E_, &error); | |
| if (!pso) { | |
| std::cerr << error->localizedDescription()->utf8String() << "\n"; | |
| assert(false); | |
| } | |
| // all buffers are effectively shared on apple silicon | |
| // due to unified memory architecture you can get away with | |
| // forgetting to call didModifyRange. | |
| // API contract says you must call didModifyRange, but the driver doesn’t enforce that on apple silicon | |
| outputs = device->newBuffer(4, MTL::ResourceStorageModeManaged); | |
| input0 = device->newBuffer(4, MTL::ResourceStorageModeManaged); | |
| const float a{10.0}; | |
| memcpy(input0->contents(), &a, sizeof(float)); | |
| // input0->didModifyRange(NS::Range::Make( 0, sizeof(float))); | |
| MTL::CommandBuffer* cmd_buff = command_queue->commandBuffer(); | |
| MTL::ComputeCommandEncoder* cmd_enc = cmd_buff->computeCommandEncoder(); | |
| // pass buffers to compute shaders | |
| cmd_enc->setComputePipelineState(pso); | |
| cmd_enc->setBuffer(outputs, 0, 0); | |
| cmd_enc->setBuffer(input0, 0, 1); | |
| cmd_enc->dispatchThreadgroups(MTL::Size({1, 1, 1}), MTL::Size({1, 1, 1})); | |
| cmd_enc->endEncoding(); | |
| // cmd_buff->addCompletedHandler([](const MTL::CommandBuffer* ignored) { | |
| // float* out = static_cast<float*>(outputs->contents()); | |
| // if (out != nullptr) { | |
| // printf("%.2f\n", *out); | |
| // } | |
| // }); | |
| cmd_buff->commit(); | |
| // figure out runloop so we can read output buf | |
| cmd_buff->waitUntilCompleted(); | |
| float* out = static_cast<float*>(outputs->contents()); | |
| if (out != nullptr) { | |
| printf("%.2f\n", *out); | |
| } | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment