Enable importing graph with tf.function nested in a while body inside a tf.function and computing its gradient.

PiperOrigin-RevId: 320691209
Change-Id: I77c88ec90bf8d1e242e31d49196e3abc3d8f9e9b
This commit is contained in:
Saurabh Saxena 2020-07-10 15:48:36 -07:00 committed by TensorFlower Gardener
parent 9b1ac5cd54
commit f085449f2b
2 changed files with 587 additions and 1 deletions

View File

@ -20,6 +20,8 @@ from __future__ import print_function
from absl.testing import parameterized
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.eager import backprop
@ -28,6 +30,7 @@ from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import importer
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@ -1238,6 +1241,575 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
config.experimental.executor_type = "SINGLE_THREADED_EXECUTOR"
self._runBasicWithConfig(config)
def testImportFromSerializedWithFunctionInBody(self):
serialized = """node {
name: "Const"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 1.0
}
}
}
}
node {
name: "while/maximum_iterations"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: -1
}
}
}
}
node {
name: "while/loop_counter"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 0
}
}
}
}
node {
name: "while"
op: "StatelessWhile"
input: "while/loop_counter"
input: "while/maximum_iterations"
input: "Const"
attr {
key: "T"
value {
list {
type: DT_INT32
type: DT_INT32
type: DT_FLOAT
}
}
}
attr {
key: "_lower_using_switch_merge"
value {
b: true
}
}
attr {
key: "_num_original_outputs"
value {
i: 3
}
}
attr {
key: "_read_only_resource_inputs"
value {
list {
}
}
}
attr {
key: "body"
value {
func {
name: "while_body_822"
}
}
}
attr {
key: "cond"
value {
func {
name: "while_cond_821"
}
}
}
attr {
key: "output_shapes"
value {
list {
shape {
}
shape {
}
shape {
}
}
}
}
attr {
key: "parallel_iterations"
value {
i: 10
}
}
}
node {
name: "while/Identity"
op: "Identity"
input: "while"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "while/Identity_1"
op: "Identity"
input: "while:1"
attr {
key: "T"
value {
type: DT_INT32
}
}
}
node {
name: "while/Identity_2"
op: "Identity"
input: "while:2"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
library {
function {
signature {
name: "while_body_822"
input_arg {
name: "while_loop_counter"
type: DT_INT32
}
input_arg {
name: "while_maximum_iterations_0"
type: DT_INT32
}
input_arg {
name: "placeholder"
type: DT_FLOAT
}
output_arg {
name: "add"
type: DT_INT32
}
output_arg {
name: "while_maximum_iterations"
type: DT_INT32
}
output_arg {
name: "partitionedcall"
type: DT_FLOAT
}
}
node_def {
name: "PartitionedCall"
op: "PartitionedCall"
input: "placeholder"
attr {
key: "Tin"
value {
list {
type: DT_FLOAT
}
}
}
attr {
key: "Tout"
value {
list {
type: DT_FLOAT
}
}
}
attr {
key: "_collective_manager_ids"
value {
list {
}
}
}
attr {
key: "_read_only_resource_inputs"
value {
list {
}
}
}
attr {
key: "config"
value {
s: ""
}
}
attr {
key: "config_proto"
value {
s: ""
}
}
attr {
key: "executor_type"
value {
s: ""
}
}
attr {
key: "f"
value {
func {
name: "__inference_f_841"
}
}
}
experimental_debug_info {
original_node_names: "PartitionedCall"
}
}
node_def {
name: "add/y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_INT32
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_INT32
tensor_shape {
}
int_val: 1
}
}
}
experimental_debug_info {
original_node_names: "add/y"
}
}
node_def {
name: "add_0"
op: "AddV2"
input: "while_loop_counter"
input: "add/y:output:0"
attr {
key: "T"
value {
type: DT_INT32
}
}
experimental_debug_info {
original_node_names: "add"
}
}
ret {
key: "add"
value: "add_0:z:0"
}
ret {
key: "partitionedcall"
value: "PartitionedCall:output:0"
}
ret {
key: "while_maximum_iterations"
value: "while_maximum_iterations_0"
}
arg_attr {
key: 0
value {
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
}
}
arg_attr {
key: 1
value {
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
}
}
arg_attr {
key: 2
value {
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
}
}
}
function {
signature {
name: "while_cond_821"
input_arg {
name: "while_loop_counter"
type: DT_INT32
}
input_arg {
name: "while_maximum_iterations"
type: DT_INT32
}
input_arg {
name: "placeholder"
type: DT_FLOAT
}
output_arg {
name: "less"
type: DT_BOOL
}
}
node_def {
name: "Less/y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 5.0
}
}
}
experimental_debug_info {
original_node_names: "Less/y"
}
}
node_def {
name: "Less"
op: "Less"
input: "placeholder"
input: "Less/y:output:0"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "Less"
}
}
ret {
key: "less"
value: "Less:z:0"
}
arg_attr {
key: 0
value {
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
}
}
arg_attr {
key: 1
value {
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
}
}
arg_attr {
key: 2
value {
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
}
}
}
function {
signature {
name: "__inference_f_841"
input_arg {
name: "mul_placeholder"
type: DT_FLOAT
}
output_arg {
name: "identity"
type: DT_FLOAT
}
}
node_def {
name: "mul/y"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
}
float_val: 2.0
}
}
}
experimental_debug_info {
original_node_names: "mul/y"
}
}
node_def {
name: "mul"
op: "Mul"
input: "mul_placeholder"
input: "mul/y:output:0"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "mul"
}
}
node_def {
name: "Identity"
op: "Identity"
input: "mul:z:0"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
experimental_debug_info {
original_node_names: "Identity"
}
}
ret {
key: "identity"
value: "Identity:output:0"
}
arg_attr {
key: 0
value {
attr {
key: "_output_shapes"
value {
list {
shape {
}
}
}
}
}
}
}
}
versions {
producer: 399
min_consumer: 12
}
"""
# Code for generating above graph:
#
# def Body(i):
# @tf.function
# def f():
# return i * 2
# return f()
# tf.while_loop(lambda i: i < 5., Body, [tf.constant(1.)])
graph_def = graph_pb2.GraphDef()
text_format.Parse(serialized, graph_def)
@def_function.function
def F():
x, y = importer.import_graph_def(
graph_def, return_elements=["Const:0", "while:2"])
grad_out, = gradients_impl.gradients(y, x)
return grad_out
self.assertAllEqual(F(), 8.0)
def ScalarShape():
return ops.convert_to_tensor([], dtype=dtypes.int32)

View File

@ -608,8 +608,22 @@ def _GradientsHelper(ys,
except LookupError:
if is_func_call:
if is_partitioned_call:
func_name = compat.as_bytes(op.get_attr("f").name)
func_call = src_graph._get_function( # pylint: disable=protected-access
compat.as_bytes(op.get_attr("f").name))
func_name)
# When a graph is imported, the FunctionDefs are not copied over
# to each sub-graph so we recursively search the outer graphs
# for the FunctionDef.
if not func_call and hasattr(src_graph, "outer_graph"):
graph = src_graph.outer_graph
while graph is not None:
func_call = graph._get_function(func_name) # pylint: disable=protected-access
if func_call is not None:
break
if hasattr(graph, "outer_graph"):
graph = graph.outer_graph
else:
break
else:
func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
# Note that __defun is not set if the graph is