SageMaker Serverless Inferenceで作る格安画像認識API
前回の記事で鳥画像を二値分類する画像認識モデルを作成しました。このモデルを外部から利用するために、SageMaker Serverless Inferenceを使ってAPIを構築します。
SageMaker Serverless Inferenceを採用した理由
あまり頻繁に利用するものでもないのでコストは低く抑えたいという要望がある場合、サーバレスなサービスを利用するのが最適です。
lambdaもありますが、機械学習用のパッケージを導入したdockerイメージの準備が必要となるので、気軽に試せるSageMaker Serverless Inferenceを採用しました。
※以前はlambdaで利用できるパッケージの容量が250MBに制限される、という課題がありましたが、現在は新機能がリリースされ、10GBのDockerコンテナのデプロイ可能と改善されています。
構築までの流れ
大まかに分けて以下の3つの手順を踏む必要があります。今回は訓練済みモデルを利用するので、ノートブックで学習を行うフェーズはなく、S3へ学習モデルをアップロードするところから始めます。

機械学習モデルをS3にアップロードする
エンドポイントに必要な情報をデプロイするにあたって、まず初めに画像認識モデルをS3にアップロードする必要があります。アップロードしたS3のモデルはNotebookを通じてエンドポイントインスタンスにアップロードされる流れとなります。

モデルを圧縮する
S3にアップロードするにあたってはモデルをtar.gz形式に圧縮する必要があります。今回は鳥画像の二値分類モデル「bird_torch.pth」を圧縮します。
bird_torch % tree ./
./
└── bird_torch.pth
bird_torch % tar czfv ../bird_torch.tar.gz .
a .
a ./bird_torch.pth
bird_torch % ls ../
./ ../ bird_torch/ bird_torch.tar.gz
S3へアップロードする
先ほど作成したtar.gzファイルをS3にアップロードします。非公開でも公開でも構いませんが、SageMaker Notebook側でS3へのアクセス権限を付与する必要があります。

ノートブックからエンドポイントを作成する
次にSageMakerノートブックを作成し、エンドポイントデプロイ用のスクリプトなどを準備してきます。
ノートブックインスタンスの作成
下記の手順でノートブックを作成します。

ノートブックインスタンス設定ではインスタンス名は任意の名前を、インスタンスタイプは「ml.t2.medium」、その他はデフォルト値で設定します。

アクセス許可と暗号化では新しいIAMを作成し、任意のS3バケットへアクセスできるように設定します。


他の値はデフォルト値のままでインスタンスを作成します。
ノートブックの準備
作成したインスタンスからJupyterを開き、ノートブックを作成します。今回は「conda_pytorch_p38」を採用します。

必要情報の読み込み
下記の通り必要情報をノートブックで実行します。
#モジュールの読み込み、ロールの設定などを行う
import sagemaker
from sagemaker.pytorch import PyTorchModel
from sagemaker.serverless import ServerlessInferenceConfig
sagemaker_session = sagemaker.Session()
role = sagemaker.get_execution_role()

