Saving and restoring variables in Tensorflow

See also Tensorflow documentation

In [39]:
#### Imports
import numpy as np
import tensorflow as tf
In [38]:
#### Utilitary to display a graph in Jupyter
from IPython.display import clear_output, Image, display, HTML

def strip_consts(graph_def, max_const_size=32):
    """Strip large constant values from graph_def."""
    strip_def = tf.GraphDef()
    for n0 in graph_def.node:
        n = strip_def.node.add() 
        n.MergeFrom(n0)
        if n.op == 'Const':
            tensor = n.attr['value'].tensor
            size = len(tensor.tensor_content)
            if size > max_const_size:
                tensor.tensor_content = "<stripped %d bytes>"%size
    return strip_def

def show_graph(graph_def, max_const_size=32):
    """Visualize TensorFlow graph."""
    if hasattr(graph_def, 'as_graph_def'):
        graph_def = graph_def.as_graph_def()
    strip_def = strip_consts(graph_def, max_const_size=max_const_size)
    code = """
        <script>
          function load() {{
            document.getElementById("{id}").pbtxt = {data};
          }}
        </script>
        <link rel="import" href="https://tensorboard.appspot.com/tf-graph-basic.build.html" onload=load()>
        <div style="height:600px">
          <tf-graph-basic id="{id}"></tf-graph-basic>
        </div>
    """.format(data=repr(str(strip_def)), id='graph'+str(np.random.rand()))

    iframe = """
        <iframe seamless style="width:1200px;height:620px;border:0" srcdoc="{}"></iframe>
    """.format(code.replace('"', '&quot;'))
    display(HTML(iframe))

Here is a typical dummy example for training parameters of a neural network with gradient descent on a given loss objective:

In [47]:
with tf.Graph().as_default():
    
    ### Inputs
    x_train = tf.placeholder(tf.float32, shape=(10, 1), name='input')
    
    ### Training graph: only 2 layers
    ### Each layer is parametrized by a weight variable (w1, w2)
    ### Weights are variables that will be stored and updated at each iteration
    w1 = tf.Variable(tf.ones((10, 10)), name='w1')
    y = tf.matmul(w1, x_train)
    w2 = tf.Variable(tf.ones((10, 1)), name='w2')
    y = w2 * y
    
    ### Define loss / objective function
    loss = tf.reduce_mean(tf.abs(y - x_train))
    
    ### Tensorflow Gradient Descent optimizer
    learning_rate = 0.1
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    ### This returns the training operations that will update the weights
    train_op = optimizer.minimize(loss)
    
    ### Display the graph
    show_graph(tf.get_default_graph().as_graph_def())
    
    ### Operation to initialize all variables
    init = tf.initialize_all_variables()
    
    with tf.Session() as sess:
        # Initialize the variables
        # This *has* to be run before anything else usually
        sess.run(init)
        # Run `num_steps` training iterations that update the weights
        num_steps = 20
        for k in range(num_steps):
            loss_, _ = sess.run([loss, train_op], 
                                feed_dict={
                    x_train: np.random.normal(size=(10, 1))
                })
            print('\r Step %d: Loss %.3f' % (k + 1, loss_), end='')
 Step 20: Loss 0.868

Now let's build the same graph but we will also save the graph structure and the weights so that we can esily reload and reuse the trained model later (e.g., for inference or for initializing a model to retrain, etc). This is done with a Saver object in Tensorflow.

In [45]:
with tf.Graph().as_default():
    
    ### Same graph as before
    x_train = tf.placeholder(tf.float32, shape=(10, 1), name='input')
    w1 = tf.Variable(tf.ones((10, 10)), name='w1')
    y = tf.matmul(w1, x_train)
    w2 = tf.Variable(tf.ones((10, 1)), name='w2')
    y = w2 * y
    loss = tf.reduce_mean(tf.abs(y - x_train))
    learning_rate = 0.1
    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
    train_op = optimizer.minimize(loss)
    init = tf.initialize_all_variables()
    
    ## Useful trick: graph collections are a nice way to group and retrieve variables
    tf.add_to_collection("weights", w1)
    tf.add_to_collection("weights", w2)
    
    ## Create the Tensorflow saver object
    ## One can specify the variables to save with var_list argument
    ## Other options are also available, e.g. how many checkpoints to keep for a given model
    saver = tf.train.Saver(var_list=tf.get_collection("weights"))
    #saver = tf.train.Saver(var_list=[w1])  ## Another example to save only w1 in the graph
    
    with tf.Session() as sess:
        num_steps = 20
        sess.run(init)
        for k in range(num_steps):
            loss_, _ = sess.run([loss, train_op], feed_dict={
                    x_train: np.random.normal(size=(10, 1))
                })
            print('\r Step %d: Loss %.3f' % (k + 1, loss_), end='')
           
        # These are the weights we are saving
        print('\n\nFinal weights\n-------------')
        print(sess.run(w1))
        
        # Saving !
        save_path = saver.save(sess, "./model.ckpt", global_step=k)
        print('\n\nModel saved in %s' % save_path)
 Step 20: Loss 1.094

