Set handle data of function inputs and outputs.

Fixes bug when taking gradients of nested functions.

PiperOrigin-RevId: 346241008
Change-Id: I50fc01bd5d971d272058b13b5964f8cdc28b00d0
This commit is contained in:
Katherine Wu 2020-12-07 20:48:20 -08:00 committed by TensorFlower Gardener
parent 57b591e3ac
commit 8318ab26da
5 changed files with 93 additions and 21 deletions
tensorflow

View File

@ -8,6 +8,7 @@ option java_package = "org.tensorflow.framework";
option go_package = "github.com/tensorflow/tensorflow/tensorflow/go/core/framework/op_def_go_proto";
import "tensorflow/core/framework/attr_value.proto";
import "tensorflow/core/framework/types.proto";
import "tensorflow/core/framework/resource_handle.proto";
// Defines an operation. A NodeDef in a GraphDef specifies an Op by
// using the "op" field which should match the name of a OpDef.
@ -42,6 +43,9 @@ message OpDef {
// type, type_attr, and number_attr may be specified.
string type_list_attr = 6;
// The handle data for resource inputs.
repeated ResourceHandleProto.DtypeAndShape handle_data = 7;
// For inputs: if true, the inputs are required to be refs.
// By default, inputs can be either refs or non-refs.
// For outputs: if true, outputs are refs, otherwise they are not.

View File

@ -18,16 +18,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from tensorflow.core.framework import function_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import tensor_shape_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.framework import versions_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import versions
from tensorflow.python.framework.func_graph import FuncGraph
from tensorflow.python.ops import resource_variable_ops
def function_def_to_graph(fdef, input_shapes=None):
@ -84,6 +89,9 @@ def function_def_to_graph(fdef, input_shapes=None):
func_graph.get_operation_by_name(fdef.control_ret[ret_name])
for ret_name in fdef.signature.control_output
]
_set_handle_data(func_graph, fdef)
for node in graph_def.node:
output_shapes = node.attr.get("_output_shapes", None)
if output_shapes is not None:
@ -264,3 +272,19 @@ def _get_num_args(arg_def, node_def):
return 1
else:
raise ValueError("Invalid arg_def:\n\n{}".format(str(arg_def)))
def _set_handle_data(func_graph, fdef):
"""Adds handle data for resource type inputs and outputs."""
for tensor, arg_def in itertools.chain(
zip(func_graph.inputs, fdef.signature.input_arg),
zip(func_graph.outputs, fdef.signature.output_arg)):
if arg_def.handle_data:
shape_and_dtype = arg_def.handle_data[0]
handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData()
handle_data.is_set = True
handle_data.shape_and_type.append(
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType(
shape=shape_and_dtype.shape, dtype=shape_and_dtype.dtype))
resource_variable_ops._set_handle_shapes_and_types( # pylint: disable=protected-access
tensor, handle_data, True)

View File

@ -50,6 +50,7 @@ from tensorflow.python.eager import monitoring
from tensorflow.python.eager import tape
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import composite_tensor
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@ -3292,18 +3293,18 @@ class Graph(object):
continue
# TODO(b/141471245): Fix the inconsistency when inputs of func graph
# are appended during gradient computation of while/cond.
for input_tensor, _ in zip(func_graph_inputs,
function_def.signature.input_arg):
for input_tensor, arg_def in zip(func_graph_inputs,
function_def.signature.input_arg):
input_shapes.list.shape.add().CopyFrom(
input_tensor.get_shape().as_proto())
if input_tensor.dtype == dtypes.resource:
# TODO(allenl): Save and restore handle data, then save the
# resource placeholder's shape. Right now some shape functions get
# confused if we set the shape of the resource placeholder (to a
# scalar of course) and there isn't any handle data.
input_shapes.list.shape.add().CopyFrom(
tensor_shape.TensorShape(None).as_proto())
else:
input_shapes.list.shape.add().CopyFrom(
input_tensor.get_shape().as_proto())
_copy_handle_data_to_arg_def(input_tensor, arg_def)
for output_tensor, arg_def in zip(func_graph.outputs,
function_def.signature.output_arg):
if output_tensor.dtype == dtypes.resource:
_copy_handle_data_to_arg_def(output_tensor, arg_def)
for node in function_def.node_def:
try:
op = func_graph.get_operation_by_name(node.name)
@ -6979,3 +6980,22 @@ def _get_enclosing_context(graph):
if graph.building_function and hasattr(graph, "outer_graph"):
return _get_enclosing_context(graph.outer_graph)
def get_resource_handle_data(graph_op):
assert type(graph_op) == Tensor # pylint: disable=unidiomatic-typecheck
handle_data = pywrap_tf_session.GetHandleShapeAndType(
graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access
return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
compat.as_bytes(handle_data))
def _copy_handle_data_to_arg_def(tensor, arg_def):
handle_data = get_resource_handle_data(tensor)
if handle_data.shape_and_type:
shape_and_type = handle_data.shape_and_type[0]
proto = arg_def.handle_data.add()
proto.dtype = shape_and_type.dtype
proto.shape.CopyFrom(handle_data.shape_and_type[0].shape)

View File

@ -19,20 +19,11 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.util import compat
def get_resource_handle_data(graph_op):
assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck
handle_data = pywrap_tf_session.GetHandleShapeAndType(
graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access
return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
compat.as_bytes(handle_data))
get_resource_handle_data = ops.get_resource_handle_data
def copy_handle_data(source_t, target_t):

View File

@ -798,6 +798,39 @@ class LoadTest(test.TestCase, parameterized.TestCase):
self.assertIsNotNone(imported_gradient)
self.assertAllClose(imported_gradient, 2.)
def test_nested_fn_backprop(self, cycles):
weight = variables.Variable(2., trainable=True)
@def_function.function(input_signature=[
tensor_spec.TensorSpec(dtype=dtypes.float32, shape=(None, None))])
def g(x):
weight.read_value() # Just get the tape to watch the variable
handle = array_ops.identity(weight.handle)
@def_function.function
def launder_var_handle():
return array_ops.identity(handle)
return x + resource_variable_ops.read_variable_op(
launder_var_handle(), dtypes.float32)
root = tracking.AutoTrackable()
root.weight = weight
root.g = g
imported = cycle(root, cycles)
def get_gradient(obj, persistent):
with backprop.GradientTape(persistent=persistent) as t:
x = constant_op.constant([[1., 2., 3.], [1., -2, 3.]])
y = obj.g(x)
self.assertAllClose(y, obj.weight + x)
loss = math_ops.reduce_sum(y)
return t.gradient(loss, obj.weight)
imported_gradient = get_gradient(imported, persistent=False)
original_gradient = get_gradient(root, persistent=False)
self.assertIsNotNone(original_gradient)
self.assertAllClose(original_gradient, 6.)
self.assertIsNotNone(imported_gradient)
self.assertAllClose(imported_gradient, 6.)
def test_restored_func_with_captured_var_backprop_float32(self, cycles):
self._test_restored_func_with_captured_var_backprop(cycles, dtypes.float32)