import torch import torch.nn as nn import torch.onnx import onnxruntime as ort import numpy as np # 간단한 PyTorch 모델 정의 class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.fc = nn.Linear(10, 5) # 입력 10, 출력 5 def forward(self, x): return self.fc(x) # 모델 초기화 model = SimpleModel() model.eval() # 추론 모드로 전환 # 더미 입력 데이터 생성 (ONNX에서 동작 확인용) dummy_input = torch.randn(1, 10) # 배치 크기 1, 입력 크기 10 # ONNX 파일로 저장 onnx_file_path = "simple_model.onnx" torch.onnx.export( model, dummy_input, onnx_file_path, input_names=["input"], # 입력 텐서 이름 output_names=["output"], # 출력 텐서 이름 dynamic_axes={ "input": {0: "batch_size"}, "output": {0: "batch_size"}, }, # 동적 배치 처리 opset_version=11, # ONNX Opset 버전 ) print(f"ONNX 모델이 {onnx_file_path}에 저장되었습니다.") # ONNX 모델 로드 onnx_file_path = "simple_model.onnx" session = ort.InferenceSession(onnx_file_path) # ONNX 모델 입력/출력 이름 확인 sess_input = session.get_inputs()[0] sess_output = session.get_outputs()[0] print("sess_input and sess_output shape", sess_input.shape, sess_output.shape) # 더미 입력 데이터 생성 dummy_input = np.random.randn(1, 10).astype(np.float32) print("더미 입력 데이터:", dummy_input) # 모델 실행 input_name = sess_input.name output_name = sess_output.name outputs = session.run([output_name], {input_name: dummy_input}) print("ONNX 모델 출력:", outputs)