Final weights
-------------
[[ 1.0266825   1.05772007  0.96590739  0.95724243  0.92907822  0.96611345
   0.90307653  1.01888311  0.92927408  0.93793356]
 [ 0.9495331   1.06078911  0.95566654  0.98204917  0.94157666  0.96301436
   0.92456734  1.01331544  0.92364347  0.9458065 ]
 [ 0.92919219  0.99486768  1.04470682  0.9641372   0.95812792  0.93862516
   0.93546259  0.96029496  0.97756213  0.96538091]
 [ 0.94353551  1.00090969  0.96378243  1.0284431   0.96275479  0.96283674
   0.91959196  1.01346624  0.92834085  0.95175922]
 [ 0.94587582  1.02199364  0.98864484  0.98802131  0.97426826  0.93756926
   0.92551464  1.00862062  0.94585419  0.92291611]
 [ 0.88112354  0.98502958  0.95367461  0.99699944  0.92616236  1.01558769
   0.93922436  1.00143313  0.98584831  0.97601318]
 [ 0.91298831  0.9453097   0.9586634   0.9823851   0.92361295  0.97580123
   0.97470599  1.03575051  0.97666329  0.98235124]
 [ 0.90732378  0.97126853  0.94039094  1.00005114  0.94014442  0.98188919
   0.93479311  1.04135549  0.96493596  0.96307874]
 [ 0.87640792  0.946374    0.97236532  0.99009228  0.92437834  0.99745423
   0.93931478  1.00605547  1.00048828  0.99940997]
 [ 0.91411066  0.96609396  0.97826278  0.99710542  0.91143686  0.98803157
   0.93433303  0.97239947  0.97085494  1.03977525]]


Model saved in ./model.ckpt-19

This should output three files in the given save path:

  • checkpoint contains the list of checkpoints saved for the given graph, and in particular the most recent one.
  • {}.meta file contains only the graph definition / structure and not the value of variables. It can be used to load the graph structure from scratch.
  • {}.ckpt files contain the value of variables saved but not the structure of the graph, i.e how they are linked in the graph.

Now the full graph and saved variable values can be easily restored

In [46]:
with tf.Graph().as_default():
    #### First we need to define the graph
    
    ## Either you still have the original code and then you can build the graph like previously
    # ... graph
    # and also the saver object
    # ... saver = tf.train.Saver(var_list=...)
    
    ## Or you can load the graph structure from the meta file directly
    saver = tf.train.import_meta_graph('%s.meta' % save_path)
    # Since we don't have the python variable anmore, we need another way to point to
    # the variable nodes, here we can for instance use the collection we defined before,
    # and which has been saved to the meta graph too
    w1 = tf.get_collection("weights")[0]
    
    ### Display the graph
    show_graph(tf.get_default_graph().as_graph_def())

    with tf.Session() as sess:
        # Here we restore the variable values for the given session and associated graph
        saver.restore(sess, save_path)
            
        # Print the weights to check their value matches the one you saved
        print('\n\nFinal weights\n-------------')
        print(sess.run(w1))

Final weights
-------------
[[ 1.0266825   1.05772007  0.96590739  0.95724243  0.92907822  0.96611345
   0.90307653  1.01888311  0.92927408  0.93793356]
 [ 0.9495331   1.06078911  0.95566654  0.98204917  0.94157666  0.96301436
   0.92456734  1.01331544  0.92364347  0.9458065 ]
 [ 0.92919219  0.99486768  1.04470682  0.9641372   0.95812792  0.93862516
   0.93546259  0.96029496  0.97756213  0.96538091]
 [ 0.94353551  1.00090969  0.96378243  1.0284431   0.96275479  0.96283674
   0.91959196  1.01346624  0.92834085  0.95175922]
 [ 0.94587582  1.02199364  0.98864484  0.98802131  0.97426826  0.93756926
   0.92551464  1.00862062  0.94585419  0.92291611]
 [ 0.88112354  0.98502958  0.95367461  0.99699944  0.92616236  1.01558769
   0.93922436  1.00143313  0.98584831  0.97601318]
 [ 0.91298831  0.9453097   0.9586634   0.9823851   0.92361295  0.97580123
   0.97470599  1.03575051  0.97666329  0.98235124]
 [ 0.90732378  0.97126853  0.94039094  1.00005114  0.94014442  0.98188919
   0.93479311  1.04135549  0.96493596  0.96307874]
 [ 0.87640792  0.946374    0.97236532  0.99009228  0.92437834  0.99745423
   0.93931478  1.00605547  1.00048828  0.99940997]
 [ 0.91411066  0.96609396  0.97826278  0.99710542  0.91143686  0.98803157
   0.93433303  0.97239947  0.97085494  1.03977525]]