Clean up test for the shape inference so that expectations are obvious.

PiperOrigin-RevId: 261732924
This commit is contained in:
A. Unique TensorFlower 2019-08-05 11:44:34 -07:00 committed by TensorFlower Gardener
parent 9020ea72b4
commit 6edf8c57c7
2 changed files with 69 additions and 45 deletions

View File

@ -504,6 +504,7 @@ tf_cc_test(
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test", "//tensorflow/core:test",
"//tensorflow/core:test_main", "//tensorflow/core:test_main",
"@com_google_absl//absl/types:optional",
], ],
) )

View File

@ -14,6 +14,8 @@ limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/c/c_api_experimental.h" #include "tensorflow/c/c_api_experimental.h"
#include "absl/types/optional.h"
#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/c_test_util.h" #include "tensorflow/c/c_test_util.h"
#include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/eager/c_api.h"
@ -437,84 +439,105 @@ class ShapeInferenceTest : public ::testing::Test {
: status_(TF_NewStatus()), tfe_context_options_(TFE_NewContextOptions()) { : status_(TF_NewStatus()), tfe_context_options_(TFE_NewContextOptions()) {
tfe_context_ = TFE_NewContext(tfe_context_options_, status_); tfe_context_ = TFE_NewContext(tfe_context_options_, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
matmul_op_ = TFE_NewOp(tfe_context_, "MatMul", status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
} }
~ShapeInferenceTest() override { ~ShapeInferenceTest() override {
TFE_DeleteOp(matmul_op_);
TFE_DeleteContextOptions(tfe_context_options_); TFE_DeleteContextOptions(tfe_context_options_);
TFE_DeleteContext(tfe_context_); TFE_DeleteContext(tfe_context_);
TF_DeleteStatus(status_); TF_DeleteStatus(status_);
} }
void infer_matmul_shapes(TF_ShapeAndTypeList* input_shapes, // Checks the expected result of shape inference for the given `op`.
int64_t expected_rank, int64_t expected_first_dim, void CheckOutputShapes(
int64_t expected_second_dim) { TFE_Op* op,
const std::vector<absl::optional<std::vector<int64_t>>>& input_shapes_vec,
TF_Tensor** input_tensors,
const absl::optional<std::vector<int64_t>>& expected_shape) {
// Create input_shapes.
TF_ShapeAndTypeList* input_shapes =
TF_NewShapeAndTypeList(input_shapes_vec.size());
for (size_t i = 0; i < input_shapes_vec.size(); ++i) {
const auto& input_shape = input_shapes_vec[i];
if (input_shape.has_value()) {
TF_ShapeAndTypeListSetShape(input_shapes, i, input_shape->data(),
input_shape->size());
} else {
TF_ShapeAndTypeListSetUnknownShape(input_shapes, i);
}
}
TF_ShapeAndTypeList* output_shapes; TF_ShapeAndTypeList* output_shapes;
TFE_InferShapes(matmul_op_, input_shapes, TFE_InferShapes(op, input_shapes, input_tensors,
/*input_tensors*/ nullptr,
/*input_tensors_as_shapes*/ nullptr, /*input_tensors_as_shapes*/ nullptr,
/*input_resource_shapes_and_types*/ nullptr, &output_shapes, /*input_resource_shapes_and_types*/ nullptr, &output_shapes,
/*output_resource_shapes_and_types*/ nullptr, status_); /*output_resource_shapes_and_types*/ nullptr, status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
CHECK_EQ(output_shapes->num_items, 1); CHECK_EQ(output_shapes->num_items, 1);
EXPECT_EQ(output_shapes->items[0].num_dims, expected_rank);
if (expected_rank == 2) { int num_dims = output_shapes->items[0].num_dims;
EXPECT_EQ(output_shapes->items[0].dims[0], expected_first_dim); int64_t* dims = output_shapes->items[0].dims;
EXPECT_EQ(output_shapes->items[0].dims[1], expected_second_dim);
if (!expected_shape.has_value()) {
EXPECT_EQ(num_dims, -1);
EXPECT_EQ(dims, nullptr);
return;
}
EXPECT_EQ(num_dims, expected_shape->size());
for (size_t i = 0; i < num_dims; ++i) {
EXPECT_EQ(dims[i], (*expected_shape)[i]);
} }
TF_DeleteShapeAndTypeList(input_shapes); TF_DeleteShapeAndTypeList(input_shapes);
TF_DeleteShapeAndTypeList(output_shapes); TF_DeleteShapeAndTypeList(output_shapes);
} }
absl::optional<std::vector<int64_t>> make_shape(
std::vector<int64_t>&& dims) const {
return absl::make_optional(dims);
}
absl::optional<std::vector<int64_t>> unknown_shape() const {
return absl::nullopt;
}
static constexpr int64_t kUnknownDim =
shape_inference::InferenceContext::kUnknownDim;
TF_Status* status_; TF_Status* status_;
TFE_ContextOptions* tfe_context_options_; TFE_ContextOptions* tfe_context_options_;
TFE_Context* tfe_context_; TFE_Context* tfe_context_;
TFE_Op* matmul_op_;
}; };
TEST_F(ShapeInferenceTest, InfersShapes) { TEST_F(ShapeInferenceTest, InfersShapesFromInputShapes) {
TFE_Op* matmul_op;
matmul_op = TFE_NewOp(tfe_context_, "MatMul", status_);
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
// Infer shape when everything is known. // Infer shape when everything is known.
int64_t _3by2[] = {3, 2}; CheckOutputShapes(matmul_op,
int64_t _2by4[] = {2, 4}; /*input_shapes*/ {make_shape({3, 2}), make_shape({2, 4})},
TF_ShapeAndTypeList* input_shapes = TF_NewShapeAndTypeList(/*num_shapes*/ 2); /*input_tensors*/ nullptr,
TF_ShapeAndTypeListSetShape(input_shapes, 0, _3by2, 2); /*expected_shape*/ make_shape({3, 4}));
TF_ShapeAndTypeListSetShape(input_shapes, 1, _2by4, 2);
infer_matmul_shapes(input_shapes, /*expected_rank*/ 2,
/*expected_first_dim*/ 3, /*expected_second_dim*/ 4);
// Infer shape when second operand has unknown shape. // Infer shape when second operand has unknown shape.
TF_ShapeAndTypeList* input_shapes_unknown_second = CheckOutputShapes(matmul_op,
TF_NewShapeAndTypeList(/*num_shapes*/ 2); /*input_shapes*/ {make_shape({3, 2}), unknown_shape()},
TF_ShapeAndTypeListSetShape(input_shapes_unknown_second, 0, _3by2, 2); /*input_tensors*/ nullptr,
TF_ShapeAndTypeListSetUnknownShape(input_shapes_unknown_second, 1); /*expected_shape*/ make_shape({3, kUnknownDim}));
infer_matmul_shapes(
input_shapes_unknown_second, /*expected_rank*/ 2,
/*expected_first_dim*/ 3,
/*expected_second_dim*/ shape_inference::InferenceContext::kUnknownDim);
// Infer shape when some dimensions are unknown. // Infer shape when some dimensions are unknown.
int64_t _unknownby2[] = {-1, 2}; CheckOutputShapes(
TF_ShapeAndTypeList* input_shapes_unknown_dims = matmul_op,
TF_NewShapeAndTypeList(/*num_shapes*/ 2); /*input_shapes*/ {make_shape({kUnknownDim, 2}), make_shape({2, 4})},
TF_ShapeAndTypeListSetShape(input_shapes_unknown_dims, 0, _unknownby2, 2); /*input_tensors*/ nullptr,
TF_ShapeAndTypeListSetShape(input_shapes_unknown_dims, 1, _2by4, 2); /*expected_shape*/ make_shape({kUnknownDim, 4}));
infer_matmul_shapes(
input_shapes_unknown_dims, /*expected_rank*/ 2,
/*expected_first_dim*/ shape_inference::InferenceContext::kUnknownDim,
/*expected_second_dim*/ 4);
// Infer shape when everything is unknown. // Infer shape when everything is unknown.
TF_ShapeAndTypeList* unknown_shapes = CheckOutputShapes(matmul_op,
TF_NewShapeAndTypeList(/*num_shapes*/ 2); /*input_shapes*/ {unknown_shape(), unknown_shape()},
TF_ShapeAndTypeListSetUnknownShape(unknown_shapes, 0); /*input_tensors*/ nullptr,
TF_ShapeAndTypeListSetUnknownShape(unknown_shapes, 1); /*expected_shape*/ make_shape({kUnknownDim, kUnknownDim}));
infer_matmul_shapes(
unknown_shapes, /*expected_rank*/ 2,
/*expected_first_dim*/ shape_inference::InferenceContext::kUnknownDim,
/*expected_second_dim*/ shape_inference::InferenceContext::kUnknownDim);
TFE_DeleteOp(matmul_op);
// TODO(bgogul): Add some death tests where status is not OK. // TODO(bgogul): Add some death tests where status is not OK.
} }