Skip to content

Instantly share code, notes, and snippets.

@nullhook
Last active February 16, 2024 16:43
Show Gist options
  • Select an option

  • Save nullhook/11d74c02dc42e061ade9528973fae7f4 to your computer and use it in GitHub Desktop.

Select an option

Save nullhook/11d74c02dc42e061ade9528973fae7f4 to your computer and use it in GitHub Desktop.
compute in metal
#include <iostream>
#define NS_PRIVATE_IMPLEMENTATION
#define CA_PRIVATE_IMPLEMENTATION
#define MTL_PRIVATE_IMPLEMENTATION
#include "Metal.hpp"
MTL::Buffer* outputs;
MTL::Buffer* input0;
int main() {
// both represents a metal context
MTL::Device* device = MTL::CreateSystemDefaultDevice();
MTL::CommandQueue* command_queue = device->newCommandQueue();
MTL::Library* library = device->newDefaultLibrary();
if (!library) assert(false);
MTL::Function* E_ = library->newFunction(NS::String::string("E_", NS::StringEncoding::UTF8StringEncoding));
NS::Error* error = nullptr;
MTL::ComputePipelineState* pso = device->newComputePipelineState(E_, &error);
if (!pso) {
std::cerr << error->localizedDescription()->utf8String() << "\n";
assert(false);
}
// all buffers are effectively shared on apple silicon
// due to unified memory architecture you can get away with
// forgetting to call didModifyRange.
// API contract says you must call didModifyRange, but the driver doesn’t enforce that on apple silicon
outputs = device->newBuffer(4, MTL::ResourceStorageModeManaged);
input0 = device->newBuffer(4, MTL::ResourceStorageModeManaged);
const float a{10.0};
memcpy(input0->contents(), &a, sizeof(float));
// input0->didModifyRange(NS::Range::Make( 0, sizeof(float)));
MTL::CommandBuffer* cmd_buff = command_queue->commandBuffer();
MTL::ComputeCommandEncoder* cmd_enc = cmd_buff->computeCommandEncoder();
// pass buffers to compute shaders
cmd_enc->setComputePipelineState(pso);
cmd_enc->setBuffer(outputs, 0, 0);
cmd_enc->setBuffer(input0, 0, 1);
cmd_enc->dispatchThreadgroups(MTL::Size({1, 1, 1}), MTL::Size({1, 1, 1}));
cmd_enc->endEncoding();
// cmd_buff->addCompletedHandler([](const MTL::CommandBuffer* ignored) {
// float* out = static_cast<float*>(outputs->contents());
// if (out != nullptr) {
// printf("%.2f\n", *out);
// }
// });
cmd_buff->commit();
// figure out runloop so we can read output buf
cmd_buff->waitUntilCompleted();
float* out = static_cast<float*>(outputs->contents());
if (out != nullptr) {
printf("%.2f\n", *out);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment