-
-
Save jettify/597694c870e35cfc78fa0f50b90c1d15 to your computer and use it in GitHub Desktop.
Revisions
-
zeryx created this gist
Oct 11, 2018 .There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,20 @@ cmake_minimum_required(VERSION 3.0 FATAL_ERROR) project(cpp_shim) set(CMAKE_PREFIX_PATH ../libtorch) find_package(Torch REQUIRED) find_package(OpenCV REQUIRED) add_executable(testing main.cpp) message(STATUS "OpenCV library status:") message(STATUS " config: ${OpenCV_DIR}") message(STATUS " version: ${OpenCV_VERSION}") message(STATUS " libraries: ${OpenCV_LIBS}") message(STATUS " include path: ${OpenCV_INCLUDE_DIRS}") message(STATUS "TORCHLIB: ${TORCH_LIBRARIES}") #target_include_directories(testing PRIVATE ${TORCH_INCLUDE_DIRS} ${OpenCV_INCLUDE_DIRS}) target_link_libraries(testing ${OpenCV_LIBS}) target_link_libraries(testing ${TORCH_LIBRARIES}) target_compile_definitions(testing PRIVATE -D_GLIBCXX_USE_CXX11_ABI=0) This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,38 @@ import torch import torch.nn as nn import torch.nn.functional as F from torch.jit import ScriptModule, script_method, trace class MyScriptModule(ScriptModule): # class MyScriptModule(nn.Module): def __init__(self): super(MyScriptModule, self).__init__() # trace produces a ScriptModule's conv1 and conv2 self.conv1 = trace(nn.Conv2d(3, 2, 5).to("cpu"), torch.rand(1, 3, 1266, 1900)) self.conv2 = trace(nn.Conv2d(2, 1, 5).to("cpu"), torch.rand(1, 2, 1266, 1900)) self.lin = trace(nn.Linear(1258*1892, 5), torch.rand(1258*1892)) @script_method def forward(self, input): input = F.relu(self.conv1(input)) input = F.relu(self.conv2(input)) input = input.squeeze() input = input.view(1258*1892) output = self.lin(input) return output test_module = MyScriptModule() print(test_module.graph) if __name__ == "__main__": test_module.save("/tmp/model.pl") # if __name__ == "__main__": # import numpy as np # from PIL import Image # img_path = "/tmp/cat_image.jpg" # img = np.asarray(Image.open(img_path)) # tensor = torch.from_numpy(img).float() # tensor = tensor.view(1, 3, tensor.shape[0], tensor.shape[1]) # test_module.forward(tensor) This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode charactersOriginal file line number Diff line number Diff line change @@ -0,0 +1,33 @@ // // Created by zeryx on 10/5/18. // #include <torch/script.h> #include <iostream> #include <memory> #include <opencv2/core/core.hpp> #include <opencv2/highgui/highgui.hpp> using namespace cv; int main() { std::string model_path = "/tmp/model.pl"; std::string image_path = "/tmp/cat_image.jpg"; Mat image = imread(image_path); std::vector<int64_t> sizes = {1, 3, image.rows, image.cols}; at::TensorOptions options(at::ScalarType::Byte); at::Tensor tensor_image = torch::from_blob(image.data, at::IntList(sizes), options); tensor_image = tensor_image.toType(at::kFloat); std::ifstream is (model_path, std::ifstream::binary); std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(is); std::vector<torch::jit::IValue> inputs; inputs.emplace_back(tensor_image); at::Tensor result = module->forward(inputs).toTensor(); auto max_result = result.max(0, true); auto max_index = std::get<1>(max_result).item<float>(); std::cout << max_index << std::endl; }