Android ONNX 연동

 

본 포스트는 Pytorch로 개발한 커스텀 모델을 Android에서 실행시키기 위해 Onnx를 활용하고 싶은 분들을 위한 포스트 입니다.

참고 블로그: https://beeny-ds.tistory.com/22

ONNX

ONNX(Open Neural Network Exchange)는 서로 다른 DNN 프레임워크로 만들어진 모델들이 존재할 때, 모델끼리 서로 호환되면서 사용할 수 있도록 만들어진 공유 플랫폼이다.
따라서, Deploy 단계에서 다양한 디바이스(ex. 스마트폰)에서 사용할 때 활용하면 좋다.
또한 TensorRT등의 가속 라이브러리와 연동도 가능하다고 하니 실시간(Real-Time) ai 서비스를 위해서라면 꼭 익혀두는것이 좋을것으로 보인다.

코드를 통한 실행 과정

Juneer Deeplearning cookbook에서 전체 코드를 확인할 수 있다.

onnx export

def export_onnx(model:torch.nn, save_path:str):
    # 아래 예시는 배치사이즈 1, Fashion MNIST 데이터를 학습한 pytorch 모델에 대한 것임
    model.eval()
    dummy_input = torch.randn(1, 1, 28, 28)

    # onnx 모델의 가중치를 접근하기위해 name specifying이 가능한것으로 보임
    # input 부분에 대한 name specifying만 해도 충분할 것으로 보임
    input_names = ['actual_input_1'] + [f'learned_{i}' for i in range(20)]
    output_names = ['output1']

    torch.onnx.export(
        model, 
        dummy_input,
        os.path.join(save_path, 'ts_mn.onnx'),
        verbose=True,       # export 할 때 사람이 읽을 수 있도록 print문으로 콘솔창에 출력
        input_names=input_names, 
        output_names=output_names
        )

onnx inference

세션을 만들고, run을 추가적으로 해줘야 하는 작업이 필요하다.

def create_session(onnx_path:str):
    ort_sess = ort.InferenceSession(onnx_path)
    outputs = ort_sess.run(
        None,
        {'actual_input_1':np.random.randn(1,1,28,28).astype(np.float32)},  # 여기서는 numpy를 사용해야 한다고 함
                                                                           # fashion mnist의 데이터에 맞게, batch:1, channel: 1, height: 28, width: 28로 구성했다.
    )
    print(outputs)      # Fashion mnist 데이터셋의 10개 클래스에 대한 softmax 결과를 확인할 수 있다.

onnx 시각화

Netron 웹페이지에 onnx 파일을 upload하면 모델의 전체적인 architecture를 확인할 수 있다.

우측의 MODEL PROPERTIES 창은 ctrl + enter로 확인할 수 있다.