利用するモデルの設定
デプロイするモデルの設定を行います。
- S3にアップロードしたモデルファイルのパス
- 学習したpytorchのframeworkバージョン
- pythonのバージョン
- entory_point用のスクリプトパス(後述)
を指定します。
# モデルの設定
pytorch_model = PyTorchModel(model_data="s3://*****/models/bird_torch.tar.gz",
role=role,
framework_version='1.11.0',
py_version="py38",
entry_point="entry_point.py")
※S3パスの「*****/models/bird_torch.tar.gz」は自身のモデルを配置したパスを指定してください。
※私はpytorchのバージョン1.12.1で学習したのですが、2022/8月時点ではノートブックでそのバージョンが利用できず1.11.0を指定しましたが、問題なく動きました。
entory_point用スクリプトの作成
さてここでデプロイするエンドポイントでモデルの読み込み、実行、返却などを実行するentory_pointスクリプトを作成します。
以下にentory_point.pyの全容を掲載します。
#entory_point.py
import logging
import torch
import torch.nn as nn
import numpy as np
from six import BytesIO
import os
import json
import torchvision.models as models
from torchvision import transforms
import cv2
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
JSON_CONTENT_TYPE = 'application/json'
PNG_CONTENT_TYPE = 'image/png'
JPG_CONTENT_TYPE = 'image/jpeg'
NPY_CONTENT_TYPE = 'application/x-npy'
class ImageTransform(object):
def __init__(self, resize, mean, std):
self.data_trasnform = {
'train': transforms.Compose([
# データオーグメンテーション
transforms.RandomHorizontalFlip(),
# 画像をresize×resizeの大きさに統一する
transforms.Resize((resize, resize)),
# Tensor型に変換する
transforms.ToTensor(),
# 色情報の標準化をする
transforms.Normalize(mean, std)
]),
'valid': transforms.Compose([
# 画像をresize×resizeの大きさに統一する
transforms.Resize((resize, resize)),
# Tensor型に変換する
transforms.ToTensor(),
# 色情報の標準化をする
transforms.Normalize(mean, std)
])
}
def __call__(self, img, phase='train'):
return self.data_trasnform[phase](img)
def model_fn(model_dir):
"""モデルのロード."""
logger.info('START model_fn')
model = models.resnet50()
model.fc = nn.Linear(model.fc.in_features, 2)
# モデルのパラメータ設定
with open(os.path.join(model_dir, 'bird_torch.pth'), 'rb') as f:
model.load_state_dict(torch.load(f, map_location=torch.device('cpu')))
logger.info('END model_fn')
return model
def input_fn(request_body, content_type=PNG_CONTENT_TYPE):
"""入力データの形式変換."""
logger.info('START input_fn')
logger.info(f'content_type: {content_type}')
logger.info(f'request_body: {request_body}')
logger.info(f'type: {type(request_body)}')
if content_type == PNG_CONTENT_TYPE:#PNGの場合の処理
#受信した画像データをバイナリデータとして受け取る
stream = BytesIO(request_body)
#バイナリデータをNumpy配列に変換
input_data = np.asarray(bytearray(stream.read()), dtype=np.uint8)
#Numpy配列を画像形式のTensorに変換する
input_data = cv2.imdecode(input_data, 1)
#Tensorを画像に変換する
transform_PIL = transforms.ToPILImage()
input_data = transform_PIL(input_data)
elif content_type == JPG_CONTENT_TYPE:#JPGの場合。PNGと同様の処理です
stream = BytesIO(request_body)
input_data = np.asarray(bytearray(stream.read()), dtype=np.uint8)
input_data = cv2.imdecode(input_data, 1)
transform_PIL = transforms.ToPILImage()
input_data = transform_PIL(input_data)
elif content_type == NPY_CONTENT_TYPE:
stream = BytesIO(request_body)
transform_PIL = transforms.ToPILImage()
input_data = transform_PIL(np.load(stream))
else:
# TODO: content_typeに応じてデータ型変換
logger.error(f"content_type invalid: {content_type}")
input_data = {"errors": [f"content_type invalid: {content_type}"]}
logger.info('END input_fn')
return input_data
def predict_fn(input_data, model):
"""推論."""
logger.info('START predict_fn')
if isinstance(input_data, dict) and 'errors' in input_data:
logger.info('SKIP predict_fn')
logger.info('END predict_fn')
return input_data
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.eval()
# 説明変数の標準化
# リサイズ先の画像サイズ
resize = 300
# 今回は簡易的に(0.5, 0.5, 0.5)で標準化
mean = (0.5, 0.5, 0.5)
std = (0.5, 0.5, 0.5)
transform = ImageTransform(resize, mean, std)
img_transformed = transform(input_data, 'valid').unsqueeze(0)
# 推論
with torch.no_grad():
logger.info('END predict_fn')
return model(img_transformed.to(device))
def output_fn(prediction, accept=JSON_CONTENT_TYPE):
"""出力データの形式変換."""
logger.info('START output_fn')
logger.info(f"accept: {accept}")
if isinstance(prediction, dict) and 'errors' in prediction:
logger.info('SKIP output_fn')
response = json.dumps(prediction)
content_type = JSON_CONTENT_TYPE
elif accept == JSON_CONTENT_TYPE:
#[0]だと「可愛い」判定、[1]だと「かっこいい」判定です
pred = []
pred += [int(l.argmax()) for l in prediction]
m = nn.Softmax(dim=1)#「可愛い」と「かっこいい」
response = json.dumps({"results": pred[0], "eval":{'kawaii':m(prediction)[0].tolist()[0], 'kakkoii':m(prediction)[0].tolist()[1]}})
content_type = JSON_CONTENT_TYPE
else:
#[0]だと「可愛い」判定、[1]だと「かっこいい」判定です
pred = []
pred += [int(l.argmax()) for l in prediction]
m = nn.Softmax(dim=1)#「可愛い」と「かっこいい」
response = json.dumps({"results": pred[0], "eval":{'kawaii':m(prediction)[0].tolist()[0], 'kakkoii':m(prediction)[0].tolist()[1]}})
content_type = JSON_CONTENT_TYPE
logger.info('END output_fn')
return response, content_type
if __name__ == '__main__':
logger.info("process main!")
pass
前項で説明した「PyTorchModel」関数を使って推論を行う場合、以下の関数が呼び出されます。
- model_fn
- input_fn
- predict_fn
- output_fn
デフォルトではこちらのgithubに記載の処理が走りますが、entory_pointで関数を上書きすることで、処理を変更することが可能です。
また、各関数は以下の流れで実行されます。(詳しくはこちらを参照のこと)
input_object = input_fn(request_body, request_content_type)
prediction = predict_fn(input_object, model)
output = output_fn(prediction, response_content_type)
さて、以下にentory_point.pyの各クラス、関数についての説明を記載します。
ImageTransformクラス
このクラスはinput_dataをモデルに読み込ませるために画像のサイズ変更、色情報の標準化などを行います。
model_fn関数
この関数ではS3に用意した「bird_torch.pth」の読み込みを行なっています。今回用意したモデルはPytrochに標準搭載されているresnet50を利用しているので、models.resnet50()でネットワークを設定した後、model.load_state_dictで学習済みのパラメータを読み込んでいます。
注意点として、デフォルトのresnet50モデルは1000種類の画像分類なのに対して、今回は二値分類を実行するモデルなので、「nn.Linear(model.fc.in_features, 2)」でネットワークの最後に二値分類用の層を追加しています。
model = models.resnet50()
model.fc = nn.Linear(model.fc.in_features, 2)
# モデルのパラメータ設定
with open(os.path.join(model_dir, 'bird_torch.pth'), 'rb') as f:
model.load_state_dict(torch.load(f, map_location=torch.device('cpu')))
input_fn関数
input_fnでは受け取ったデータをモデルで読み込めるように前処理を行います。受け取ったデータはバイナリデータとなっているので、ImageTransformクラスで処理可能な画像形式に変換するための前処理を走らせています。
※かなりハマったのですが、PNGやJPEGを受信した場合、Numpy配列に変換しただけでは画像に適したTensor形式になっていないため、cv2.imdecodeを使って正しいTensor形式に変換する必要がありました。
--------------PNGまたはJPEGを受信した場合の前処理--------------
#受信した画像データをバイナリデータとして受け取る
stream = BytesIO(request_body)
#バイナリデータをNumpy配列に変換
input_data = np.asarray(bytearray(stream.read()), dtype=np.uint8)
#print(input_data.shape)-->(82906,)
#Numpy配列を画像形式のTensorに変換する
input_data = cv2.imdecode(input_data, 1)
#print(input_data.shape)-->(480, 480, 3)
#Tensorを画像に変換する
transform_PIL = transforms.ToPILImage()
input_data = transform_PIL(input_data)
--------------ノートブック上での確認用の前処理--------------
stream = BytesIO(request_body)
transform_PIL = transforms.ToPILImage()
input_data = transform_PIL(np.load(stream))
predict_fn関数
input_dataをImageTransformクラスで読み込んで、モデルに流し込みます。
output_fn関数
predict_fnで吐き出した出力結果をJSON形式にして送信します。
エンドポイントをデプロイする
前項で説明したentory_point.pyをjupyterにアップロードし、下記を実行します。Serverless Inferenceを利用する場合はserverless_inference_configを指定するだけでOKです。
注意点としてinstance_typeがt2.mediumではメモリが足りなかったのでt3.mediumを使用しています。
# デプロイパラメータ
serverless_config = ServerlessInferenceConfig(
memory_size_in_mb=2048, #利用するメモリサイズ。インスタンスの最大メモリより大きくはできない
max_concurrency=2 #同時実行可能な数
)
deploy_params = {
'instance_type' : 'ml.t3.medium',
'initial_instance_count' : 1,
'serverless_inference_config' : serverless_config #ここでサーバレスインスタンスの利用設定を行う
}
# デプロイ
predictor = pytorch_model.deploy(**deploy_params)
デプロイには2, 3分かかります。成功すると、エンドポイントが下記のように作成されます。Serverless Inferenceを指定した場合はタイプがサーバレスとなっているはずです。

エンドポイントのテスト
さて、デプロイしたエンドポイントが上手く動作するかどうかノートブック上で確認します。「sample_1.jpeg」と名前をつけた画像をjupyterにアップロードし、以下のコードを実行してみてください。
from PIL import Image
#入力データ
file_name = 'sample_1.jpeg'
img = Image.open(file_name)
# 推論
results = predictor.predict(img)
print(results)
上手く動作すれば下記のような結果が得られます。

ちなみに、使った画像は以下の通りです。

外部からAPIでアクセスする
最後に、ローカルからエンドポイントへアクセスできるかどうかを確認します。ローカルで利用しているawsユーザーに対して、下記の通りSageMakerへのアクセス権限を付与します。
※ローカル環境でのaws cliの利用設定などの説明は省略します。「aws cli インストール」などで検索して調べてみてください。

権限を付与した状態で以下のスクリプトを実行してください。
# predict_request.py
import boto3
import json
import requests
def request2api():
# read image data
f = open("sample_0.jpeg", "rb")
reqbody = f.read()
f.close()
# Request
client = boto3.client('sagemaker-runtime')
endpoint_response = client.invoke_endpoint(
EndpointName='pytorch-inference-*****', #自分のエンドポイント名を記載してください
ContentType='image/jpeg',
Accept='application/json',
Body=reqbody
)
results = endpoint_response['Body'].read()
detections = json.loads(results)
print(detections)
if __name__ == '__main__':
request2api()
そうすると、下記の通りの結果が得られると思います。
% python predict_request.py
{'results': 0, 'eval': {'kawaii': 0.9999431371688843, 'kakkoii': 5.68593641219195e-05}}
さて、sagemaker serverless inferenceを使った画像認識APIの作り方の説明は以上になります。
今回の構築にあたっての感想として、entory_point.pyを正常に動作させるのがかなり大変でした。エラーが出るたびに、entry_point.pyをアップロードしてエンドポイントを作成する必要があり時間がかかったのが最大の要因です。どなたか素早くエラー解消をする方法をご存知でしたら教えていただけると恐縮です😅
最後にentry_point.pyなどのスクリプトをこちらにアップロードしています。よろしければご利用ください。
それではまた。
[参考文献]