diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index ced0cd03f74..a635608596c 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -359,6 +359,52 @@ tf_cc_test( ], ) +cc_library( + name = "shape_inference", + srcs = ["shape_inference.cc"], + hdrs = ["shape_inference.h"], + deps = [ + ":shape_inference_helpers", + "//tensorflow/compiler/tf2xla:dump_graph", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + deps = [ + ":shape_inference", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "shape_inference_test", + srcs = ["shape_inference_test.cc"], + deps = [ + ":shape_inference", + ":test_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:cc_ops_internal", + "//tensorflow/cc:ops", + "//tensorflow/core:framework", + "//tensorflow/core:ops", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/kernels:constant_op", + ], +) + cc_library( name = "compilation_passes", srcs = [ diff --git a/tensorflow/compiler/jit/shape_inference.cc b/tensorflow/compiler/jit/shape_inference.cc new file mode 100644 index 00000000000..80c691fe490 --- /dev/null +++ b/tensorflow/compiler/jit/shape_inference.cc @@ -0,0 +1,174 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/shape_inference.h" + +#include "tensorflow/compiler/jit/shape_inference_helpers.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/graph/algorithm.h" + +namespace tensorflow { + +namespace { + +// Converts a shape inference handle to a PartialTensorShape. +Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context, + const shape_inference::ShapeHandle& handle, + PartialTensorShape* shape) { + // The default is already unknown + if (!context->RankKnown(handle)) return Status::OK(); + + std::vector dims(context->Rank(handle)); + for (int32 i = 0; i < dims.size(); ++i) { + dims[i] = context->Value(context->Dim(handle, i)); + } + return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape); +} + +Status PropagateShapes(const Graph& graph, + const std::map& arg_shapes, + ShapeRefiner* shape_refiner) { + // Visits the nodes in topological order (reverse post-order), inferring + // shapes. + // TODO(phawkins): handle cyclic graphs. + std::vector order; + GetReversePostOrder(graph, &order); + + for (Node* n : order) { + // Ignore the status returned by the shape_refiner. We want the best effort + // shapes, even if no shape function is registered for a node. + Status status = shape_refiner->AddNode(n); + if (!status.ok()) { + VLOG(1) << "Shape inference failed for node: " << status; + } + + if (n->type_string() == "_Arg") { + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + auto it = arg_shapes.find(index); + if (it != arg_shapes.end()) { + const InferredShape& arg_shape = it->second; + shape_inference::InferenceContext* context = + shape_refiner->GetContext(n); + + if (arg_shape.handle_type != DT_INVALID) { + shape_inference::ShapeHandle handle; + TF_RETURN_IF_ERROR(context->MakeShapeFromPartialTensorShape( + arg_shape.handle_shape, &handle)); + + // Sets the shape and type of the variable's value. + context->set_output_handle_shapes_and_types( + 0, std::vector{ + {handle, arg_shape.handle_type}}); + } + + shape_inference::ShapeHandle handle; + TF_RETURN_IF_ERROR( + context->MakeShapeFromPartialTensorShape(arg_shape.shape, &handle)); + TF_RETURN_IF_ERROR(shape_refiner->SetShape(n, 0, handle)); + } + } + } + return Status::OK(); +} + +// Store the shapes of the output tensors in a map +Status StoreOutputShapes(const Graph& graph, const ShapeRefiner& shape_refiner, + GraphShapeInfo* shape_info) { + for (const Node* node : graph.nodes()) { + shape_inference::InferenceContext* context = shape_refiner.GetContext(node); + if (!context) continue; + + auto& outputs = (*shape_info)[node->name()]; + outputs.resize(context->num_outputs()); + for (int i = 0; i < context->num_outputs(); ++i) { + auto& output = outputs[i]; + TF_RETURN_IF_ERROR( + ShapeHandleToTensorShape(context, context->output(i), &output.shape)); + + const auto* handle_shapes_and_types = + context->output_handle_shapes_and_types(i); + if (handle_shapes_and_types != nullptr) { + if (handle_shapes_and_types->size() == 1) { + TF_RETURN_IF_ERROR(ShapeHandleToTensorShape( + context, (*handle_shapes_and_types)[0].shape, + &output.handle_shape)); + output.handle_type = (*handle_shapes_and_types)[0].dtype; + } else { + // otherwise, it may be resource like a Queue, which can have + // multiple shapes and types represented by a single handle. + } + } + VLOG(4) << node->name() << " output " << i << " shape" + << output.shape.DebugString() << " handle_type " + << DataTypeString(output.handle_type) << " handle_shape " + << output.handle_shape.DebugString(); + } + } + return Status::OK(); +} + +} // namespace + +Status InferShapes(Graph* graph, const std::map& arg_shapes, + const tensorflow::FunctionLibraryDefinition* fnlib_def, + GraphShapeInfo* shape_info) { + ShapeRefiner shape_refiner(graph->versions(), graph->op_registry()); + shape_refiner.set_require_shape_inference_fns(false); + // TODO(dlibenzi): Verify if it is worth trying to infer shaped within + // functions. Some functions can be called at multiple locations with + // difference shapes, which will trigger a shape inference based on the + // arguments passed at the first call. + // shape_refiner.set_function_library_for_shape_inference(fnlib_def); + + // ShapeRefiner requires that all inputs of a node are present when + // ShapeRefiner::AddNode is called. To get at least some shape information in + // loops, we temporarily remove loop backedges and add them back again after + // the shape inference is complete. + BackEdgeHelper back_edge; + TF_RETURN_IF_ERROR(back_edge.Remove(graph)); + TF_RETURN_IF_ERROR(PropagateShapes(*graph, arg_shapes, &shape_refiner)); + TF_RETURN_IF_ERROR(back_edge.Replace()); + + // Currently information does not flow "backward" from consumers to producers + // in the shape inference, but we consume the shapes in a second pass in case + // backward information flow is added in the future. + return StoreOutputShapes(*graph, shape_refiner, shape_info); +} + +xla::StatusOr MergeInferredShapes(const InferredShape& a, + const InferredShape& b) { + InferredShape result; + TF_RETURN_IF_ERROR(a.shape.MergeWith(b.shape, &result.shape)); + + if (a.handle_type == DT_INVALID) { + result.handle_type = b.handle_type; + } else if (b.handle_type == DT_INVALID) { + result.handle_type = a.handle_type; + } else if (a.handle_type == b.handle_type) { + result.handle_type = a.handle_type; + } else { + return errors::InvalidArgument( + "Mismatched resource types: ", DataTypeString(a.handle_type), " vs. ", + DataTypeString(b.handle_type)); + } + TF_RETURN_IF_ERROR( + a.handle_shape.MergeWith(b.handle_shape, &result.handle_shape)); + return result; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/shape_inference.h b/tensorflow/compiler/jit/shape_inference.h new file mode 100644 index 00000000000..8668dbca55c --- /dev/null +++ b/tensorflow/compiler/jit/shape_inference.h @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_H_ +#define TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_H_ + +#include +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +struct InferredShape { + // Shape of the argument tensor. + PartialTensorShape shape; + + // If the argument is a resource variable, the type and shape of the + // variable's value. + DataType handle_type = DT_INVALID; + PartialTensorShape handle_shape; +}; +typedef std::unordered_map> GraphShapeInfo; + +// Infer shapes for all Tensors in a graph, and save them in a map. The vector +// for a Node contains the information about each of its outputs. +// TODO(phawkins): this code does not infer accurate shapes for cyclic graphs. +Status InferShapes(Graph* graph, const std::map& arg_shapes, + const tensorflow::FunctionLibraryDefinition* fnlib_def, + GraphShapeInfo* shape_info); + +// Merges two InferredShapes. Return an error if the two shapes cannot be +// merged. +xla::StatusOr MergeInferredShapes(const InferredShape& a, + const InferredShape& b); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_H_ diff --git a/tensorflow/compiler/jit/shape_inference_test.cc b/tensorflow/compiler/jit/shape_inference_test.cc new file mode 100644 index 00000000000..9268172b1c4 --- /dev/null +++ b/tensorflow/compiler/jit/shape_inference_test.cc @@ -0,0 +1,124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Tests for ShapeInference. + +#include "tensorflow/compiler/jit/shape_inference.h" + +#include +#include + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/jit/test_util.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +TEST(ShapeInferenceTest, Basics) { + Scope root = Scope::NewRootScope().ExitOnError(); + auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT, + ops::Placeholder::Shape({2, 3})); + auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT, + ops::Placeholder::Shape({3})); + auto c = ops::Placeholder(root.WithOpName("C"), DT_FLOAT); + auto d = ops::Add(root.WithOpName("D"), a, b); + auto e = ops::Add(root.WithOpName("E"), d, c); + auto f = ops::Neg(root.WithOpName("F"), e); + auto g = ops::AddN(root.WithOpName("G"), std::initializer_list{e, f}); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_CHECK_OK(root.ToGraph(graph.get())); + + GraphShapeInfo shape_info; + TF_ASSERT_OK(InferShapes(graph.get(), /*arg_shapes=*/{}, + /*fnlib_def=*/nullptr, &shape_info)); + + std::map> expected = { + {"A", {PartialTensorShape({2, 3})}}, {"B", {PartialTensorShape({3})}}, + {"C", {PartialTensorShape()}}, {"D", {PartialTensorShape({2, 3})}}, + {"E", {PartialTensorShape()}}, {"F", {PartialTensorShape()}}, + {"G", {PartialTensorShape()}}, + }; + TF_EXPECT_OK(ShapeAnnotationsMatch(*graph, shape_info, expected)); +} + +TEST(ShapeInferenceTest, WhileLoop) { + // Graph: + // x = array_ops.placeholder(dtypes.int32) + // y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x]) + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32, + ops::Placeholder::Shape({})); + + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32, + ops::Placeholder::Shape({})); + auto enter = + ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop"); + // Add an unused Enter node. These should be ignored. + auto enter2 = + ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop"); + auto merge = ops::Merge(scope.WithOpName("while/Merge"), + std::initializer_list{enter, dummy}); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), + 10); + auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + auto switch_node = + ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); + auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"), + switch_node.output_false); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), + switch_node.output_true); + auto identity_shape = + ops::Const(scope.WithOpName("while/Identity/shape"), {}); + auto identity_reshaped = ops::Reshape( + scope.WithOpName("while/Identity/reshaped"), identity, identity_shape); + + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity_reshaped, one); + auto next_iteration = + ops::NextIteration(scope.WithOpName("while/NextIteration"), add); + + auto sink = ops::Identity(scope.WithOpName("sink"), exit); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + GraphShapeInfo shape_info; + TF_ASSERT_OK(InferShapes(&graph, /*arg_shapes=*/{}, /*fnlib_def=*/nullptr, + &shape_info)); + std::map> expected = { + {"while/Identity", {PartialTensorShape()}}, + {"while/add", {PartialTensorShape({})}}, + }; + TF_EXPECT_OK(ShapeAnnotationsMatch(graph, shape_info, expected)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/test_util.cc b/tensorflow/compiler/jit/test_util.cc new file mode 100644 index 00000000000..cada272090a --- /dev/null +++ b/tensorflow/compiler/jit/test_util.cc @@ -0,0 +1,57 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/jit/test_util.h" + +#include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace tensorflow { + +Status ShapeAnnotationsMatch( + const Graph& graph, const GraphShapeInfo& shape_info, + std::map> expected_shapes) { + for (Node* node : graph.op_nodes()) { + auto sit = shape_info.find(node->name()); + TF_RET_CHECK(sit != shape_info.end()) + << "Missing shape information for node " << node->name(); + std::vector shapes; + for (const auto& output : sit->second) shapes.push_back(output.shape); + + auto it = expected_shapes.find(node->name()); + if (it != expected_shapes.end()) { + if (!PartialTensorShapeUtils::AreIdentical(shapes, it->second)) { + return errors::InvalidArgument( + "Shape mismatch for ", node->name(), ". Expected: ", + PartialTensorShapeUtils::PartialShapeListString(it->second), + ", actual: ", + PartialTensorShapeUtils::PartialShapeListString(shapes)); + } + expected_shapes.erase(it); + } + } + if (!expected_shapes.empty()) { + std::vector missing; + missing.reserve(expected_shapes.size()); + for (const auto& entry : expected_shapes) { + missing.push_back(entry.first); + } + return errors::InvalidArgument("Missing shapes for nodes: ", + str_util::Join(missing, ",")); + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/jit/test_util.h b/tensorflow/compiler/jit/test_util.h new file mode 100644 index 00000000000..0c9fee8f244 --- /dev/null +++ b/tensorflow/compiler/jit/test_util.h @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helper functions for tests. + +#ifndef TENSORFLOW_COMPILER_JIT_TEST_UTIL_H_ +#define TENSORFLOW_COMPILER_JIT_TEST_UTIL_H_ + +#include +#include +#include + +#include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Tests that the shapes in 'shape_info' for the nodes in `graph` match +// `expected_shapes`. Returns an error if there are nodes in `expected_shapes` +// that do not have shape information. Ignores nodes in `graph` that do not have +// `expected_shapes` entries. +Status ShapeAnnotationsMatch( + const Graph& graph, const GraphShapeInfo& shape_info, + std::map> expected_shapes); + +} // namespace tensorflow + + +#endif // TENSORFLOW_COMPILER_JIT_TEST_UTIL_H_