Set handle data of function inputs and outputs.
Fixes bug when taking gradients of nested functions. PiperOrigin-RevId: 346241008 Change-Id: I50fc01bd5d971d272058b13b5964f8cdc28b00d0
This commit is contained in:
parent
57b591e3ac
commit
8318ab26da
tensorflow
core/framework
python
@ -8,6 +8,7 @@ option java_package = "org.tensorflow.framework";
|
||||
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/op_def_go_proto";
|
||||
import "tensorflow/core/framework/attr_value.proto";
|
||||
import "tensorflow/core/framework/types.proto";
|
||||
import "tensorflow/core/framework/resource_handle.proto";
|
||||
|
||||
// Defines an operation. A NodeDef in a GraphDef specifies an Op by
|
||||
// using the "op" field which should match the name of a OpDef.
|
||||
@ -42,6 +43,9 @@ message OpDef {
|
||||
// type, type_attr, and number_attr may be specified.
|
||||
string type_list_attr = 6;
|
||||
|
||||
// The handle data for resource inputs.
|
||||
repeated ResourceHandleProto.DtypeAndShape handle_data = 7;
|
||||
|
||||
// For inputs: if true, the inputs are required to be refs.
|
||||
// By default, inputs can be either refs or non-refs.
|
||||
// For outputs: if true, outputs are refs, otherwise they are not.
|
||||
|
@ -18,16 +18,21 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import itertools
|
||||
|
||||
|
||||
from tensorflow.core.framework import function_pb2
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.framework import tensor_shape_pb2
|
||||
from tensorflow.core.framework import types_pb2
|
||||
from tensorflow.core.framework import versions_pb2
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import cpp_shape_inference_pb2
|
||||
from tensorflow.python.framework import importer
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import versions
|
||||
from tensorflow.python.framework.func_graph import FuncGraph
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
|
||||
|
||||
def function_def_to_graph(fdef, input_shapes=None):
|
||||
@ -84,6 +89,9 @@ def function_def_to_graph(fdef, input_shapes=None):
|
||||
func_graph.get_operation_by_name(fdef.control_ret[ret_name])
|
||||
for ret_name in fdef.signature.control_output
|
||||
]
|
||||
|
||||
_set_handle_data(func_graph, fdef)
|
||||
|
||||
for node in graph_def.node:
|
||||
output_shapes = node.attr.get("_output_shapes", None)
|
||||
if output_shapes is not None:
|
||||
@ -264,3 +272,19 @@ def _get_num_args(arg_def, node_def):
|
||||
return 1
|
||||
else:
|
||||
raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def)))
|
||||
|
||||
|
||||
def _set_handle_data(func_graph, fdef):
|
||||
"""Adds handle data for resource type inputs and outputs."""
|
||||
for tensor, arg_def in itertools.chain(
|
||||
zip(func_graph.inputs, fdef.signature.input_arg),
|
||||
zip(func_graph.outputs, fdef.signature.output_arg)):
|
||||
if arg_def.handle_data:
|
||||
shape_and_dtype = arg_def.handle_data[0]
|
||||
handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
|
||||
handle_data.is_set = True
|
||||
handle_data.shape_and_type.append(
|
||||
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
|
||||
shape=shape_and_dtype.shape, dtype=shape_and_dtype.dtype))
|
||||
resource_variable_ops._set_handle_shapes_and_types( # pylint: disable=protected-access
|
||||
tensor, handle_data, True)
|
||||
|
@ -50,6 +50,7 @@ from tensorflow.python.eager import monitoring
|
||||
from tensorflow.python.eager import tape
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import composite_tensor
|
||||
from tensorflow.python.framework import cpp_shape_inference_pb2
|
||||
from tensorflow.python.framework import device as pydev
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
@ -3292,18 +3293,18 @@ class Graph(object):
|
||||
continue
|
||||
# TODO(b/141471245): Fix the inconsistency when inputs of func graph
|
||||
# are appended during gradient computation of while/cond.
|
||||
for input_tensor, _ in zip(func_graph_inputs,
|
||||
function_def.signature.input_arg):
|
||||
for input_tensor, arg_def in zip(func_graph_inputs,
|
||||
function_def.signature.input_arg):
|
||||
input_shapes.list.shape.add().CopyFrom(
|
||||
input_tensor.get_shape().as_proto())
|
||||
if input_tensor.dtype == dtypes.resource:
|
||||
# TODO(allenl): Save and restore handle data, then save the
|
||||
# resource placeholder's shape. Right now some shape functions get
|
||||
# confused if we set the shape of the resource placeholder (to a
|
||||
# scalar of course) and there isn't any handle data.
|
||||
input_shapes.list.shape.add().CopyFrom(
|
||||
tensor_shape.TensorShape(None).as_proto())
|
||||
else:
|
||||
input_shapes.list.shape.add().CopyFrom(
|
||||
input_tensor.get_shape().as_proto())
|
||||
_copy_handle_data_to_arg_def(input_tensor, arg_def)
|
||||
|
||||
for output_tensor, arg_def in zip(func_graph.outputs,
|
||||
function_def.signature.output_arg):
|
||||
if output_tensor.dtype == dtypes.resource:
|
||||
_copy_handle_data_to_arg_def(output_tensor, arg_def)
|
||||
|
||||
for node in function_def.node_def:
|
||||
try:
|
||||
op = func_graph.get_operation_by_name(node.name)
|
||||
@ -6979,3 +6980,22 @@ def _get_enclosing_context(graph):
|
||||
|
||||
if graph.building_function and hasattr(graph, "outer_graph"):
|
||||
return _get_enclosing_context(graph.outer_graph)
|
||||
|
||||
|
||||
def get_resource_handle_data(graph_op):
|
||||
assert type(graph_op) == Tensor # pylint: disable=unidiomatic-typecheck
|
||||
|
||||
handle_data = pywrap_tf_session.GetHandleShapeAndType(
|
||||
graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access
|
||||
|
||||
return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
|
||||
compat.as_bytes(handle_data))
|
||||
|
||||
|
||||
def _copy_handle_data_to_arg_def(tensor, arg_def):
|
||||
handle_data = get_resource_handle_data(tensor)
|
||||
if handle_data.shape_and_type:
|
||||
shape_and_type = handle_data.shape_and_type[0]
|
||||
proto = arg_def.handle_data.add()
|
||||
proto.dtype = shape_and_type.dtype
|
||||
proto.shape.CopyFrom(handle_data.shape_and_type[0].shape)
|
||||
|
@ -19,20 +19,11 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.framework import cpp_shape_inference_pb2
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
|
||||
def get_resource_handle_data(graph_op):
|
||||
assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck
|
||||
|
||||
handle_data = pywrap_tf_session.GetHandleShapeAndType(
|
||||
graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access
|
||||
|
||||
return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
|
||||
compat.as_bytes(handle_data))
|
||||
get_resource_handle_data = ops.get_resource_handle_data
|
||||
|
||||
|
||||
def copy_handle_data(source_t, target_t):
|
||||
|
@ -798,6 +798,39 @@ class LoadTest(test.TestCase, parameterized.TestCase):
|
||||
self.assertIsNotNone(imported_gradient)
|
||||
self.assertAllClose(imported_gradient, 2.)
|
||||
|
||||
def test_nested_fn_backprop(self, cycles):
|
||||
weight = variables.Variable(2., trainable=True)
|
||||
|
||||
@def_function.function(input_signature=[
|
||||
tensor_spec.TensorSpec(dtype=dtypes.float32, shape=(None, None))])
|
||||
def g(x):
|
||||
weight.read_value() # Just get the tape to watch the variable
|
||||
handle = array_ops.identity(weight.handle)
|
||||
@def_function.function
|
||||
def launder_var_handle():
|
||||
return array_ops.identity(handle)
|
||||
return x + resource_variable_ops.read_variable_op(
|
||||
launder_var_handle(), dtypes.float32)
|
||||
|
||||
root = tracking.AutoTrackable()
|
||||
root.weight = weight
|
||||
root.g = g
|
||||
imported = cycle(root, cycles)
|
||||
def get_gradient(obj, persistent):
|
||||
with backprop.GradientTape(persistent=persistent) as t:
|
||||
x = constant_op.constant([[1., 2., 3.], [1., -2, 3.]])
|
||||
y = obj.g(x)
|
||||
self.assertAllClose(y, obj.weight + x)
|
||||
loss = math_ops.reduce_sum(y)
|
||||
return t.gradient(loss, obj.weight)
|
||||
|
||||
imported_gradient = get_gradient(imported, persistent=False)
|
||||
original_gradient = get_gradient(root, persistent=False)
|
||||
self.assertIsNotNone(original_gradient)
|
||||
self.assertAllClose(original_gradient, 6.)
|
||||
self.assertIsNotNone(imported_gradient)
|
||||
self.assertAllClose(imported_gradient, 6.)
|
||||
|
||||
def test_restored_func_with_captured_var_backprop_float32(self, cycles):
|
||||
self._test_restored_func_with_captured_var_backprop(cycles, dtypes.float32)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user