일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |
- hadoop2
- LSTM
- 딥러닝
- NumPy
- 그래프이론
- Sort
- Java
- 텐서플로
- collections
- codingthematrix
- yarn
- RNN
- graph
- scrapy
- 하이브
- 주식분석
- 파이썬
- HelloWorld
- 하둡2
- 코딩더매트릭스
- C언어
- tensorflow
- recursion
- effective python
- GRU
- C
- 선형대수
- hive
- 알고리즘
- python
- Today
- Total
EXCELSIOR
[러닝 텐서플로]Chap10.1 - 모델 익스포트와 서빙, Saver 본문
[러닝 텐서플로]Chap10.1 - 모델 익스포트와 서빙, Saver
Excelsior-JH 2018. 7. 5. 13:34학습한 모델을 저장하고 내보내는 방법에 대해 NumPy의
.savez()
와 텐서플로의Saver
를 사용해 학습된 가중치를 저장하고 로드해보자.
10.1 모델을 저장하고 내보내기
텐서플로를 이용해 모델을 만들고 학습한 뒤 학습된 모델 즉, 매개변수(weight, bias)를 저장하는 방법에 대해 알아보자. 이렇게 학습된 모델을 저장해놓으면 나중에 모델을 처음 부터 다시 학습시킬 필요가 없기 때문에 편리하다.
학습된 모델을 저장하기 위해 NumPy를 이용해 매개변수를 저장하는 방법을 알아보고, 텐서플로의 Saver
를 이용해 모델을 저장하고 관리하는 방법에 대해 알아보자.
10.1.1 로딩된 가중치 할당
먼저, NumPy의 savez
를 이용해 학습된 가중치 값을 저장하고, 불러오는 방법에 대해 알아보자. 이를 위해, Chap02-텐서플로 설치 및 실행에서 살펴본 Softmax Regression을 이용해 MNIST 데이터 분류 모델을 만들어 준다.
savez()
를 이용해 저장해주자. savez()
는 NumPy의 array형식을 .npz
파일로 저장해주는 기능을 한다.
위의 코드를 통해 저장된 weight_storage.npz
파일을 불러와 tf.Variable()
의 .assign()
메소드를 통해 학습된 가중치들을 할당해줄 수 있다.
아래의 코드는 위에서 구현한 Softmax Regression을 학습된 가중치를 가지고 정확도(accuracy
)를 구하는 코드이다.
이번에는 간단한 CNN 모델을 가지고 위와 동일한 방법으로 NumPy의 cnn_weight_storage.npz
파일로 가중치를 저장한 뒤 로드해 테스트를 해보자.
먼저, CNN 모델을 클래스 형태로 구성해준다.
SimpleCNN
클래스를 이용해 학습을 하고 가중치를 저장한다.
cnn_weight_storage.npz
를 로드하여 학습된 가중치를 이용해 테스트셋을 분류해보자.
텐서플로는 자체적으로 학습된 모델을 저장하고 로드할 수 있는 기능인 Saver
라는 클래스를 제공한다. Saver
는 체크포인트 파일(checkpoint file)인 이진 파일을 이용하여 모델의 매개변수를 저장하고 복원한다.
텐서플로의 Saver
는 tf.train.Saver()
를 통해 사용할 수 있으며, tf.train.Saver()
의 .save()
메소드를 이용해 체크포인트 파일을 저장한다.
10.1.1에서 구현한 Softmax Regression 모델을 텐서플로의 Saver
를 이용해 저장하고 불러와 보도록 하자.
Saver
를 이용해 학습된 가중치를 저장하였으므로, 이번에는 Saver.restore()
을 이용해 체크포인트를 복원하여 학습된 가중치를 모델에 할당해보자.
tf.reset_default_graph()
# placeholder and variable
inputs = tf.placeholder(tf.float32, [None, 28*28], name='inputs')
weights = tf.Variable(
tf.truncated_normal(shape=[28*28, 10], stddev=0.01),
name='weights')
logits = tf.matmul(inputs, weights)
labels = tf.placeholder(tf.float32, [None, 10])
# accuracy
correct_mask = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
accuracy = tf.reduce_mean(tf.cast(correct_mask, tf.float32))
# Saver
saver = tf.train.Saver()
# Test
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess, os.path.join(MODEL_PATH, 'model_ckpt-1000'))
acc = sess.run(accuracy, feed_dict={inputs: test_x,
labels: test_y})
print("Accuarcy: {:.5f}".format(acc))
INFO:tensorflow:Restoring parameters from ./saved_model/model_ckpt-1000
Accuarcy: 0.91260
Saver를 이용해 그래프 복원하기
텐서플로 Saver
의 장점은 연산 그래프를 저장해 다시 불러올 수 있다는 것이다. 위의 에제에서는 저장된 가중치 파일을 로드하여 그래프를 다시 구성한 뒤에 테스트를 수행했다. 텐서플로의 Saver
는 기본적으로 체크포인트 파일을 저장할때 그래프 정보를 담고있는 .meta
파일도 같이 저장한다.
이렇게 저장된 .meta
파일을 텐서플로 tf.train.import_meta_graph()
를 이용해 그래프를 불러온다. 아래의 예제코드는 tf.train.import_meta_graph()
를 이용해 그래프를 불러와 테스트를 수행하는 코드이다. 학습단계에서 텐서플로의 컬렉션(collection)에 테스트 단계에 사용할 변수 inputs, labels, accuracy
를추가하고 Saver
의 .export_meta_graph()
메소드의 인자 collection_list
에 넣어준다.
tf.train.import_meta_graph()
를 이용해 저장한 그래프를 불러오고 tf.get_collection()
을 통해 텐서플로 컬렉션에 저장한 변수들을 할당해준 후 테스트데이터로 테스트 해보자.
tf.reset_default_graph()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Saver
saver = tf.train.import_meta_graph(os.path.join(MODEL_PATH, 'model_ckpt.meta'))
saver.restore(sess, os.path.join(MODEL_PATH, 'model_ckpt-1000'))
inputs = tf.get_collection('train_var')[0]
labels = tf.get_collection('train_var')[1]
accuracy = tf.get_collection('train_var')[2]
test_acc = sess.run(accuracy, feed_dict={inputs: test_x,
labels: test_y})
print("Accuarcy: {:.5f}".format(test_acc))
INFO:tensorflow:Restoring parameters from ./saved_model/model_ckpt-1000
Accuarcy: 0.91270
10.1.3 정리
텐서플로에서 학습한 가중치를 NumPy와 텐서플로의 Saver를 이용해 저장하고, 불러오는 방법에 대해 알아보았다. 여러번 학습을 시켜야 하거나, 학습된 모델을 바로 테스트 하는데 이러한 방법들을 이용해 편리하게 테스트할 수 있다.
'DeepLearning > Learning TensorFlow' 카테고리의 다른 글
Custom Estimator (0) | 2018.09.29 |
---|---|
TensorFlow Feature Column (0) | 2018.09.22 |
[러닝 텐서플로]Chap09 - 분산 텐서플로 (0) | 2018.07.04 |
[러닝 텐서플로]Chap08 - tf.FIFOQueue, tf.QueueRunner, TFRecords (1) | 2018.07.04 |
[러닝 텐서플로]Chap07.4 - 텐서플로 추상화와 간소화, TF-Slim (1) | 2018.07.01 |