Last active
December 9, 2022 09:28
-
-
Save adkri/101342a72f475ec396a85440f188b18f to your computer and use it in GitHub Desktop.
Matrix multiplication in macos with MPS
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // Compile with: clang -framework MetalPerformanceShaders -framework Foundation -framework Metal -framework CoreGraphics main.m -o my_matrix_mul | |
| #include <MetalPerformanceShaders/MetalPerformanceShaders.h> | |
| #include <stdio.h> | |
| int main() { | |
| id<MTLDevice> device = MTLCreateSystemDefaultDevice(); | |
| if (!device) { | |
| NSLog(@">> Error: failed to create system device!"); | |
| } | |
| if (!MPSSupportsMTLDevice(device)) { | |
| NSLog(@">> Error: MPS not supported!"); | |
| } | |
| // Create the two matrices that we want to multiply | |
| float matrixAData[] = { | |
| 1, 2, 3, 4, | |
| 5, 6, 7, 8 | |
| }; | |
| float matrixBData[] = { | |
| 1, 2, | |
| 3, 4, | |
| 5, 6, | |
| 7, 8 | |
| }; | |
| const int aRows = 2; | |
| const int aColumns = 4; | |
| const int bRows = 4; | |
| const int bColumns = 2; | |
| const int aRowBytes = aColumns * sizeof(float); | |
| const int bRowBytes = bColumns * sizeof(float); | |
| @try { | |
| id<MTLBuffer> matrixABuffer = [device newBufferWithBytes:matrixAData | |
| length:aRows * aColumns * sizeof(float) | |
| options:MTLResourceStorageModeShared]; | |
| MPSMatrixDescriptor *matrixADesc = | |
| [MPSMatrixDescriptor matrixDescriptorWithRows:aRows | |
| columns:aColumns | |
| rowBytes:aRowBytes | |
| dataType:MPSDataTypeFloat32]; | |
| MPSMatrix *matrixA = [[MPSMatrix alloc] initWithBuffer:matrixABuffer | |
| descriptor:matrixADesc]; | |
| id<MTLBuffer> matrixBBuffer = [device newBufferWithBytes:matrixBData | |
| length:bRows * bColumns * sizeof(float) | |
| options:MTLResourceStorageModeShared]; | |
| MPSMatrixDescriptor *matrixBDesc = | |
| [MPSMatrixDescriptor matrixDescriptorWithRows:bRows | |
| columns:bColumns | |
| rowBytes:bRowBytes | |
| dataType:MPSDataTypeFloat32]; | |
| MPSMatrix *matrixB = [[MPSMatrix alloc] initWithBuffer:matrixBBuffer | |
| descriptor:matrixBDesc]; | |
| // Create the MPSMatrixMultiplication object that will perform the multiplication | |
| MPSMatrixMultiplication *matrixMul = | |
| [[MPSMatrixMultiplication alloc] initWithDevice:device | |
| transposeLeft:NO | |
| transposeRight:NO | |
| resultRows:2 | |
| resultColumns:2 | |
| interiorColumns:4 | |
| alpha:1.0 | |
| beta: 0.0]; | |
| // create metal command buffer | |
| id<MTLCommandQueue> commandQueue = [device newCommandQueue]; | |
| id<MTLCommandBuffer> commandBuffer = [commandQueue commandBuffer]; | |
| // Create the output matrix | |
| float outputMatrixData[4] = {0}; | |
| id<MTLBuffer> outputBuffer = [device newBufferWithBytes:outputMatrixData | |
| length:aRows * bColumns * sizeof(float) | |
| options:MTLResourceStorageModeShared]; | |
| MPSMatrixDescriptor *outputDesc = [MPSMatrixDescriptor matrixDescriptorWithRows:aRows | |
| columns:bColumns | |
| rowBytes:bRowBytes | |
| dataType:MPSDataTypeFloat32]; | |
| MPSMatrix *outputMatrix = [[MPSMatrix alloc] initWithBuffer:outputBuffer descriptor:outputDesc]; | |
| // encode the multiplication into command buffer | |
| [matrixMul encodeToCommandBuffer:commandBuffer | |
| leftMatrix:matrixA | |
| rightMatrix:matrixB | |
| resultMatrix:outputMatrix]; | |
| // add handler for complete event | |
| [commandBuffer addCompletedHandler:^(id<MTLCommandBuffer> buffer) { | |
| float *resultData = (float*)outputMatrix.data.contents; | |
| printf("Result matrix:\n"); | |
| for (int i = 0; i < 2; i++) { | |
| for (int j = 0; j < 2; j++) { | |
| printf("%f ", resultData[i * 2 + j]); | |
| } | |
| printf("\n"); | |
| } | |
| }]; | |
| // schedule command buffer execution | |
| [commandBuffer commit]; | |
| // wait for completion handler to be called | |
| [commandBuffer waitUntilCompleted]; | |
| } @catch(NSException *exception) { | |
| NSLog(@"%@", exception.reason); | |
| } | |
| return 0; | |
| } | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment