Metagraph round-tripping for resource variables.

Change: 146519414
This commit is contained in:
Alexandre Passos 2017-02-03 14:45:21 -08:00 committed by TensorFlower Gardener
parent 6abf7cfee9
commit 987e3d9d79
5 changed files with 175 additions and 24 deletions

View File

@ -19,6 +19,9 @@ message VariableDef {
// Support for saving variables as slices of a larger variable.
SaveSliceInfoDef save_slice_info_def = 4;
// Whether to represent this as a ResourceVariable.
bool is_resource = 5;
}
message SaveSliceInfoDef {

View File

@ -1341,6 +1341,7 @@ py_library(
":framework_for_generated_wrappers",
":resource_variable_ops_gen",
":util",
":variables",
],
)

View File

@ -35,6 +35,7 @@ from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@ -381,15 +382,21 @@ class ScopedMetaGraphTest(test.TestCase):
# Verifies that we can export a subgraph in a nested name scope containing a
# "hidden1/hidden2" and import it into "new_hidden1/new_hidden2" in a new
# graph.
def testExportNestedNames(self):
def doTestExportNestedNames(self, use_resource=False):
graph1 = ops.Graph()
with graph1.as_default():
with ops.name_scope("hidden1/hidden2/hidden3"):
images = constant_op.constant(
1.0, dtypes.float32, shape=[3, 2], name="images")
if use_resource:
weights1 = variables.Variable(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
biases1 = resource_variable_ops.ResourceVariable(
[0.1] * 3, name="biases")
else:
biases1 = variables.Variable([0.1] * 3, name="biases")
weights1 = variables.Variable(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
orig_meta_graph, var_list = meta_graph.export_scoped_meta_graph(
@ -425,6 +432,12 @@ class ScopedMetaGraphTest(test.TestCase):
for n, e in zip(nodes, expected):
self.assertEqual([e], graph2.get_operation_by_name(n).get_attr("_class"))
def testExportNestedNames(self):
self.doTestExportNestedNames(use_resource=False)
def testExportNestedNamesResource(self):
self.doTestExportNestedNames(use_resource=True)
def testPotentialCycle(self):
graph1 = ops.Graph()
with graph1.as_default():

View File

@ -20,9 +20,11 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import variable_pb2
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import variables
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_resource_variable_ops import *
@ -48,8 +50,68 @@ class ResourceVariable(object):
"""
# pylint: disable=unused-argument
def __init__(self,
initial_value=None,
trainable=True,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
dtype=None,
variable_def=None,
import_scope=None):
"""Creates a variable.
Args:
initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
which is the initial value for the Variable. The initial value must have
a shape specified unless `validate_shape` is set to False. Can also be a
callable with no argument that returns the initial value when called.
(Note that initializer functions from init_ops.py must first be bound
to a shape before being used here.)
trainable: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.
collections: List of graph collections keys. The new variable is added to
these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
validate_shape: Ignored. Provided for compatibility with tf.Variable.
caching_device: Optional device string or function describing where the
Variable should be cached for reading. Defaults to the Variable's
device. If not `None`, caches on another device. Typical use is to
cache on the device where the Ops using the Variable reside, to
deduplicate copying through `Switch` and other conditional statements.
name: Optional name for the variable. Defaults to `'Variable'` and gets
uniquified automatically.
dtype: If set, initial_value will be converted to the given type.
If None, either the datatype will be kept (if initial_value is
a Tensor) or float32 will be used (if it is a Python object convertible
to a Tensor).
variable_def: `VariableDef` protocol buffer. If not None, recreates the
`ResourceVariable` object with its contents. `variable_def` and other
arguments (except for import_scope) are mutually exclusive.
import_scope: Optional `string`. Name scope to add to the
ResourceVariable. Only used when `variable_def` is provided.
Raises:
ValueError: If the initial value is not specified, or does not have a
shape and `validate_shape` is `True`.
"""
if variable_def:
if initial_value:
raise ValueError("variable_def and initial_value are mutually "
"exclusive.")
self._init_from_proto(variable_def, import_scope=import_scope)
else:
self._init_from_args(initial_value=initial_value,
trainable=trainable,
collections=collections,
validate_shape=validate_shape,
caching_device=caching_device,
name=name,
dtype=dtype)
# pylint: disable=unused-argument
def _init_from_args(self,
initial_value=None,
trainable=True,
collections=None,
@ -139,10 +201,9 @@ class ResourceVariable(object):
self._is_initialized_op = (
gen_resource_variable_ops.var_is_initialized_op(self._handle))
if initial_value is not None:
with ops.name_scope("Create"):
with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
self._initialize_op = gen_resource_variable_ops.assign_variable_op(
self._handle, self._initial_value)
self._handle, self._initial_value, name=n)
with ops.name_scope("Read"):
self._value = gen_resource_variable_ops.read_variable_op(
self._handle, dtype=self._dtype)
@ -158,6 +219,32 @@ class ResourceVariable(object):
self._value.initializer = self._initialize_op
ops.add_to_collections(collections, self)
def _init_from_proto(self, variable_def, import_scope=None):
"""Initializes from `VariableDef` proto."""
assert isinstance(variable_def, variable_pb2.VariableDef)
if not variable_def.is_resource:
raise ValueError("Trying to restore Variable as ResourceVariable.")
# Create from variable_def.
g = ops.get_default_graph()
self._handle = g.as_graph_element(
ops.prepend_name_scope(variable_def.variable_name,
import_scope=import_scope))
self._initialize_op = g.as_graph_element(
ops.prepend_name_scope(variable_def.initializer_name,
import_scope=import_scope))
self._cached_value = g.as_graph_element(
ops.prepend_name_scope(variable_def.snapshot_name,
import_scope=import_scope))
self._value = self._cached_value
if variable_def.HasField("save_slice_info_def"):
self._save_slice_info = variables.Variable.SaveSliceInfo(
save_slice_info_def=variable_def.save_slice_info_def)
else:
self._save_slice_info = None
self._caching_device = None
self._dtype = self._handle.op.get_attr("dtype")
@property
def dtype(self):
"""The dtype of this variable."""
@ -247,6 +334,38 @@ class ResourceVariable(object):
_register_variable_read(value, collections=collections, trainable=trainable)
return array_ops.identity(value)
def to_proto(self, export_scope=None):
"""Converts a `ResourceVariable` to a `VariableDef` protocol buffer.
Args:
export_scope: Optional `string`. Name scope to remove.
Returns:
A `VariableDef` protocol buffer, or `None` if the `Variable` is not
in the specified name scope.
"""
if (export_scope is None or
self.handle.name.startswith(export_scope)):
var_def = variable_pb2.VariableDef()
var_def.variable_name = ops.strip_name_scope(
self.handle.name, export_scope)
var_def.initializer_name = ops.strip_name_scope(
self.initializer.name, export_scope)
var_def.snapshot_name = ops.strip_name_scope(
self.value().name, export_scope)
var_def.is_resource = True
if self._save_slice_info:
var_def.save_slice_info_def.MergeFrom(self._save_slice_info.to_proto(
export_scope=export_scope))
return var_def
else:
return None
@staticmethod
def from_proto(variable_def, import_scope=None):
return ResourceVariable(variable_def=variable_def,
import_scope=import_scope)
@staticmethod
def _OverloadAllOperators(): # pylint: disable=invalid-name
"""Register overloads for all operators."""
@ -325,3 +444,32 @@ ops.register_tensor_conversion_function(ResourceVariable, _dense_var_to_tensor)
# pylint: disable=protected-access
ResourceVariable._OverloadAllOperators()
ops.register_dense_tensor_like_type(ResourceVariable)
def _to_proto_fn(v, export_scope=None):
"""Converts Variable and ResourceVariable to VariableDef for collections."""
return v.to_proto(export_scope=export_scope)
def _from_proto_fn(v, import_scope=None):
"""Creates Variable or ResourceVariable from VariableDef as needed."""
if v.is_resource:
return ResourceVariable.from_proto(v, import_scope=import_scope)
return variables.Variable.from_proto(v, import_scope=import_scope)
ops.register_proto_function(
ops.GraphKeys.GLOBAL_VARIABLES,
proto_type=variable_pb2.VariableDef,
to_proto=_to_proto_fn,
from_proto=_from_proto_fn)
ops.register_proto_function(
ops.GraphKeys.TRAINABLE_VARIABLES,
proto_type=variable_pb2.VariableDef,
to_proto=_to_proto_fn,
from_proto=_from_proto_fn)
ops.register_proto_function(
ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
proto_type=variable_pb2.VariableDef,
to_proto=_to_proto_fn,
from_proto=_from_proto_fn)

View File

@ -1326,19 +1326,5 @@ ops.register_tensor_conversion_function(
PartitionedVariable, PartitionedVariable._TensorConversionFunction)
# pylint: enable=protected-access
ops.register_dense_tensor_like_type(Variable)
ops.register_proto_function(
ops.GraphKeys.GLOBAL_VARIABLES,
proto_type=variable_pb2.VariableDef,
to_proto=Variable.to_proto,
from_proto=Variable.from_proto)
ops.register_proto_function(
ops.GraphKeys.TRAINABLE_VARIABLES,
proto_type=variable_pb2.VariableDef,
to_proto=Variable.to_proto,
from_proto=Variable.from_proto)
ops.register_proto_function(
ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
proto_type=variable_pb2.VariableDef,
to_proto=Variable.to_proto,
from_proto=Variable.from_proto)