diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b071fc5db58..72c77407bc8 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -3072,6 +3072,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/executor.h", "common_runtime/executor_factory.h", "common_runtime/graph_optimizer.h", + "common_runtime/isolate_placer_inspection_required_ops_pass.h", "common_runtime/local_device.h", "common_runtime/lower_function_call_op.h", "common_runtime/lower_if_op.h", @@ -3133,6 +3134,7 @@ tf_cuda_library( "common_runtime/graph_optimizer.cc", "common_runtime/graph_runner.cc", "common_runtime/hierarchical_tree_broadcaster.cc", + "common_runtime/isolate_placer_inspection_required_ops_pass.cc", "common_runtime/local_device.cc", "common_runtime/lower_function_call_op.cc", "common_runtime/lower_functional_ops.cc", @@ -3910,6 +3912,7 @@ tf_cc_tests( "common_runtime/collective_rma_local_test.cc", "common_runtime/device_resolver_local_test.cc", "common_runtime/device_set_test.cc", + "common_runtime/isolate_placer_inspection_required_ops_pass_test.cc", "common_runtime/optimization_registry_test.cc", "common_runtime/pending_counts_test.cc", "common_runtime/placer_inspection_required_ops_utils_test.cc", diff --git a/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.cc b/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.cc new file mode 100644 index 00000000000..bc1915326d7 --- /dev/null +++ b/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.cc @@ -0,0 +1,54 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.h" + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/dump_graph.h" + +namespace tensorflow { + +Status IsolatePlacerInspectionRequiredOpsPass::Run( + const GraphOptimizationPassOptions& options) { + if (options.graph == nullptr) { + VLOG(1) << "Not running IsolatePlacerInspectionRequiredOpsPass because no " + "graph is provided"; + return Status::OK(); + } + + VLOG(1) << "IsolatePlacerInspectionRequiredOpsPass::Run"; + + Graph* graph = options.graph->get(); + if (VLOG_IS_ON(3)) { + DumpGraphToFile("isolate_deep_ops_before", *graph, nullptr, "/tmp"); + } + + const FunctionLibraryDefinition* flib_def = + options.flib_def == nullptr ? &graph->flib_def() : options.flib_def; + Status status = IsolatePlacerInspectionRequiredOps(*flib_def, graph); + + if (VLOG_IS_ON(3) && status.ok()) { + DumpGraphToFile("isolate_deep_ops_after", *graph, nullptr, "/tmp"); + } + return status; +} + +REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 25, + IsolatePlacerInspectionRequiredOpsPass); + +} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.h b/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.h new file mode 100644 index 00000000000..3d86c4538f0 --- /dev/null +++ b/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.h @@ -0,0 +1,63 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_ISOLATE_PLACER_INSPECTION_REQUIRED_OPS_PASS_H_ +#define TENSORFLOW_CORE_COMMON_RUNTIME_ISOLATE_PLACER_INSPECTION_REQUIRED_OPS_PASS_H_ + +#include "tensorflow/core/common_runtime/optimization_registry.h" + +namespace tensorflow { +// Adds Identities for each input/output of function-calling ops. +// +// For example, the following graph calling a function on inputs `a` and `b` +// and producing output `y` will be rewritted to include identities on all +// edges: +// +// a b +// | | +// v v +// f (PartitionedCallOp) +// | +// v +// y +// +// is transformed to +// +// a b +// | | +// a_f (Identity) a_f (Identity) +// | | +// v v +// f (PartitionedCallOp) +// | +// f_y (Identity) +// | +// v +// y +// +// This pass is currently needed to simplify correctly placing the nodes +// producing inputs for as well as consuming output from function-calling ops. +// +// This pass should also help to implement replacing PartitionedCallOp with +// component function calls (to avoid copying input/output tensors), if we get +// to it. +class IsolatePlacerInspectionRequiredOpsPass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_COMMON_RUNTIME_ISOLATE_PLACER_INSPECTION_REQUIRED_OPS_PASS_H_ diff --git a/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass_test.cc b/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass_test.cc new file mode 100644 index 00000000000..6fb01c3b28a --- /dev/null +++ b/tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass_test.cc @@ -0,0 +1,437 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/common_runtime/isolate_placer_inspection_required_ops_pass.h" + +#include +#include + +#include "absl/memory/memory.h" +#include "absl/strings/str_join.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/util/equal_graph_def.h" + +namespace tensorflow { + +using ::tensorflow::test::function::GDef; +using ::tensorflow::test::function::NDef; +using FDH = ::tensorflow::FunctionDefHelper; + +// Returns void so that we can call TF_ASSERT_OK inside it. +void RunPass(const GraphDef& original, GraphDef* rewritten, + FunctionLibraryDefinition* flib_def = nullptr) { + std::unique_ptr graph = absl::make_unique(OpRegistry::Global()); + GraphConstructorOptions opts; + TF_ASSERT_OK(ConvertGraphDefToGraph(opts, original, graph.get())); + GraphOptimizationPassOptions options; + options.graph = &graph; + options.flib_def = flib_def; + IsolatePlacerInspectionRequiredOpsPass pass; + TF_ASSERT_OK(pass.Run(options)); + graph->ToGraphDef(rewritten); +} + +void RunPassAndCompare(const GraphDef& original, const GraphDef& expected) { + GraphDef rewritten; + RunPass(original, &rewritten); + TF_EXPECT_GRAPH_EQ(expected, rewritten); +} + +void RunPassAndCompare(const GraphDef& original, + const std::vector& expected_alternatives) { + GraphDef rewritten; + RunPass(original, &rewritten); + + std::vector errors; + errors.push_back(absl::StrCat("Graphs did not match.\n Rewritten graph:\n", + SummarizeGraphDef(rewritten))); + for (const GraphDef& alternative : expected_alternatives) { + string diff; + bool graphs_equal = EqualGraphDef(rewritten, alternative, &diff); + if (graphs_equal) { + return; + } + errors.push_back(absl::StrCat(" Expected alternative:\n", + SummarizeGraphDef(alternative))); + } + EXPECT_TRUE(false) << absl::StrJoin(errors, "\n"); +} + +TEST(IsolatePlacerInspectionRequiredOpsPassTest, Basic) { + /* + * x (_Arg, DT_RESOURCE) + * | + * v + * f (PartitionedCallOp: ResourceIdentity) + * | + * v + * y (_Retval, DT_RESOURCE) + */ + FunctionDef func = test::function::ResourceIdentity(); + GraphDef original = GDef( + { + NDef("x", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("f", "PartitionedCall", {"x"}, + {{"Tin", DataTypeSlice{DT_RESOURCE}}, + {"Tout", DataTypeSlice{DT_RESOURCE}}, + {"f", FDH::FunctionRef("ResourceIdentity", {})}}), + NDef("y", "_Retval", {"f:0"}, {{"T", DT_RESOURCE}}), + }, + // FunctionLib + {func}); + + GraphDef expected = GDef( + { + NDef("x", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("x_f", "Identity", {"x"}, {{"T", DT_RESOURCE}}), + NDef("f", "PartitionedCall", {"x_f"}, + {{"Tin", DataTypeSlice{DT_RESOURCE}}, + {"Tout", DataTypeSlice{DT_RESOURCE}}, + {"f", FDH::FunctionRef("ResourceIdentity", {})}}), + NDef("f_y", "Identity", {"f:0"}, {{"T", DT_RESOURCE}}), + NDef("y", "_Retval", {"f_y:0"}, {{"T", DT_RESOURCE}}), + }, + // FunctionLib + {func}); + + RunPassAndCompare(original, expected); +} + +TEST(IsolatePlacerInspectionRequiredOpsPassTest, FunctionDefinitionNotInGraph) { + /* + * x (_Arg, DT_RESOURCE) + * | + * v + * f (PartitionedCallOp: ResourceIdentity) + * | + * v + * y (_Retval, DT_RESOURCE) + */ + FunctionDef func = test::function::ResourceIdentity(); + GraphDef original = GDef({ + NDef("x", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("f", "PartitionedCall", {"x"}, + {{"Tin", DataTypeSlice{DT_RESOURCE}}, + {"Tout", DataTypeSlice{DT_RESOURCE}}, + {"f", FDH::FunctionRef("ResourceIdentity", {})}}), + NDef("y", "_Retval", {"f:0"}, {{"T", DT_RESOURCE}}), + }); + + GraphDef expected = GDef({ + NDef("x", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("x_f", "Identity", {"x"}, {{"T", DT_RESOURCE}}), + NDef("f", "PartitionedCall", {"x_f"}, + {{"Tin", DataTypeSlice{DT_RESOURCE}}, + {"Tout", DataTypeSlice{DT_RESOURCE}}, + {"f", FDH::FunctionRef("ResourceIdentity", {})}}), + NDef("f_y", "Identity", {"f:0"}, {{"T", DT_RESOURCE}}), + NDef("y", "_Retval", {"f_y:0"}, {{"T", DT_RESOURCE}}), + }); + + FunctionLibraryDefinition flib_def(OpRegistry::Global(), {}); + TF_ASSERT_OK(flib_def.AddFunctionDef(func)); + GraphDef rewritten; + RunPass(original, &rewritten, &flib_def); + TF_EXPECT_GRAPH_EQ(expected, rewritten); +} + +TEST(IsolatePlacerInspectionRequiredOpsPassTest, MultipleInputsAndOutputs) { + /* + * a (_Arg, DT_RESOURCE) + * | b (_Arg, DT_RESOURCE) + * | | + * v v + * f (PartitionedCallOp: Swap) + * | | + * | v + * v r2 (_Retval, DT_RESOURCE) + * r1 (_Retval, DT_RESOURCE) + */ + FunctionDef func = test::function::Swap(); + GraphDef original = GDef( + { + NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("f", "PartitionedCall", {"a", "b"}, + {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}}), + NDef("r1", "_Retval", {"f:0"}, {{"T", DT_RESOURCE}}), + NDef("r2", "_Retval", {"f:1"}, {{"T", DT_RESOURCE}}), + }, + // FunctionLib + {func}); + + GraphDef expected = GDef( + { + NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("a_f", "Identity", {"a"}, {{"T", DT_RESOURCE}}), + NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("b_f", "Identity", {"b"}, {{"T", DT_RESOURCE}}), + NDef("f", "PartitionedCall", {"a_f", "b_f"}, + {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}}), + NDef("f_r1", "Identity", {"f:0"}, {{"T", DT_RESOURCE}}), + NDef("r1", "_Retval", {"f_r1"}, {{"T", DT_RESOURCE}}), + NDef("f_r2", "Identity", {"f:1"}, {{"T", DT_RESOURCE}}), + NDef("r2", "_Retval", {"f_r2"}, {{"T", DT_RESOURCE}}), + }, + // FunctionLib + {func}); + + RunPassAndCompare(original, expected); +} + +TEST(IsolatePlacerInspectionRequiredOpsPassTest, UnusedOutput) { + /* + * a (_Arg, DT_RESOURCE) + * | b (_Arg, DT_RESOURCE) + * | | + * v v + * f (PartitionedCallOp: Swap) + * | | + * | v + * v + * r1 (_Retval, DT_RESOURCE) + */ + FunctionDef func = test::function::Swap(); + GraphDef original = GDef( + { + NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("f", "PartitionedCall", {"a", "b"}, + {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}}), + NDef("r1", "_Retval", {"f:0"}, {{"T", DT_RESOURCE}}), + }, + // FunctionLib + {func}); + + GraphDef expected = GDef( + { + NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("a_f", "Identity", {"a"}, {{"T", DT_RESOURCE}}), + NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("b_f", "Identity", {"b"}, {{"T", DT_RESOURCE}}), + NDef("f", "PartitionedCall", {"a_f", "b_f"}, + {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}}), + NDef("f_r1", "Identity", {"f:0"}, {{"T", DT_RESOURCE}}), + NDef("r1", "_Retval", {"f_r1"}, {{"T", DT_RESOURCE}}), + // Identity is created for output that was not used. + NDef("f_0", "Identity", {"f:1"}, {{"T", DT_RESOURCE}}), + }, + // FunctionLib + {func}); + + RunPassAndCompare(original, expected); +} + +TEST(IsolatePlacerInspectionRequiredOpsPassTest, OutputsConsumedBySameOp) { + /* + * a (_Arg, DT_RESOURCE) + * | b (_Arg, DT_RESOURCE) + * | | + * v v + * f (PartitionedCallOp: Swap) + * | | + * | | + * v v + * add (Add, DT_RESOURCE) + */ + FunctionDef func = test::function::Swap(); + GraphDef original = GDef( + { + NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("f", "PartitionedCall", {"a", "b"}, + {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}}), + NDef("add", "Add", {"f:0", "f:1"}, {{"T", DT_RESOURCE}}), + }, + // FunctionLib + {func}); + + // There are two possible namings for outputs depending on map + // iteration order. + GraphDef expected1 = GDef( + { + NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("a_f", "Identity", {"a"}, {{"T", DT_RESOURCE}}), + NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("b_f", "Identity", {"b"}, {{"T", DT_RESOURCE}}), + NDef("f", "PartitionedCall", {"a_f", "b_f"}, + {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}}), + NDef("f_add", "Identity", {"f:0"}, {{"T", DT_RESOURCE}}), + NDef("f_add_0", "Identity", {"f:1"}, {{"T", DT_RESOURCE}}), + NDef("add", "Add", {"f_add", "f_add_0"}, {{"T", DT_RESOURCE}}), + }, + // FunctionLib + {func}); + + GraphDef expected2 = GDef( + { + // Same as above + NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("a_f", "Identity", {"a"}, {{"T", DT_RESOURCE}}), + NDef("b", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("b_f", "Identity", {"b"}, {{"T", DT_RESOURCE}}), + NDef("f", "PartitionedCall", {"a_f", "b_f"}, + {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}}), + // Different from above + NDef("f_add", "Identity", {"f:1"}, {{"T", DT_RESOURCE}}), + NDef("f_add_0", "Identity", {"f:0"}, {{"T", DT_RESOURCE}}), + NDef("add", "Add", {"f_add_0", "f_add"}, {{"T", DT_RESOURCE}}), + }, + // FunctionLib + {func}); + + RunPassAndCompare(original, {expected1, expected2}); +} + +TEST(IsolatePlacerInspectionRequiredOpsPassTest, IdenticalInputs) { + /* + * a (_Arg, DT_RESOURCE) + * | | + * | | + * v v + * f (PartitionedCallOp: Swap) + * | | + * | v + * v r2 (_Retval, DT_RESOURCE) + * r1 (_Retval, DT_RESOURCE) + */ + FunctionDef func = test::function::Swap(); + GraphDef original = GDef( + { + NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("f", "PartitionedCall", {"a", "a"}, + {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}}), + NDef("r1", "_Retval", {"f:0"}, {{"T", DT_RESOURCE}}), + NDef("r2", "_Retval", {"f:1"}, {{"T", DT_RESOURCE}}), + }, + // FunctionLib + {func}); + + // There are two possible namings for outputs depending on map + // iteration order. + GraphDef expected1 = GDef( + { + NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("a_f", "Identity", {"a"}, {{"T", DT_RESOURCE}}), + NDef("a_f_0", "Identity", {"a"}, {{"T", DT_RESOURCE}}), + NDef("f", "PartitionedCall", {"a_f", "a_f_0"}, + {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}}), + NDef("f_r1", "Identity", {"f:0"}, {{"T", DT_RESOURCE}}), + NDef("r1", "_Retval", {"f_r1"}, {{"T", DT_RESOURCE}}), + NDef("f_r2", "Identity", {"f:1"}, {{"T", DT_RESOURCE}}), + NDef("r2", "_Retval", {"f_r2"}, {{"T", DT_RESOURCE}}), + }, + // FunctionLib + {func}); + + GraphDef expected2 = GDef( + { + NDef("a", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("a_f", "Identity", {"a"}, {{"T", DT_RESOURCE}}), + NDef("a_f_0", "Identity", {"a"}, {{"T", DT_RESOURCE}}), + NDef("f", "PartitionedCall", + {"a_f_0", "a_f"}, // the only different line from above + {{"Tin", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"Tout", DataTypeSlice{DT_RESOURCE, DT_RESOURCE}}, + {"f", FDH::FunctionRef("Swap", {{"T", DT_RESOURCE}})}}), + NDef("f_r1", "Identity", {"f:0"}, {{"T", DT_RESOURCE}}), + NDef("r1", "_Retval", {"f_r1"}, {{"T", DT_RESOURCE}}), + NDef("f_r2", "Identity", {"f:1"}, {{"T", DT_RESOURCE}}), + NDef("r2", "_Retval", {"f_r2"}, {{"T", DT_RESOURCE}}), + }, + // FunctionLib + {func}); + + RunPassAndCompare(original, {expected1, expected2}); +} + +TEST(IsolatePlacerInspectionRequiredOpsPassTest, DirectCallsAreNotIsolated) { + /* + * x (_Arg, DT_RESOURCE) + * | + * v + * f (direct function call to ResourceIdentity) + * | + * v + * y (_Retval, DT_RESOURCE) + */ + FunctionDef func = test::function::ResourceIdentity(); + GraphDef original = GDef( + { + NDef("x", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("f", "ResourceIdentity", {"x"}), + NDef("y", "_Retval", {"f:0"}, {{"T", DT_RESOURCE}}), + }, + // FunctionLib + {func}); + + RunPassAndCompare(original, original); +} + +TEST(IsolatePlacerInspectionRequiredOpsPassTest, + FunctionsNotReturningResourcesAreNotIsolated) { + /* + * x (_Arg, DT_RESOURCE) + * | + * v + * f (PartitionedCallOp, ReadResourceVariable) + * | + * v + * y (_Retval, DT_FLOAT) + */ + FunctionDef func = test::function::ReadResourceVariable(); + GraphDef original = GDef( + { + NDef("x", "_Arg", {}, {{"T", DT_RESOURCE}}), + NDef("f", "PartitionedCall", {"x"}, + {{"Tin", DataTypeSlice{DT_RESOURCE}}, + {"Tout", DataTypeSlice{DT_FLOAT}}, + {"f", FDH::FunctionRef("ReadResourceVariable", {})}}), + NDef("y", "_Retval", {"f:0"}, {{"T", DT_FLOAT}}), + }, + // FunctionLib + {func}); + + RunPassAndCompare(original, original); +} + +} // namespace tensorflow