Skip to content

Instantly share code, notes, and snippets.

@jettify
Forked from zeryx/CMakeLists.txt
Created October 24, 2020 01:40
Show Gist options
  • Save jettify/597694c870e35cfc78fa0f50b90c1d15 to your computer and use it in GitHub Desktop.
Save jettify/597694c870e35cfc78fa0f50b90c1d15 to your computer and use it in GitHub Desktop.

Revisions

  1. @zeryx zeryx created this gist Oct 11, 2018.
    20 changes: 20 additions & 0 deletions CMakeLists.txt
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,20 @@
    cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
    project(cpp_shim)
    set(CMAKE_PREFIX_PATH ../libtorch)
    find_package(Torch REQUIRED)
    find_package(OpenCV REQUIRED)

    add_executable(testing main.cpp)

    message(STATUS "OpenCV library status:")
    message(STATUS " config: ${OpenCV_DIR}")
    message(STATUS " version: ${OpenCV_VERSION}")
    message(STATUS " libraries: ${OpenCV_LIBS}")
    message(STATUS " include path: ${OpenCV_INCLUDE_DIRS}")

    message(STATUS "TORCHLIB: ${TORCH_LIBRARIES}")
    #target_include_directories(testing PRIVATE ${TORCH_INCLUDE_DIRS} ${OpenCV_INCLUDE_DIRS})
    target_link_libraries(testing ${OpenCV_LIBS})
    target_link_libraries(testing ${TORCH_LIBRARIES})

    target_compile_definitions(testing PRIVATE -D_GLIBCXX_USE_CXX11_ABI=0)
    38 changes: 38 additions & 0 deletions generator.py
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,38 @@
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from torch.jit import ScriptModule, script_method, trace



    class MyScriptModule(ScriptModule):
    # class MyScriptModule(nn.Module):
    def __init__(self):
    super(MyScriptModule, self).__init__()
    # trace produces a ScriptModule's conv1 and conv2
    self.conv1 = trace(nn.Conv2d(3, 2, 5).to("cpu"), torch.rand(1, 3, 1266, 1900))
    self.conv2 = trace(nn.Conv2d(2, 1, 5).to("cpu"), torch.rand(1, 2, 1266, 1900))
    self.lin = trace(nn.Linear(1258*1892, 5), torch.rand(1258*1892))

    @script_method
    def forward(self, input):
    input = F.relu(self.conv1(input))
    input = F.relu(self.conv2(input))
    input = input.squeeze()
    input = input.view(1258*1892)
    output = self.lin(input)
    return output

    test_module = MyScriptModule()
    print(test_module.graph)
    if __name__ == "__main__":
    test_module.save("/tmp/model.pl")

    # if __name__ == "__main__":
    # import numpy as np
    # from PIL import Image
    # img_path = "/tmp/cat_image.jpg"
    # img = np.asarray(Image.open(img_path))
    # tensor = torch.from_numpy(img).float()
    # tensor = tensor.view(1, 3, tensor.shape[0], tensor.shape[1])
    # test_module.forward(tensor)
    33 changes: 33 additions & 0 deletions main.cpp
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,33 @@
    //
    // Created by zeryx on 10/5/18.
    //
    #include <torch/script.h>
    #include <iostream>
    #include <memory>

    #include <opencv2/core/core.hpp>
    #include <opencv2/highgui/highgui.hpp>

    using namespace cv;

    int main() {
    std::string model_path = "/tmp/model.pl";
    std::string image_path = "/tmp/cat_image.jpg";

    Mat image = imread(image_path);
    std::vector<int64_t> sizes = {1, 3, image.rows, image.cols};
    at::TensorOptions options(at::ScalarType::Byte);
    at::Tensor tensor_image = torch::from_blob(image.data, at::IntList(sizes), options);
    tensor_image = tensor_image.toType(at::kFloat);

    std::ifstream is (model_path, std::ifstream::binary);
    std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(is);
    std::vector<torch::jit::IValue> inputs;
    inputs.emplace_back(tensor_image);
    at::Tensor result = module->forward(inputs).toTensor();
    auto max_result = result.max(0, true);
    auto max_index = std::get<1>(max_result).item<float>();


    std::cout << max_index << std::endl;
    }