통계, IT, AI

[TensorFlow] 모델 저장 및 복원 본문

IT/기타

[TensorFlow] 모델 저장 및 복원

Harold_Finch 2017. 9. 6. 23:36

1. 개요

    TensorFlow를 이용한 모델링은 오랜 시간이 걸릴 수 있다. 경우에 따라서는 한번의 Session으로 학습을 마치기 어려운 경우도 있다. 본 포스팅에서는 학습의 결과인 TensorFlow의 Variable을 저장하고 복원하는 방법을 살펴본다. 아래의 내용은 Udacity에서 발췌하였으며 사용된 코드는 이곳에서 확인할 수 있다. python은 3.5.3, TensorFlow는 1.3.0 사용한다.

2. Saving and Loading Variables

    아래는 weights와 bias라는 Variable을 만들고 저장하는 예시이다.[각주:1]

# -*- coding: utf-8 -*-

import tensorflow as tf

# Set random seed to make equal result
tf.set_random_seed(1234)

# The file path to save the data
save_file = './model.ckpt'

# Two Tensor Variables: weights and bias
weights = tf.Variable(tf.truncated_normal([2, 3]))
bias = tf.Variable(tf.truncated_normal([3]))

# Class used to save and/or restore Tensor Variables
saver = tf.train.Saver()

with tf.Session() as sess:
    
    # Initialize all the Variables
    sess.run(tf.global_variables_initializer())
    
    # Show the values of weights and bias
    print('Weights:')
    print(sess.run(weights))
    print('Bias:')
    print(sess.run(bias))
    
    # Save the model
    saver.save(sess, save_file)

    위 코드의 결과는 다음과 같다.

Weights:

[[-0.13862522 -0.24789245 -0.22179745]

[ 0.91565138 0.3255454 -0.5017547 ]]

Bias:

[-0.70187312 -0.81546098 -0.31579655]


    weights와 bias Variable은 truncated normal distribution에서 생성된 난수이다. Variable은 tf.train.Saver.save()함수를 통하여 save_file에 저장된다. TensorFlow 버전이 0.11.0RC1 이상인 경우, model.ckpt.meta 파일이 추가로 생성되는데, 이 파일에는 TensorFlow graph가 저장된다.아래는 위에서 저장한 Variable을 복원하는 코드이다.[각주:2]

# -*- coding: utf-8 -*-

import tensorflow as tf

# The file path to save the data
save_file = './model.ckpt'

# Remove the previous weights and bias
tf.reset_default_graph()

# Two Tensor Variables: weights and bias
weights = tf.Variable(tf.truncated_normal([2, 3]))
bias = tf.Variable(tf.truncated_normal([3]))

# Class used to save and/or restore Tensor Variables
saver = tf.train.Saver()

with tf.Session() as sess:
    # Load the weights and bias
    saver.restore(sess, save_file)
	
	# Show the values of weights and bias
    print('Weights:')
    print(sess.run(weights))
    print('Bias:')
    print(sess.run(bias))

    수행 결과는 save.py를 수행한 결과와 동일하다. tf.train.Saver.restore() 함수는 저장된 Variable을 복원한 후 같은 이름을 가진 변수에 할당하므로 retore()을 호출하기 전에 weights와 bias를 선언해야 한다. 그리고 restore()가 호출될 때에는 모든 TensorFlow Variable이 초기화되기 때문에 tf.global_variables_initializer()를 따로 호출할 필요는 없다.

