Move shape inference utility to open source.
PiperOrigin-RevId: 217259772
This commit is contained in:
parent
5473b48a76
commit
062439348e
@ -359,6 +359,52 @@ tf_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "shape_inference",
|
||||
srcs = ["shape_inference.cc"],
|
||||
hdrs = ["shape_inference.h"],
|
||||
deps = [
|
||||
":shape_inference_helpers",
|
||||
"//tensorflow/compiler/tf2xla:dump_graph",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:graph",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "test_util",
|
||||
testonly = 1,
|
||||
srcs = ["test_util.cc"],
|
||||
hdrs = ["test_util.h"],
|
||||
deps = [
|
||||
":shape_inference",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "shape_inference_test",
|
||||
srcs = ["shape_inference_test.cc"],
|
||||
deps = [
|
||||
":shape_inference",
|
||||
":test_util",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/cc:ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/kernels:constant_op",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "compilation_passes",
|
||||
srcs = [
|
||||
|
174
tensorflow/compiler/jit/shape_inference.cc
Normal file
174
tensorflow/compiler/jit/shape_inference.cc
Normal file
@ -0,0 +1,174 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/shape_inference.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/shape_inference_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/dump_graph.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/graph/algorithm.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// Converts a shape inference handle to a PartialTensorShape.
|
||||
Status ShapeHandleToTensorShape(shape_inference::InferenceContext* context,
|
||||
const shape_inference::ShapeHandle& handle,
|
||||
PartialTensorShape* shape) {
|
||||
// The default is already unknown
|
||||
if (!context->RankKnown(handle)) return Status::OK();
|
||||
|
||||
std::vector<int64> dims(context->Rank(handle));
|
||||
for (int32 i = 0; i < dims.size(); ++i) {
|
||||
dims[i] = context->Value(context->Dim(handle, i));
|
||||
}
|
||||
return PartialTensorShape::MakePartialShape(dims.data(), dims.size(), shape);
|
||||
}
|
||||
|
||||
Status PropagateShapes(const Graph& graph,
|
||||
const std::map<int, InferredShape>& arg_shapes,
|
||||
ShapeRefiner* shape_refiner) {
|
||||
// Visits the nodes in topological order (reverse post-order), inferring
|
||||
// shapes.
|
||||
// TODO(phawkins): handle cyclic graphs.
|
||||
std::vector<Node*> order;
|
||||
GetReversePostOrder(graph, &order);
|
||||
|
||||
for (Node* n : order) {
|
||||
// Ignore the status returned by the shape_refiner. We want the best effort
|
||||
// shapes, even if no shape function is registered for a node.
|
||||
Status status = shape_refiner->AddNode(n);
|
||||
if (!status.ok()) {
|
||||
VLOG(1) << "Shape inference failed for node: " << status;
|
||||
}
|
||||
|
||||
if (n->type_string() == "_Arg") {
|
||||
int index;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
|
||||
auto it = arg_shapes.find(index);
|
||||
if (it != arg_shapes.end()) {
|
||||
const InferredShape& arg_shape = it->second;
|
||||
shape_inference::InferenceContext* context =
|
||||
shape_refiner->GetContext(n);
|
||||
|
||||
if (arg_shape.handle_type != DT_INVALID) {
|
||||
shape_inference::ShapeHandle handle;
|
||||
TF_RETURN_IF_ERROR(context->MakeShapeFromPartialTensorShape(
|
||||
arg_shape.handle_shape, &handle));
|
||||
|
||||
// Sets the shape and type of the variable's value.
|
||||
context->set_output_handle_shapes_and_types(
|
||||
0, std::vector<shape_inference::ShapeAndType>{
|
||||
{handle, arg_shape.handle_type}});
|
||||
}
|
||||
|
||||
shape_inference::ShapeHandle handle;
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->MakeShapeFromPartialTensorShape(arg_shape.shape, &handle));
|
||||
TF_RETURN_IF_ERROR(shape_refiner->SetShape(n, 0, handle));
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Store the shapes of the output tensors in a map
|
||||
Status StoreOutputShapes(const Graph& graph, const ShapeRefiner& shape_refiner,
|
||||
GraphShapeInfo* shape_info) {
|
||||
for (const Node* node : graph.nodes()) {
|
||||
shape_inference::InferenceContext* context = shape_refiner.GetContext(node);
|
||||
if (!context) continue;
|
||||
|
||||
auto& outputs = (*shape_info)[node->name()];
|
||||
outputs.resize(context->num_outputs());
|
||||
for (int i = 0; i < context->num_outputs(); ++i) {
|
||||
auto& output = outputs[i];
|
||||
TF_RETURN_IF_ERROR(
|
||||
ShapeHandleToTensorShape(context, context->output(i), &output.shape));
|
||||
|
||||
const auto* handle_shapes_and_types =
|
||||
context->output_handle_shapes_and_types(i);
|
||||
if (handle_shapes_and_types != nullptr) {
|
||||
if (handle_shapes_and_types->size() == 1) {
|
||||
TF_RETURN_IF_ERROR(ShapeHandleToTensorShape(
|
||||
context, (*handle_shapes_and_types)[0].shape,
|
||||
&output.handle_shape));
|
||||
output.handle_type = (*handle_shapes_and_types)[0].dtype;
|
||||
} else {
|
||||
// otherwise, it may be resource like a Queue, which can have
|
||||
// multiple shapes and types represented by a single handle.
|
||||
}
|
||||
}
|
||||
VLOG(4) << node->name() << " output " << i << " shape"
|
||||
<< output.shape.DebugString() << " handle_type "
|
||||
<< DataTypeString(output.handle_type) << " handle_shape "
|
||||
<< output.handle_shape.DebugString();
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes,
|
||||
const tensorflow::FunctionLibraryDefinition* fnlib_def,
|
||||
GraphShapeInfo* shape_info) {
|
||||
ShapeRefiner shape_refiner(graph->versions(), graph->op_registry());
|
||||
shape_refiner.set_require_shape_inference_fns(false);
|
||||
// TODO(dlibenzi): Verify if it is worth trying to infer shaped within
|
||||
// functions. Some functions can be called at multiple locations with
|
||||
// difference shapes, which will trigger a shape inference based on the
|
||||
// arguments passed at the first call.
|
||||
// shape_refiner.set_function_library_for_shape_inference(fnlib_def);
|
||||
|
||||
// ShapeRefiner requires that all inputs of a node are present when
|
||||
// ShapeRefiner::AddNode is called. To get at least some shape information in
|
||||
// loops, we temporarily remove loop backedges and add them back again after
|
||||
// the shape inference is complete.
|
||||
BackEdgeHelper back_edge;
|
||||
TF_RETURN_IF_ERROR(back_edge.Remove(graph));
|
||||
TF_RETURN_IF_ERROR(PropagateShapes(*graph, arg_shapes, &shape_refiner));
|
||||
TF_RETURN_IF_ERROR(back_edge.Replace());
|
||||
|
||||
// Currently information does not flow "backward" from consumers to producers
|
||||
// in the shape inference, but we consume the shapes in a second pass in case
|
||||
// backward information flow is added in the future.
|
||||
return StoreOutputShapes(*graph, shape_refiner, shape_info);
|
||||
}
|
||||
|
||||
xla::StatusOr<InferredShape> MergeInferredShapes(const InferredShape& a,
|
||||
const InferredShape& b) {
|
||||
InferredShape result;
|
||||
TF_RETURN_IF_ERROR(a.shape.MergeWith(b.shape, &result.shape));
|
||||
|
||||
if (a.handle_type == DT_INVALID) {
|
||||
result.handle_type = b.handle_type;
|
||||
} else if (b.handle_type == DT_INVALID) {
|
||||
result.handle_type = a.handle_type;
|
||||
} else if (a.handle_type == b.handle_type) {
|
||||
result.handle_type = a.handle_type;
|
||||
} else {
|
||||
return errors::InvalidArgument(
|
||||
"Mismatched resource types: ", DataTypeString(a.handle_type), " vs. ",
|
||||
DataTypeString(b.handle_type));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
a.handle_shape.MergeWith(b.handle_shape, &result.handle_shape));
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
54
tensorflow/compiler/jit/shape_inference.h
Normal file
54
tensorflow/compiler/jit/shape_inference.h
Normal file
@ -0,0 +1,54 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/common_runtime/optimization_registry.h"
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
struct InferredShape {
|
||||
// Shape of the argument tensor.
|
||||
PartialTensorShape shape;
|
||||
|
||||
// If the argument is a resource variable, the type and shape of the
|
||||
// variable's value.
|
||||
DataType handle_type = DT_INVALID;
|
||||
PartialTensorShape handle_shape;
|
||||
};
|
||||
typedef std::unordered_map<string, std::vector<InferredShape>> GraphShapeInfo;
|
||||
|
||||
// Infer shapes for all Tensors in a graph, and save them in a map. The vector
|
||||
// for a Node contains the information about each of its outputs.
|
||||
// TODO(phawkins): this code does not infer accurate shapes for cyclic graphs.
|
||||
Status InferShapes(Graph* graph, const std::map<int, InferredShape>& arg_shapes,
|
||||
const tensorflow::FunctionLibraryDefinition* fnlib_def,
|
||||
GraphShapeInfo* shape_info);
|
||||
|
||||
// Merges two InferredShapes. Return an error if the two shapes cannot be
|
||||
// merged.
|
||||
xla::StatusOr<InferredShape> MergeInferredShapes(const InferredShape& a,
|
||||
const InferredShape& b);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_SHAPE_INFERENCE_H_
|
124
tensorflow/compiler/jit/shape_inference_test.cc
Normal file
124
tensorflow/compiler/jit/shape_inference_test.cc
Normal file
@ -0,0 +1,124 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Tests for ShapeInference.
|
||||
|
||||
#include "tensorflow/compiler/jit/shape_inference.h"
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/ops/control_flow_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/jit/test_util.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(ShapeInferenceTest, Basics) {
|
||||
Scope root = Scope::NewRootScope().ExitOnError();
|
||||
auto a = ops::Placeholder(root.WithOpName("A"), DT_FLOAT,
|
||||
ops::Placeholder::Shape({2, 3}));
|
||||
auto b = ops::Placeholder(root.WithOpName("B"), DT_FLOAT,
|
||||
ops::Placeholder::Shape({3}));
|
||||
auto c = ops::Placeholder(root.WithOpName("C"), DT_FLOAT);
|
||||
auto d = ops::Add(root.WithOpName("D"), a, b);
|
||||
auto e = ops::Add(root.WithOpName("E"), d, c);
|
||||
auto f = ops::Neg(root.WithOpName("F"), e);
|
||||
auto g = ops::AddN(root.WithOpName("G"), std::initializer_list<Output>{e, f});
|
||||
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_CHECK_OK(root.ToGraph(graph.get()));
|
||||
|
||||
GraphShapeInfo shape_info;
|
||||
TF_ASSERT_OK(InferShapes(graph.get(), /*arg_shapes=*/{},
|
||||
/*fnlib_def=*/nullptr, &shape_info));
|
||||
|
||||
std::map<string, std::vector<PartialTensorShape>> expected = {
|
||||
{"A", {PartialTensorShape({2, 3})}}, {"B", {PartialTensorShape({3})}},
|
||||
{"C", {PartialTensorShape()}}, {"D", {PartialTensorShape({2, 3})}},
|
||||
{"E", {PartialTensorShape()}}, {"F", {PartialTensorShape()}},
|
||||
{"G", {PartialTensorShape()}},
|
||||
};
|
||||
TF_EXPECT_OK(ShapeAnnotationsMatch(*graph, shape_info, expected));
|
||||
}
|
||||
|
||||
TEST(ShapeInferenceTest, WhileLoop) {
|
||||
// Graph:
|
||||
// x = array_ops.placeholder(dtypes.int32)
|
||||
// y = control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x])
|
||||
Graph graph(OpRegistry::Global());
|
||||
{
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
|
||||
auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32,
|
||||
ops::Placeholder::Shape({}));
|
||||
|
||||
auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32,
|
||||
ops::Placeholder::Shape({}));
|
||||
auto enter =
|
||||
ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop");
|
||||
// Add an unused Enter node. These should be ignored.
|
||||
auto enter2 =
|
||||
ops::internal::Enter(scope.WithOpName("while/Enter2"), source, "aloop");
|
||||
auto merge = ops::Merge(scope.WithOpName("while/Merge"),
|
||||
std::initializer_list<Input>{enter, dummy});
|
||||
auto ten = ops::Const<int32>(
|
||||
scope.WithOpName("while/Less/y").WithControlDependencies(merge.output),
|
||||
10);
|
||||
auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten);
|
||||
auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less);
|
||||
auto switch_node =
|
||||
ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond);
|
||||
auto exit = ops::internal::Exit(scope.WithOpName("while/Exit"),
|
||||
switch_node.output_false);
|
||||
auto identity = ops::Identity(scope.WithOpName("while/Identity"),
|
||||
switch_node.output_true);
|
||||
auto identity_shape =
|
||||
ops::Const<int32>(scope.WithOpName("while/Identity/shape"), {});
|
||||
auto identity_reshaped = ops::Reshape(
|
||||
scope.WithOpName("while/Identity/reshaped"), identity, identity_shape);
|
||||
|
||||
auto one = ops::Const<int32>(
|
||||
scope.WithOpName("while/add/y").WithControlDependencies(identity), 1);
|
||||
auto add = ops::Add(scope.WithOpName("while/add"), identity_reshaped, one);
|
||||
auto next_iteration =
|
||||
ops::NextIteration(scope.WithOpName("while/NextIteration"), add);
|
||||
|
||||
auto sink = ops::Identity(scope.WithOpName("sink"), exit);
|
||||
|
||||
// Remove the dummy node and add the loop backedge.
|
||||
scope.graph()->RemoveNode(dummy.node());
|
||||
scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1);
|
||||
|
||||
TF_EXPECT_OK(scope.ToGraph(&graph));
|
||||
}
|
||||
|
||||
GraphShapeInfo shape_info;
|
||||
TF_ASSERT_OK(InferShapes(&graph, /*arg_shapes=*/{}, /*fnlib_def=*/nullptr,
|
||||
&shape_info));
|
||||
std::map<string, std::vector<PartialTensorShape>> expected = {
|
||||
{"while/Identity", {PartialTensorShape()}},
|
||||
{"while/add", {PartialTensorShape({})}},
|
||||
};
|
||||
TF_EXPECT_OK(ShapeAnnotationsMatch(graph, shape_info, expected));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
57
tensorflow/compiler/jit/test_util.cc
Normal file
57
tensorflow/compiler/jit/test_util.cc
Normal file
@ -0,0 +1,57 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/jit/test_util.h"
|
||||
|
||||
#include "tensorflow/compiler/jit/shape_inference.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Status ShapeAnnotationsMatch(
|
||||
const Graph& graph, const GraphShapeInfo& shape_info,
|
||||
std::map<string, std::vector<PartialTensorShape>> expected_shapes) {
|
||||
for (Node* node : graph.op_nodes()) {
|
||||
auto sit = shape_info.find(node->name());
|
||||
TF_RET_CHECK(sit != shape_info.end())
|
||||
<< "Missing shape information for node " << node->name();
|
||||
std::vector<PartialTensorShape> shapes;
|
||||
for (const auto& output : sit->second) shapes.push_back(output.shape);
|
||||
|
||||
auto it = expected_shapes.find(node->name());
|
||||
if (it != expected_shapes.end()) {
|
||||
if (!PartialTensorShapeUtils::AreIdentical(shapes, it->second)) {
|
||||
return errors::InvalidArgument(
|
||||
"Shape mismatch for ", node->name(), ". Expected: ",
|
||||
PartialTensorShapeUtils::PartialShapeListString(it->second),
|
||||
", actual: ",
|
||||
PartialTensorShapeUtils::PartialShapeListString(shapes));
|
||||
}
|
||||
expected_shapes.erase(it);
|
||||
}
|
||||
}
|
||||
if (!expected_shapes.empty()) {
|
||||
std::vector<string> missing;
|
||||
missing.reserve(expected_shapes.size());
|
||||
for (const auto& entry : expected_shapes) {
|
||||
missing.push_back(entry.first);
|
||||
}
|
||||
return errors::InvalidArgument("Missing shapes for nodes: ",
|
||||
str_util::Join(missing, ","));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
44
tensorflow/compiler/jit/test_util.h
Normal file
44
tensorflow/compiler/jit/test_util.h
Normal file
@ -0,0 +1,44 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Helper functions for tests.
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_JIT_TEST_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_JIT_TEST_UTIL_H_
|
||||
|
||||
#include <map>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/jit/shape_inference.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/partial_tensor_shape.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Tests that the shapes in 'shape_info' for the nodes in `graph` match
|
||||
// `expected_shapes`. Returns an error if there are nodes in `expected_shapes`
|
||||
// that do not have shape information. Ignores nodes in `graph` that do not have
|
||||
// `expected_shapes` entries.
|
||||
Status ShapeAnnotationsMatch(
|
||||
const Graph& graph, const GraphShapeInfo& shape_info,
|
||||
std::map<string, std::vector<PartialTensorShape>> expected_shapes);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_JIT_TEST_UTIL_H_
|
Loading…
Reference in New Issue
Block a user