Enable importing graph with tf.function nested in a while body inside a tf.function and computing its gradient.
PiperOrigin-RevId: 320691209 Change-Id: I77c88ec90bf8d1e242e31d49196e3abc3d8f9e9b
This commit is contained in:
parent
9b1ac5cd54
commit
f085449f2b
@ -20,6 +20,8 @@ from __future__ import print_function
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from google.protobuf import text_format
|
||||
from tensorflow.core.framework import graph_pb2
|
||||
from tensorflow.core.protobuf import config_pb2
|
||||
from tensorflow.core.protobuf import rewriter_config_pb2
|
||||
from tensorflow.python.eager import backprop
|
||||
@ -28,6 +30,7 @@ from tensorflow.python.eager import def_function
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import function
|
||||
from tensorflow.python.framework import importer
|
||||
from tensorflow.python.framework import meta_graph
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
@ -1238,6 +1241,575 @@ class WhileV2Test(test.TestCase, parameterized.TestCase):
|
||||
config.experimental.executor_type = "SINGLE_THREADED_EXECUTOR"
|
||||
self._runBasicWithConfig(config)
|
||||
|
||||
def testImportFromSerializedWithFunctionInBody(self):
|
||||
serialized = """node {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 1.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/maximum_iterations"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: -1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/loop_counter"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while"
|
||||
op: "StatelessWhile"
|
||||
input: "while/loop_counter"
|
||||
input: "while/maximum_iterations"
|
||||
input: "Const"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT32
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_lower_using_switch_merge"
|
||||
value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_num_original_outputs"
|
||||
value {
|
||||
i: 3
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_read_only_resource_inputs"
|
||||
value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "body"
|
||||
value {
|
||||
func {
|
||||
name: "while_body_822"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "cond"
|
||||
value {
|
||||
func {
|
||||
name: "while_cond_821"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
shape {
|
||||
}
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "parallel_iterations"
|
||||
value {
|
||||
i: 10
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Identity"
|
||||
op: "Identity"
|
||||
input: "while"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Identity_1"
|
||||
op: "Identity"
|
||||
input: "while:1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "while/Identity_2"
|
||||
op: "Identity"
|
||||
input: "while:2"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "while_body_822"
|
||||
input_arg {
|
||||
name: "while_loop_counter"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "while_maximum_iterations_0"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "placeholder"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "add"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "while_maximum_iterations"
|
||||
type: DT_INT32
|
||||
}
|
||||
output_arg {
|
||||
name: "partitionedcall"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "PartitionedCall"
|
||||
op: "PartitionedCall"
|
||||
input: "placeholder"
|
||||
attr {
|
||||
key: "Tin"
|
||||
value {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "Tout"
|
||||
value {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_collective_manager_ids"
|
||||
value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_read_only_resource_inputs"
|
||||
value {
|
||||
list {
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "config"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "config_proto"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "executor_type"
|
||||
value {
|
||||
s: ""
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "f"
|
||||
value {
|
||||
func {
|
||||
name: "__inference_f_841"
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "PartitionedCall"
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "add/y"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_INT32
|
||||
tensor_shape {
|
||||
}
|
||||
int_val: 1
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "add/y"
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "add_0"
|
||||
op: "AddV2"
|
||||
input: "while_loop_counter"
|
||||
input: "add/y:output:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_INT32
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "add"
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "add"
|
||||
value: "add_0:z:0"
|
||||
}
|
||||
ret {
|
||||
key: "partitionedcall"
|
||||
value: "PartitionedCall:output:0"
|
||||
}
|
||||
ret {
|
||||
key: "while_maximum_iterations"
|
||||
value: "while_maximum_iterations_0"
|
||||
}
|
||||
arg_attr {
|
||||
key: 0
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
arg_attr {
|
||||
key: 1
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
arg_attr {
|
||||
key: 2
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
function {
|
||||
signature {
|
||||
name: "while_cond_821"
|
||||
input_arg {
|
||||
name: "while_loop_counter"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "while_maximum_iterations"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "placeholder"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "less"
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "Less/y"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 5.0
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "Less/y"
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "Less"
|
||||
op: "Less"
|
||||
input: "placeholder"
|
||||
input: "Less/y:output:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "Less"
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "less"
|
||||
value: "Less:z:0"
|
||||
}
|
||||
arg_attr {
|
||||
key: 0
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
arg_attr {
|
||||
key: 1
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
arg_attr {
|
||||
key: 2
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
function {
|
||||
signature {
|
||||
name: "__inference_f_841"
|
||||
input_arg {
|
||||
name: "mul_placeholder"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
output_arg {
|
||||
name: "identity"
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "mul/y"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
}
|
||||
float_val: 2.0
|
||||
}
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "mul/y"
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "mul"
|
||||
op: "Mul"
|
||||
input: "mul_placeholder"
|
||||
input: "mul/y:output:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "mul"
|
||||
}
|
||||
}
|
||||
node_def {
|
||||
name: "Identity"
|
||||
op: "Identity"
|
||||
input: "mul:z:0"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
experimental_debug_info {
|
||||
original_node_names: "Identity"
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "identity"
|
||||
value: "Identity:output:0"
|
||||
}
|
||||
arg_attr {
|
||||
key: 0
|
||||
value {
|
||||
attr {
|
||||
key: "_output_shapes"
|
||||
value {
|
||||
list {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
versions {
|
||||
producer: 399
|
||||
min_consumer: 12
|
||||
}
|
||||
"""
|
||||
# Code for generating above graph:
|
||||
#
|
||||
# def Body(i):
|
||||
# @tf.function
|
||||
# def f():
|
||||
# return i * 2
|
||||
# return f()
|
||||
# tf.while_loop(lambda i: i < 5., Body, [tf.constant(1.)])
|
||||
graph_def = graph_pb2.GraphDef()
|
||||
text_format.Parse(serialized, graph_def)
|
||||
@def_function.function
|
||||
def F():
|
||||
x, y = importer.import_graph_def(
|
||||
graph_def, return_elements=["Const:0", "while:2"])
|
||||
grad_out, = gradients_impl.gradients(y, x)
|
||||
return grad_out
|
||||
self.assertAllEqual(F(), 8.0)
|
||||
|
||||
|
||||
def ScalarShape():
|
||||
return ops.convert_to_tensor([], dtype=dtypes.int32)
|
||||
|
@ -608,8 +608,22 @@ def _GradientsHelper(ys,
|
||||
except LookupError:
|
||||
if is_func_call:
|
||||
if is_partitioned_call:
|
||||
func_name = compat.as_bytes(op.get_attr("f").name)
|
||||
func_call = src_graph._get_function( # pylint: disable=protected-access
|
||||
compat.as_bytes(op.get_attr("f").name))
|
||||
func_name)
|
||||
# When a graph is imported, the FunctionDefs are not copied over
|
||||
# to each sub-graph so we recursively search the outer graphs
|
||||
# for the FunctionDef.
|
||||
if not func_call and hasattr(src_graph, "outer_graph"):
|
||||
graph = src_graph.outer_graph
|
||||
while graph is not None:
|
||||
func_call = graph._get_function(func_name) # pylint: disable=protected-access
|
||||
if func_call is not None:
|
||||
break
|
||||
if hasattr(graph, "outer_graph"):
|
||||
graph = graph.outer_graph
|
||||
else:
|
||||
break
|
||||
else:
|
||||
func_call = src_graph._get_function(op.type) # pylint: disable=protected-access
|
||||
# Note that __defun is not set if the graph is
|
||||
|
Loading…
x
Reference in New Issue
Block a user