3. Naming Error

    앞에서 restore() 함수가 Variable을 복원한 후 같은 이름을 가진 변수에 할당한다고 하였다. 이때 "같은 이름"은 변수명을 의미하는 것이 아니다. TensorFlow는 tensor, operation 등을 구분하기 위하여 name을 id로 사용한다. Variable을 선언할 때 name을 명시하지 않으면 TensforFlow가 자동으로 설정한다.


    이때 사용하는 규칙은 선언되는 node의 <type>과 선언되는 순서를 합쳐 <type>_<number>이다. 아래의 코드는 선언되는 Variable의 name을 다르게 하여 restore()의 과정에서 의도적으로 에러를 발생시키는 코드이다.[각주:3] 9, 10번째 줄과 26, 27번째 줄을 보자. 저장하는 단계에서는 weights가 bias보다 먼저 선언되었고 복원하는 단계에서는 bias가 먼저 선언되었다. 따라서 각 단계에서의 weight와 bias는 서로 다른 name을 가지고 있다.

import tensorflow as tf

# Remove the previous weights and bias
tf.reset_default_graph()

save_file = './model.ckpt'

# Two Tensor Variables: weights and bias
weights = tf.Variable(tf.truncated_normal([2, 3]))
bias = tf.Variable(tf.truncated_normal([3]))

saver = tf.train.Saver()

# Print the name of Weights and Bias
print('Save Weights: {}'.format(weights.name))
print('Save Bias: {}'.format(bias.name))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_file)

# Remove the previous weights and bias
tf.reset_default_graph()

# Two Variables: weights and bias
bias = tf.Variable(tf.truncated_normal([3]))
weights = tf.Variable(tf.truncated_normal([2, 3]))

saver = tf.train.Saver()

# Print the name of Weights and Bias
print('Load Weights: {}'.format(weights.name))
print('Load Bias: {}'.format(bias.name))

with tf.Session() as sess:
    # Load the weights and bias - ERROR
    saver.restore(sess, save_file)

    위 코드에서 발생한 문제를 해결하기 위해서는 저장 과정에서 설정한 name들과 restore할 때의 name을 맞게 설정하면 된다. 그것을 위해서는 저장 단계의 변수 선언 순서를 맞춰줄 수도 있지만 name을 명시해주는 것이 가독성 등에서 더 낫다. 아래의 코드는 name을 명시해주는 예시이다.[각주:4] 9, 10번째 줄과 26, 27번째 줄을 보자. 복원 단계의 변수명이 저장할 때의 변수명과 다르지만 name이 동일하므로 복원되는 것을 확인할 수 있다.

import tensorflow as tf

# Remove the previous weights and bias
tf.reset_default_graph()

save_file = './model.ckpt'

# Two Tensor Variables: weights and bias
weights = tf.Variable(tf.truncated_normal([2, 3]), name='w_1')
bias = tf.Variable(tf.truncated_normal([3]), name='b_1')

saver = tf.train.Saver()

# Print the name of Weights and Bias
print('Save Weights: {}'.format(weights.name))
print('Save Bias: {}'.format(bias.name))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_file)

# Remove the previous weights and bias
tf.reset_default_graph()

# Two Variables: weights and bias 
same_bias = tf.Variable(tf.truncated_normal([3]), name='b_1')
same_weights = tf.Variable(tf.truncated_normal([2, 3]), name='w_1')

saver = tf.train.Saver()

# Print the name of Weights and Bias
print('Load Weights: {}'.format(weights.name))
print('Load Bias: {}'.format(bias.name))

with tf.Session() as sess:
    # Load the weights and bias - ERROR
    saver.restore(sess, save_file)


  1. 파일명은 save.py이다. [본문으로]
  2. 파일명은 load.py이다. [본문으로]
  3. 파일명은 save_and_load_with_error.py이다. [본문으로]
  4. 파일명은 save_and_load_without_error.py이다. [본문으로]

'IT > 기타' 카테고리의 다른 글

[python] mutable vs immutable  (0) 2017.11.28
[기타] 티스토리 기초 설정  (1) 2017.07.09
C#에서 이벤트 사용하기  (1) 2017.01.29
C#에서 Zip 및 익명형식의 배열 사용하기  (0) 2016.11.14
C#과 R을 연동하기  (2) 2016.09.27
Comments