Metagraph round-tripping for resource variables.
Change: 146519414
This commit is contained in:
parent
6abf7cfee9
commit
987e3d9d79
@ -19,6 +19,9 @@ message VariableDef {
|
|||||||
|
|
||||||
// Support for saving variables as slices of a larger variable.
|
// Support for saving variables as slices of a larger variable.
|
||||||
SaveSliceInfoDef save_slice_info_def = 4;
|
SaveSliceInfoDef save_slice_info_def = 4;
|
||||||
|
|
||||||
|
// Whether to represent this as a ResourceVariable.
|
||||||
|
bool is_resource = 5;
|
||||||
}
|
}
|
||||||
|
|
||||||
message SaveSliceInfoDef {
|
message SaveSliceInfoDef {
|
||||||
|
@ -1341,6 +1341,7 @@ py_library(
|
|||||||
":framework_for_generated_wrappers",
|
":framework_for_generated_wrappers",
|
||||||
":resource_variable_ops_gen",
|
":resource_variable_ops_gen",
|
||||||
":util",
|
":util",
|
||||||
|
":variables",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -35,6 +35,7 @@ from tensorflow.python.ops import data_flow_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -381,15 +382,21 @@ class ScopedMetaGraphTest(test.TestCase):
|
|||||||
# Verifies that we can export a subgraph in a nested name scope containing a
|
# Verifies that we can export a subgraph in a nested name scope containing a
|
||||||
# "hidden1/hidden2" and import it into "new_hidden1/new_hidden2" in a new
|
# "hidden1/hidden2" and import it into "new_hidden1/new_hidden2" in a new
|
||||||
# graph.
|
# graph.
|
||||||
def testExportNestedNames(self):
|
def doTestExportNestedNames(self, use_resource=False):
|
||||||
graph1 = ops.Graph()
|
graph1 = ops.Graph()
|
||||||
with graph1.as_default():
|
with graph1.as_default():
|
||||||
with ops.name_scope("hidden1/hidden2/hidden3"):
|
with ops.name_scope("hidden1/hidden2/hidden3"):
|
||||||
images = constant_op.constant(
|
images = constant_op.constant(
|
||||||
1.0, dtypes.float32, shape=[3, 2], name="images")
|
1.0, dtypes.float32, shape=[3, 2], name="images")
|
||||||
|
if use_resource:
|
||||||
weights1 = variables.Variable(
|
weights1 = variables.Variable(
|
||||||
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
|
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
|
||||||
|
biases1 = resource_variable_ops.ResourceVariable(
|
||||||
|
[0.1] * 3, name="biases")
|
||||||
|
else:
|
||||||
biases1 = variables.Variable([0.1] * 3, name="biases")
|
biases1 = variables.Variable([0.1] * 3, name="biases")
|
||||||
|
weights1 = variables.Variable(
|
||||||
|
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
|
||||||
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
|
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
|
||||||
|
|
||||||
orig_meta_graph, var_list = meta_graph.export_scoped_meta_graph(
|
orig_meta_graph, var_list = meta_graph.export_scoped_meta_graph(
|
||||||
@ -425,6 +432,12 @@ class ScopedMetaGraphTest(test.TestCase):
|
|||||||
for n, e in zip(nodes, expected):
|
for n, e in zip(nodes, expected):
|
||||||
self.assertEqual([e], graph2.get_operation_by_name(n).get_attr("_class"))
|
self.assertEqual([e], graph2.get_operation_by_name(n).get_attr("_class"))
|
||||||
|
|
||||||
|
def testExportNestedNames(self):
|
||||||
|
self.doTestExportNestedNames(use_resource=False)
|
||||||
|
|
||||||
|
def testExportNestedNamesResource(self):
|
||||||
|
self.doTestExportNestedNames(use_resource=True)
|
||||||
|
|
||||||
def testPotentialCycle(self):
|
def testPotentialCycle(self):
|
||||||
graph1 = ops.Graph()
|
graph1 = ops.Graph()
|
||||||
with graph1.as_default():
|
with graph1.as_default():
|
||||||
|
@ -20,9 +20,11 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.core.framework import attr_value_pb2
|
from tensorflow.core.framework import attr_value_pb2
|
||||||
|
from tensorflow.core.framework import variable_pb2
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import gen_resource_variable_ops
|
from tensorflow.python.ops import gen_resource_variable_ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
# go/tf-wildcard-import
|
# go/tf-wildcard-import
|
||||||
# pylint: disable=wildcard-import
|
# pylint: disable=wildcard-import
|
||||||
from tensorflow.python.ops.gen_resource_variable_ops import *
|
from tensorflow.python.ops.gen_resource_variable_ops import *
|
||||||
@ -48,8 +50,68 @@ class ResourceVariable(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
|
initial_value=None,
|
||||||
|
trainable=True,
|
||||||
|
collections=None,
|
||||||
|
validate_shape=True,
|
||||||
|
caching_device=None,
|
||||||
|
name=None,
|
||||||
|
dtype=None,
|
||||||
|
variable_def=None,
|
||||||
|
import_scope=None):
|
||||||
|
"""Creates a variable.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
|
||||||
|
which is the initial value for the Variable. The initial value must have
|
||||||
|
a shape specified unless `validate_shape` is set to False. Can also be a
|
||||||
|
callable with no argument that returns the initial value when called.
|
||||||
|
(Note that initializer functions from init_ops.py must first be bound
|
||||||
|
to a shape before being used here.)
|
||||||
|
trainable: If `True`, the default, also adds the variable to the graph
|
||||||
|
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
|
||||||
|
the default list of variables to use by the `Optimizer` classes.
|
||||||
|
collections: List of graph collections keys. The new variable is added to
|
||||||
|
these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
|
||||||
|
validate_shape: Ignored. Provided for compatibility with tf.Variable.
|
||||||
|
caching_device: Optional device string or function describing where the
|
||||||
|
Variable should be cached for reading. Defaults to the Variable's
|
||||||
|
device. If not `None`, caches on another device. Typical use is to
|
||||||
|
cache on the device where the Ops using the Variable reside, to
|
||||||
|
deduplicate copying through `Switch` and other conditional statements.
|
||||||
|
name: Optional name for the variable. Defaults to `'Variable'` and gets
|
||||||
|
uniquified automatically.
|
||||||
|
dtype: If set, initial_value will be converted to the given type.
|
||||||
|
If None, either the datatype will be kept (if initial_value is
|
||||||
|
a Tensor) or float32 will be used (if it is a Python object convertible
|
||||||
|
to a Tensor).
|
||||||
|
variable_def: `VariableDef` protocol buffer. If not None, recreates the
|
||||||
|
`ResourceVariable` object with its contents. `variable_def` and other
|
||||||
|
arguments (except for import_scope) are mutually exclusive.
|
||||||
|
import_scope: Optional `string`. Name scope to add to the
|
||||||
|
ResourceVariable. Only used when `variable_def` is provided.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the initial value is not specified, or does not have a
|
||||||
|
shape and `validate_shape` is `True`.
|
||||||
|
"""
|
||||||
|
if variable_def:
|
||||||
|
if initial_value:
|
||||||
|
raise ValueError("variable_def and initial_value are mutually "
|
||||||
|
"exclusive.")
|
||||||
|
self._init_from_proto(variable_def, import_scope=import_scope)
|
||||||
|
else:
|
||||||
|
self._init_from_args(initial_value=initial_value,
|
||||||
|
trainable=trainable,
|
||||||
|
collections=collections,
|
||||||
|
validate_shape=validate_shape,
|
||||||
|
caching_device=caching_device,
|
||||||
|
name=name,
|
||||||
|
dtype=dtype)
|
||||||
|
|
||||||
|
# pylint: disable=unused-argument
|
||||||
|
def _init_from_args(self,
|
||||||
initial_value=None,
|
initial_value=None,
|
||||||
trainable=True,
|
trainable=True,
|
||||||
collections=None,
|
collections=None,
|
||||||
@ -139,10 +201,9 @@ class ResourceVariable(object):
|
|||||||
self._is_initialized_op = (
|
self._is_initialized_op = (
|
||||||
gen_resource_variable_ops.var_is_initialized_op(self._handle))
|
gen_resource_variable_ops.var_is_initialized_op(self._handle))
|
||||||
if initial_value is not None:
|
if initial_value is not None:
|
||||||
with ops.name_scope("Create"):
|
with ops.name_scope("Assign") as n, ops.colocate_with(self._handle):
|
||||||
self._initialize_op = gen_resource_variable_ops.assign_variable_op(
|
self._initialize_op = gen_resource_variable_ops.assign_variable_op(
|
||||||
self._handle, self._initial_value)
|
self._handle, self._initial_value, name=n)
|
||||||
|
|
||||||
with ops.name_scope("Read"):
|
with ops.name_scope("Read"):
|
||||||
self._value = gen_resource_variable_ops.read_variable_op(
|
self._value = gen_resource_variable_ops.read_variable_op(
|
||||||
self._handle, dtype=self._dtype)
|
self._handle, dtype=self._dtype)
|
||||||
@ -158,6 +219,32 @@ class ResourceVariable(object):
|
|||||||
self._value.initializer = self._initialize_op
|
self._value.initializer = self._initialize_op
|
||||||
ops.add_to_collections(collections, self)
|
ops.add_to_collections(collections, self)
|
||||||
|
|
||||||
|
def _init_from_proto(self, variable_def, import_scope=None):
|
||||||
|
"""Initializes from `VariableDef` proto."""
|
||||||
|
assert isinstance(variable_def, variable_pb2.VariableDef)
|
||||||
|
if not variable_def.is_resource:
|
||||||
|
raise ValueError("Trying to restore Variable as ResourceVariable.")
|
||||||
|
|
||||||
|
# Create from variable_def.
|
||||||
|
g = ops.get_default_graph()
|
||||||
|
self._handle = g.as_graph_element(
|
||||||
|
ops.prepend_name_scope(variable_def.variable_name,
|
||||||
|
import_scope=import_scope))
|
||||||
|
self._initialize_op = g.as_graph_element(
|
||||||
|
ops.prepend_name_scope(variable_def.initializer_name,
|
||||||
|
import_scope=import_scope))
|
||||||
|
self._cached_value = g.as_graph_element(
|
||||||
|
ops.prepend_name_scope(variable_def.snapshot_name,
|
||||||
|
import_scope=import_scope))
|
||||||
|
self._value = self._cached_value
|
||||||
|
if variable_def.HasField("save_slice_info_def"):
|
||||||
|
self._save_slice_info = variables.Variable.SaveSliceInfo(
|
||||||
|
save_slice_info_def=variable_def.save_slice_info_def)
|
||||||
|
else:
|
||||||
|
self._save_slice_info = None
|
||||||
|
self._caching_device = None
|
||||||
|
self._dtype = self._handle.op.get_attr("dtype")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
"""The dtype of this variable."""
|
"""The dtype of this variable."""
|
||||||
@ -247,6 +334,38 @@ class ResourceVariable(object):
|
|||||||
_register_variable_read(value, collections=collections, trainable=trainable)
|
_register_variable_read(value, collections=collections, trainable=trainable)
|
||||||
return array_ops.identity(value)
|
return array_ops.identity(value)
|
||||||
|
|
||||||
|
def to_proto(self, export_scope=None):
|
||||||
|
"""Converts a `ResourceVariable` to a `VariableDef` protocol buffer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
export_scope: Optional `string`. Name scope to remove.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `VariableDef` protocol buffer, or `None` if the `Variable` is not
|
||||||
|
in the specified name scope.
|
||||||
|
"""
|
||||||
|
if (export_scope is None or
|
||||||
|
self.handle.name.startswith(export_scope)):
|
||||||
|
var_def = variable_pb2.VariableDef()
|
||||||
|
var_def.variable_name = ops.strip_name_scope(
|
||||||
|
self.handle.name, export_scope)
|
||||||
|
var_def.initializer_name = ops.strip_name_scope(
|
||||||
|
self.initializer.name, export_scope)
|
||||||
|
var_def.snapshot_name = ops.strip_name_scope(
|
||||||
|
self.value().name, export_scope)
|
||||||
|
var_def.is_resource = True
|
||||||
|
if self._save_slice_info:
|
||||||
|
var_def.save_slice_info_def.MergeFrom(self._save_slice_info.to_proto(
|
||||||
|
export_scope=export_scope))
|
||||||
|
return var_def
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_proto(variable_def, import_scope=None):
|
||||||
|
return ResourceVariable(variable_def=variable_def,
|
||||||
|
import_scope=import_scope)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _OverloadAllOperators(): # pylint: disable=invalid-name
|
def _OverloadAllOperators(): # pylint: disable=invalid-name
|
||||||
"""Register overloads for all operators."""
|
"""Register overloads for all operators."""
|
||||||
@ -325,3 +444,32 @@ ops.register_tensor_conversion_function(ResourceVariable, _dense_var_to_tensor)
|
|||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
ResourceVariable._OverloadAllOperators()
|
ResourceVariable._OverloadAllOperators()
|
||||||
ops.register_dense_tensor_like_type(ResourceVariable)
|
ops.register_dense_tensor_like_type(ResourceVariable)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_proto_fn(v, export_scope=None):
|
||||||
|
"""Converts Variable and ResourceVariable to VariableDef for collections."""
|
||||||
|
return v.to_proto(export_scope=export_scope)
|
||||||
|
|
||||||
|
|
||||||
|
def _from_proto_fn(v, import_scope=None):
|
||||||
|
"""Creates Variable or ResourceVariable from VariableDef as needed."""
|
||||||
|
if v.is_resource:
|
||||||
|
return ResourceVariable.from_proto(v, import_scope=import_scope)
|
||||||
|
return variables.Variable.from_proto(v, import_scope=import_scope)
|
||||||
|
|
||||||
|
|
||||||
|
ops.register_proto_function(
|
||||||
|
ops.GraphKeys.GLOBAL_VARIABLES,
|
||||||
|
proto_type=variable_pb2.VariableDef,
|
||||||
|
to_proto=_to_proto_fn,
|
||||||
|
from_proto=_from_proto_fn)
|
||||||
|
ops.register_proto_function(
|
||||||
|
ops.GraphKeys.TRAINABLE_VARIABLES,
|
||||||
|
proto_type=variable_pb2.VariableDef,
|
||||||
|
to_proto=_to_proto_fn,
|
||||||
|
from_proto=_from_proto_fn)
|
||||||
|
ops.register_proto_function(
|
||||||
|
ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
|
||||||
|
proto_type=variable_pb2.VariableDef,
|
||||||
|
to_proto=_to_proto_fn,
|
||||||
|
from_proto=_from_proto_fn)
|
||||||
|
@ -1326,19 +1326,5 @@ ops.register_tensor_conversion_function(
|
|||||||
PartitionedVariable, PartitionedVariable._TensorConversionFunction)
|
PartitionedVariable, PartitionedVariable._TensorConversionFunction)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
ops.register_dense_tensor_like_type(Variable)
|
ops.register_dense_tensor_like_type(Variable)
|
||||||
ops.register_proto_function(
|
|
||||||
ops.GraphKeys.GLOBAL_VARIABLES,
|
|
||||||
proto_type=variable_pb2.VariableDef,
|
|
||||||
to_proto=Variable.to_proto,
|
|
||||||
from_proto=Variable.from_proto)
|
|
||||||
ops.register_proto_function(
|
|
||||||
ops.GraphKeys.TRAINABLE_VARIABLES,
|
|
||||||
proto_type=variable_pb2.VariableDef,
|
|
||||||
to_proto=Variable.to_proto,
|
|
||||||
from_proto=Variable.from_proto)
|
|
||||||
ops.register_proto_function(
|
|
||||||
ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
|
|
||||||
proto_type=variable_pb2.VariableDef,
|
|
||||||
to_proto=Variable.to_proto,
|
|
||||||
from_proto=Variable.from_proto)
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user