Merge tensorflow models

Posted by chunyang on December 27, 2020

Sometimes you want to transfer certain weights from multiple models into a single model or just want to merge multiple models. There are at least two ways to do that:

init_from_checkpoint

You can refer to tf.train.init_from_checkpoint.

tf.train.init_from_checkpoint(init_from_checkpoint(
    ckpt_dir_or_file, assignment_map
)
  • ckpt_dir_or_file
    • Can be a checkpoint directory or file: /path/to/checkpoint_dir, /path/to/checkpoint_dir/model-1234
    • Can be a saved model: /path/to/saved_model/variables/variables (The second variables is the name prefix of variables.index)
  • assignment_map
    • key: can be a scope
    • value: can be variable name, variable reference

It is very flexible. Under the hood, init_from_checkpoint modifies the initializer of a variable. When we run tf.global_variables_initializer(), the related restore op will be executed.

variable._initializer_op = init_op
restore_op.set_shape(variable.shape)
variable._initial_value = restore_op

If you have user-defined variables such as you create a AwesomeVariable which behaves like a tensorflow Variable but with a different back-end storage. You can define a similar function by creating the user-defined initializer_op.

# Create your op using
io_ops.restore_v2

# Replace the initializer op
variable._initializer_op = init_op

Multiple init_from_checkpoint can be called with different ckpt_dir_or_file. As a result, a single model’s variables can be initialized from different source checkpoints or saved models.

import os
import tensorflow as tf

os.makedirs("./models/a", exist_ok=True)
os.makedirs("./models/b", exist_ok=True)

with tf.Session(graph=tf.Graph()) as session:
    tf.Variable(3, name="a")
    saver = tf.train.Saver()
    session.run(tf.global_variables_initializer())
    saver.save(session, "./models/a/model-a")


with tf.Session(graph=tf.Graph()) as session:
    tf.Variable(4, name="b")
    saver = tf.train.Saver()
    session.run(tf.global_variables_initializer())
    saver.save(session, "./models/b/model-b")


with tf.Session(graph=tf.Graph()) as session:
    a = tf.Variable(1, name="a")
    b = tf.Variable(1, name="b")
    tf.train.init_from_checkpoint("./models/a/model-a", {"a": a})
    tf.train.init_from_checkpoint("./models/b/model-b", {"b": b})
    session.run(tf.global_variables_initializer())
    print(session.run(a))
    print(session.run(b))

gen_io_ops.merge_v2_checkpoints

When tensorflow loads a model, the only requirement is that: all variables in the model must have a valid value in the checkpoint. So we can first merge checkpoints of different models and then just load once.

 import tensorflow as tf
from tensorflow.python.ops import gen_io_ops

src = tf.constant(["./models/a/model-a", "./models/b/model-b"])
target = tf.constant("./models/merged_model")

op = gen_io_ops.merge_v2_checkpoints(src, target, delete_old_dirs=False)

tf.Session().run(op)

with tf.Session(graph=tf.Graph()) as session:

    a = tf.Variable(1, name="a")
    b = tf.Variable(1, name="b")
    saver = tf.train.Saver()
    saver.restore(session, "./models/merged_model")
    print(session.run(a))
    print(session.run(b))

Be careful that delete_old_dirs will delete the .index and .data file no matter it is set to True or False.