Skip to content

Instantly share code, notes, and snippets.

@jhojin7
Created November 30, 2024 03:19
Show Gist options
  • Save jhojin7/e81bdb1396384bcd1f2a7a09d5f28b94 to your computer and use it in GitHub Desktop.
Save jhojin7/e81bdb1396384bcd1f2a7a09d5f28b94 to your computer and use it in GitHub Desktop.
import torch
import torch.nn as nn
import torch.onnx
import onnxruntime as ort
import numpy as np
# 간단한 PyTorch 모델 정의
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 5) # 입력 10, 출력 5
def forward(self, x):
return self.fc(x)
# 모델 초기화
model = SimpleModel()
model.eval() # 추론 모드로 전환
# 더미 입력 데이터 생성 (ONNX에서 동작 확인용)
dummy_input = torch.randn(1, 10) # 배치 크기 1, 입력 크기 10
# ONNX 파일로 저장
onnx_file_path = "simple_model.onnx"
torch.onnx.export(
model,
dummy_input,
onnx_file_path,
input_names=["input"], # 입력 텐서 이름
output_names=["output"], # 출력 텐서 이름
dynamic_axes={
"input": {0: "batch_size"},
"output": {0: "batch_size"},
}, # 동적 배치 처리
opset_version=11, # ONNX Opset 버전
)
print(f"ONNX 모델이 {onnx_file_path}에 저장되었습니다.")
# ONNX 모델 로드
onnx_file_path = "simple_model.onnx"
session = ort.InferenceSession(onnx_file_path)
# ONNX 모델 입력/출력 이름 확인
sess_input = session.get_inputs()[0]
sess_output = session.get_outputs()[0]
print("sess_input and sess_output shape", sess_input.shape, sess_output.shape)
# 더미 입력 데이터 생성
dummy_input = np.random.randn(1, 10).astype(np.float32)
print("더미 입력 데이터:", dummy_input)
# 모델 실행
input_name = sess_input.name
output_name = sess_output.name
outputs = session.run([output_name], {input_name: dummy_input})
print("ONNX 모델 출력:", outputs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment