Skip to content

Instantly share code, notes, and snippets.

@adkri
Last active December 9, 2022 09:28
Show Gist options
  • Select an option

  • Save adkri/101342a72f475ec396a85440f188b18f to your computer and use it in GitHub Desktop.

Select an option

Save adkri/101342a72f475ec396a85440f188b18f to your computer and use it in GitHub Desktop.
Matrix multiplication in macos with MPS
// Compile with: clang -framework MetalPerformanceShaders -framework Foundation -framework Metal -framework CoreGraphics main.m -o my_matrix_mul
#include <MetalPerformanceShaders/MetalPerformanceShaders.h>
#include <stdio.h>
int main() {
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
if (!device) {
NSLog(@">> Error: failed to create system device!");
}
if (!MPSSupportsMTLDevice(device)) {
NSLog(@">> Error: MPS not supported!");
}
// Create the two matrices that we want to multiply
float matrixAData[] = {
1, 2, 3, 4,
5, 6, 7, 8
};
float matrixBData[] = {
1, 2,
3, 4,
5, 6,
7, 8
};
const int aRows = 2;
const int aColumns = 4;
const int bRows = 4;
const int bColumns = 2;
const int aRowBytes = aColumns * sizeof(float);
const int bRowBytes = bColumns * sizeof(float);
@try {
id<MTLBuffer> matrixABuffer = [device newBufferWithBytes:matrixAData
length:aRows * aColumns * sizeof(float)
options:MTLResourceStorageModeShared];
MPSMatrixDescriptor *matrixADesc =
[MPSMatrixDescriptor matrixDescriptorWithRows:aRows
columns:aColumns
rowBytes:aRowBytes
dataType:MPSDataTypeFloat32];
MPSMatrix *matrixA = [[MPSMatrix alloc] initWithBuffer:matrixABuffer
descriptor:matrixADesc];
id<MTLBuffer> matrixBBuffer = [device newBufferWithBytes:matrixBData
length:bRows * bColumns * sizeof(float)
options:MTLResourceStorageModeShared];
MPSMatrixDescriptor *matrixBDesc =
[MPSMatrixDescriptor matrixDescriptorWithRows:bRows
columns:bColumns
rowBytes:bRowBytes
dataType:MPSDataTypeFloat32];
MPSMatrix *matrixB = [[MPSMatrix alloc] initWithBuffer:matrixBBuffer
descriptor:matrixBDesc];
// Create the MPSMatrixMultiplication object that will perform the multiplication
MPSMatrixMultiplication *matrixMul =
[[MPSMatrixMultiplication alloc] initWithDevice:device
transposeLeft:NO
transposeRight:NO
resultRows:2
resultColumns:2
interiorColumns:4
alpha:1.0
beta: 0.0];
// create metal command buffer
id<MTLCommandQueue> commandQueue = [device newCommandQueue];
id<MTLCommandBuffer> commandBuffer = [commandQueue commandBuffer];
// Create the output matrix
float outputMatrixData[4] = {0};
id<MTLBuffer> outputBuffer = [device newBufferWithBytes:outputMatrixData
length:aRows * bColumns * sizeof(float)
options:MTLResourceStorageModeShared];
MPSMatrixDescriptor *outputDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:aRows
columns:bColumns
rowBytes:bRowBytes
dataType:MPSDataTypeFloat32];
MPSMatrix *outputMatrix = [[MPSMatrix alloc] initWithBuffer:outputBuffer descriptor:outputDesc];
// encode the multiplication into command buffer
[matrixMul encodeToCommandBuffer:commandBuffer
leftMatrix:matrixA
rightMatrix:matrixB
resultMatrix:outputMatrix];
// add handler for complete event
[commandBuffer addCompletedHandler:^(id<MTLCommandBuffer> buffer) {
float *resultData = (float*)outputMatrix.data.contents;
printf("Result matrix:\n");
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
printf("%f ", resultData[i * 2 + j]);
}
printf("\n");
}
}];
// schedule command buffer execution
[commandBuffer commit];
// wait for completion handler to be called
[commandBuffer waitUntilCompleted];
} @catch(NSException *exception) {
NSLog(@"%@", exception.reason);
}
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment