본문 바로가기
데이터 어쩌구/기술 써보기

Model 저장과 사용 (tf.keras)

by annmunju 2023. 8. 26.

저장

# 1. 모델 통째로 저장
model.save('./my_model')

# 2. weight만 저장
model.save_weights('./my_model/epoch_001')

# 3. callbacks를 사용하여 저장
# 체크포인트 경로 지정({}변수 에 epoch 값이 들어가도록 epoch fotmat을 포함시켜야 한다.)
checkpoint_path = "./checkpoints/epoch_{epoch:03d}.ckpt"

# 체크포인트 콜백 만들기
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
                                                 save_weights_only=True,
																								 period=1, # 1개의 epoch마다 저장
                                                 verbose=1)
model.fit(train_x, train_y,
          epochs=10, callbacks=[cp_callback],
          verbose=1)

사용

# 1. 모델 통째로 불러오기
model = keras.models.load_model('./my_model')

# 2. weight만 불러오기
model = Model()
model.load_weight('./my_model/epoch_001')

# 3. 위 콜백으로 저장된 것 중 latest모델 가져오기
latest = tf.train.latest_checkpoints('./checkpoints')
model.load_weight(latest)
728x90