Metagraph round-tripping for resource variables.
Change: 146519414
This commit is contained in:
parent
6abf7cfee9
commit
987e3d9d79
@ -19,6 +19,9 @@ message VariableDef {
|
||||
|
||||
// Support for saving variables as slices of a larger variable.
|
||||
SaveSliceInfoDef save_slice_info_def = 4;
|
||||
|
||||
// Whether to represent this as a ResourceVariable.
|
||||
bool is_resource = 5;
|
||||
}
|
||||
|
||||
message SaveSliceInfoDef {
|
||||
|
@ -1341,6 +1341,7 @@ py_library(
|
||||
":framework_for_generated_wrappers",
|
||||
":resource_variable_ops_gen",
|
||||
":util",
|
||||
":variables",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -35,6 +35,7 @@ from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import nn_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
@ -381,15 +382,21 @@ class ScopedMetaGraphTest(test.TestCase):
|
||||
# Verifies that we can export a subgraph in a nested name scope containing a
|
||||
# "hidden1/hidden2" and import it into "new_hidden1/new_hidden2" in a new
|
||||
# graph.
|
||||
def testExportNestedNames(self):
|
||||
def doTestExportNestedNames(self, use_resource=False):
|
||||
graph1 = ops.Graph()
|
||||
with graph1.as_default():
|
||||
with ops.name_scope("hidden1/hidden2/hidden3"):
|
||||
images = constant_op.constant(
|
||||
1.0, dtypes.float32, shape=[3, 2], name="images")
|
||||
if use_resource:
|
||||
weights1 = variables.Variable(
|
||||
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
|
||||
biases1 = resource_variable_ops.ResourceVariable(
|
||||
[0.1] * 3, name="biases")
|
||||
else:
|
||||
biases1 = variables.Variable([0.1] * 3, name="biases")
|
||||
weights1 = variables.Variable(
|
||||
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
|
||||
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
|
||||
|
||||
orig_meta_graph, var_list = meta_graph.export_scoped_meta_graph(
|
||||
@ -425,6 +432,12 @@ class ScopedMetaGraphTest(test.TestCase):
|
||||
for n, e in zip(nodes, expected):
|
||||
self.assertEqual([e], graph2.get_operation_by_name(n).get_attr("_class"))
|
||||
|
||||
def testExportNestedNames(self):
|
||||
self.doTestExportNestedNames(use_resource=False)
|
||||
|
||||
def testExportNestedNamesResource(self):
|
||||
self.doTestExportNestedNames(use_resource=True)
|
||||
|
||||
def testPotentialCycle(self):
|
||||
graph1 = ops.Graph()
|
||||
with graph1.as_default():
|
||||
|
@ -20,9 +20,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.core.framework import attr_value_pb2
|
||||
from tensorflow.core.framework import variable_pb2
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_resource_variable_ops
|
||||
from tensorflow.python.ops import variables
|
||||
# go/tf-wildcard-import
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.python.ops.gen_resource_variable_ops import *
|
||||
@ -48,8 +50,68 @@ class ResourceVariable(object):
|
||||
|
||||
"""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def __init__(self,
|
||||
initial_value=None,
|
||||
trainable=True,
|
||||
collections=None,
|
||||
validate_shape=True,
|
||||
caching_device=None,
|
||||
name=None,
|
||||
dtype=None,
|
||||
variable_def=None,
|
||||
import_scope=None):
|
||||
"""Creates a variable.
|
||||
|
||||
Args:
|
||||
initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
|
||||
which is the initial value for the Variable. The initial value must have
|
||||
a shape specified unless `validate_shape` is set to False. Can also be a
|
||||
callable with no argument that returns the initial value when called.
|
||||
(Note that initializer functions from init_ops.py must first be bound
|
||||
to a shape before being used here.)
|
||||
trainable: If `True`, the default, also adds the variable to the graph
|
||||
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
|
||||
the default list of variables to use by the `Optimizer` classes.
|
||||
collections: List of graph collections keys. The new variable is added to
|
||||
these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
|
||||
validate_shape: Ignored. Provided for compatibility with tf.Variable.
|
||||
caching_device: Optional device string or function describing where the
|
||||
Variable should be cached for reading. Defaults to the Variable's
|
||||
device. If not `None`, caches on another device. Typical use is to
|
||||
cache on the device where the Ops using the Variable reside, to
|
||||
deduplicate copying through `Switch` and other conditional statements.
|
||||
name: Optional name for the variable. Defaults to `'Variable'` and gets
|
||||
uniquified automatically.
|
||||
dtype: If set, initial_value will be converted to the given type.
|
||||
If None, either the datatype will be kept (if initial_value is
|
||||
a Tensor) or float32 will be used (if it is a Python object convertible
|
||||
to a Tensor).
|
||||
variable_def: `VariableDef` protocol buffer. If not None, recreates the
|
||||
`ResourceVariable` object with its contents. `variable_def` and other
|
||||
arguments (except for import_scope) are mutually exclusive.
|
||||
import_scope: Optional `string`. Name scope to add to the
|
||||
ResourceVariable. Only used when `variable_def` is provided.
|
||||
|
||||
Raises:
|
||||
ValueError: If the initial value is not specified, or does not have a
|
||||
shape and `validate_shape` is `True`.
|
||||
"""
|
||||
if variable_def:
|
||||
if initial_value:
|
||||
raise ValueError("variable_def and initial_value are mutually "
|
||||
"exclusive.")
|
||||
self._init_from_proto(variable_def, import_scope=import_scope)
|
||||
else:
|
||||
self._init_from_args(initial_value=initial_value,
|
||||
trainable=trainable,
|
||||
collections=collections,
|
||||
validate_shape=validate_shape,
|
||||
caching_device=caching_device,
|
||||
name=name,
|
||||
dtype=dtype)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def _init_from_args(self,
|
||||
initial_value=None,
|
||||
trainable=True,
|
||||
collections=None,
|
||||
@ -139,10 +201,9 @@ class ResourceVariable(object):
|
||||
self._is_initialized_op = (
|
||||
gen_resource_variable_ops.var_is_initialized_op(self._handle))
|
||||
if initial_value is not None:
|
||||
with ops.name_scope("Create"):
|
||||
with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
|
||||
self._initialize_op = gen_resource_variable_ops.assign_variable_op(
|
||||
self._handle, self._initial_value)
|
||||
|
||||
self._handle, self._initial_value, name=n)
|
||||
with ops.name_scope("Read"):
|
||||
self._value = gen_resource_variable_ops.read_variable_op(
|
||||
self._handle, dtype=self._dtype)
|
||||
@ -158,6 +219,32 @@ class ResourceVariable(object):
|
||||
self._value.initializer = self._initialize_op
|
||||
ops.add_to_collections(collections, self)
|
||||
|
||||
def _init_from_proto(self, variable_def, import_scope=None):
|
||||
"""Initializes from `VariableDef` proto."""
|
||||
assert isinstance(variable_def, variable_pb2.VariableDef)
|
||||
if not variable_def.is_resource:
|
||||
raise ValueError("Trying to restore Variable as ResourceVariable.")
|
||||
|
||||
# Create from variable_def.
|
||||
g = ops.get_default_graph()
|
||||
self._handle = g.as_graph_element(
|
||||
ops.prepend_name_scope(variable_def.variable_name,
|
||||
import_scope=import_scope))
|
||||
self._initialize_op = g.as_graph_element(
|
||||
ops.prepend_name_scope(variable_def.initializer_name,
|
||||
import_scope=import_scope))
|
||||
self._cached_value = g.as_graph_element(
|
||||
ops.prepend_name_scope(variable_def.snapshot_name,
|
||||
import_scope=import_scope))
|
||||
self._value = self._cached_value
|
||||
if variable_def.HasField("save_slice_info_def"):
|
||||
self._save_slice_info = variables.Variable.SaveSliceInfo(
|
||||
save_slice_info_def=variable_def.save_slice_info_def)
|
||||
else:
|
||||
self._save_slice_info = None
|
||||
self._caching_device = None
|
||||
self._dtype = self._handle.op.get_attr("dtype")
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""The dtype of this variable."""
|
||||
@ -247,6 +334,38 @@ class ResourceVariable(object):
|
||||
_register_variable_read(value, collections=collections, trainable=trainable)
|
||||
return array_ops.identity(value)
|
||||
|
||||
def to_proto(self, export_scope=None):
|
||||
"""Converts a `ResourceVariable` to a `VariableDef` protocol buffer.
|
||||
|
||||
Args:
|
||||
export_scope: Optional `string`. Name scope to remove.
|
||||
|
||||
Returns:
|
||||
A `VariableDef` protocol buffer, or `None` if the `Variable` is not
|
||||
in the specified name scope.
|
||||
"""
|
||||
if (export_scope is None or
|
||||
self.handle.name.startswith(export_scope)):
|
||||
var_def = variable_pb2.VariableDef()
|
||||
var_def.variable_name = ops.strip_name_scope(
|
||||
self.handle.name, export_scope)
|
||||
var_def.initializer_name = ops.strip_name_scope(
|
||||
self.initializer.name, export_scope)
|
||||
var_def.snapshot_name = ops.strip_name_scope(
|
||||
self.value().name, export_scope)
|
||||
var_def.is_resource = True
|
||||
if self._save_slice_info:
|
||||
var_def.save_slice_info_def.MergeFrom(self._save_slice_info.to_proto(
|
||||
export_scope=export_scope))
|
||||
return var_def
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def from_proto(variable_def, import_scope=None):
|
||||
return ResourceVariable(variable_def=variable_def,
|
||||
import_scope=import_scope)
|
||||
|
||||
@staticmethod
|
||||
def _OverloadAllOperators(): # pylint: disable=invalid-name
|
||||
"""Register overloads for all operators."""
|
||||
@ -325,3 +444,32 @@ ops.register_tensor_conversion_function(ResourceVariable, _dense_var_to_tensor)
|
||||
# pylint: disable=protected-access
|
||||
ResourceVariable._OverloadAllOperators()
|
||||
ops.register_dense_tensor_like_type(ResourceVariable)
|
||||
|
||||
|
||||
def _to_proto_fn(v, export_scope=None):
|
||||
"""Converts Variable and ResourceVariable to VariableDef for collections."""
|
||||
return v.to_proto(export_scope=export_scope)
|
||||
|
||||
|
||||
def _from_proto_fn(v, import_scope=None):
|
||||
"""Creates Variable or ResourceVariable from VariableDef as needed."""
|
||||
if v.is_resource:
|
||||
return ResourceVariable.from_proto(v, import_scope=import_scope)
|
||||
return variables.Variable.from_proto(v, import_scope=import_scope)
|
||||
|
||||
|
||||
ops.register_proto_function(
|
||||
ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
proto_type=variable_pb2.VariableDef,
|
||||
to_proto=_to_proto_fn,
|
||||
from_proto=_from_proto_fn)
|
||||
ops.register_proto_function(
|
||||
ops.GraphKeys.TRAINABLE_VARIABLES,
|
||||
proto_type=variable_pb2.VariableDef,
|
||||
to_proto=_to_proto_fn,
|
||||
from_proto=_from_proto_fn)
|
||||
ops.register_proto_function(
|
||||
ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
|
||||
proto_type=variable_pb2.VariableDef,
|
||||
to_proto=_to_proto_fn,
|
||||
from_proto=_from_proto_fn)
|
||||
|
@ -1326,19 +1326,5 @@ ops.register_tensor_conversion_function(
|
||||
PartitionedVariable, PartitionedVariable._TensorConversionFunction)
|
||||
# pylint: enable=protected-access
|
||||
|
||||
|
||||
ops.register_dense_tensor_like_type(Variable)
|
||||
ops.register_proto_function(
|
||||
ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
proto_type=variable_pb2.VariableDef,
|
||||
to_proto=Variable.to_proto,
|
||||
from_proto=Variable.from_proto)
|
||||
ops.register_proto_function(
|
||||
ops.GraphKeys.TRAINABLE_VARIABLES,
|
||||
proto_type=variable_pb2.VariableDef,
|
||||
to_proto=Variable.to_proto,
|
||||
from_proto=Variable.from_proto)
|
||||
ops.register_proto_function(
|
||||
ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
|
||||
proto_type=variable_pb2.VariableDef,
|
||||
to_proto=Variable.to_proto,
|
||||
from_proto=Variable.from_proto)
|
||||
|
Loading…
Reference in New Issue
Block a user