fn _main() { println!("Hello, world!"); let registry = DialectRegistry::new(); register_all_dialects(®istry); let context = Context::new(); context.append_dialect_registry(®istry); context.load_all_available_dialects(); let location = Location::unknown(&context); let mut module = Module::new(location); let index_type = Type::index(&context); let float32_type = Type::float32(&context); let vector2_float32_type = Type::vector(&[2], float32_type); let i64_type: Type = IntegerType::new(&context, 64).into(); let tensor_type: Type = RankedTensorType::new(&[1], float32_type, None).into(); let _memref_load_type: Type = MemRefType::new(vector2_float32_type, &[1], None, None).into(); // see https://github.com/raviqqe/melior/issues/180 let array_attr: Attribute = unsafe { let raw_attr = mlirArrayAttrGet(context.to_raw(), 1, &IntegerAttribute::new(0, i64_type).to_raw()); Attribute::from_raw(raw_attr) }; // add two floats together module.body().append_operation(func::func( &context, StringAttribute::new(&context, "add"), TypeAttribute::new(FunctionType::new(&context, &[float32_type, float32_type], &[float32_type]).into()), { let block = Block::new(&[(float32_type, location), (float32_type, location)]); let sum = block.append_operation(arith::addf( block.argument(0).unwrap().into(), block.argument(1).unwrap().into(), location )); block.append_operation(func::r#return(&[sum.result(0).unwrap().into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location )); // testing vector ops module.body().append_operation(func::func( &context, StringAttribute::new(&context, "firstInVector"), TypeAttribute::new(FunctionType::new(&context, &[vector2_float32_type], &[float32_type]).into()), { // block arguments must match type attribute arguments let block = Block::new(&[(vector2_float32_type, location)]); let vector_extract_op = OperationBuilder::new("vector.extract", location) .add_attributes(&[( Identifier::new(&context, "position"), array_attr )]) .add_operands(&[block.argument(0).unwrap().into()]) .add_results(&[float32_type]) .build(); let vector_extract_op = block.append_operation(vector_extract_op); block.append_operation(func::r#return(&[vector_extract_op.result(0).unwrap().into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location )); // testing tensor ops module.body().append_operation(func::func( &context, StringAttribute::new(&context, "firstInTensor"), TypeAttribute::new(FunctionType::new(&context, &[tensor_type], &[float32_type]).into()), { // block arguments must match type attribute arguments let block = Block::new(&[(tensor_type, location)]); let constant_op = index::constant(&context, IntegerAttribute::new(0, index_type), location); let constant_op = block.append_operation(constant_op); let tensor_extract_op = OperationBuilder::new("tensor.extract", location) .add_operands(&[ // tensor: ranked tensor of any type values block.argument(0).unwrap().into(), // indices: index constant_op.result(0).unwrap().into() ]) .add_results(&[float32_type]) .build(); let tensor_extract_op = block.append_operation(tensor_extract_op); block.append_operation(func::r#return(&[tensor_extract_op.result(0).unwrap().into()], location)); let region = Region::new(); region.append_block(block); region }, &[], location )); let module_op = module.as_operation(); module_op.dump(); assert!(module_op.verify()); if true { // llvm let pass_manager = PassManager::new(&context); pass_manager.add_pass(pass::conversion::create_arith_to_llvm()); pass_manager.add_pass(pass::conversion::create_math_to_llvm()); pass_manager.add_pass(pass::conversion::create_func_to_llvm()); pass_manager.add_pass(pass::conversion::create_vector_to_llvm()); pass_manager.add_pass(pass::conversion::create_tensor_to_linalg()); pass_manager.add_pass(pass::conversion::create_linalg_to_llvm()); pass_manager.add_pass(pass::conversion::create_index_to_llvm_pass()); /* pass_manager.add_pass(pass::conversion::create_tensor_to_linalg()); pass_manager.add_pass(pass::conversion::create_linalg_to_standard()); pass_manager.add_pass(pass::conversion::create_linalg_to_llvm()); pass_manager.add_pass(pass::conversion::create_mem_ref_to_llvm()); pass_manager.add_pass(pass::conversion::create_gpu_to_llvm()); */ pass_manager.run(&mut module).unwrap(); } else { // spirv let pass_manager = PassManager::new(&context); pass_manager.add_pass(pass::conversion::create_arith_to_spirv()); pass_manager.add_pass(pass::conversion::create_math_to_spirv()); pass_manager.add_pass(pass::conversion::create_func_to_spirv()); pass_manager.add_pass(pass::conversion::create_vector_to_spirv()); //pass_manager.add_pass(pass::conversion::create_index_to_llvm_pass()); /* pass_manager.add_pass(pass::conversion::create_tensor_to_linalg()); pass_manager.add_pass(pass::conversion::create_linalg_to_standard()); pass_manager.add_pass(pass::conversion::create_linalg_to_llvm()); pass_manager.add_pass(pass::conversion::create_mem_ref_to_llvm()); pass_manager.add_pass(pass::conversion::create_gpu_to_llvm()); */ pass_manager.run(&mut module).unwrap(); } let module_op = module.as_operation(); module_op.dump(); assert!(module_op.verify()) }