diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index 03467cc4e27..2393d973522 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -504,6 +504,7 @@ tf_cc_test( "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@com_google_absl//absl/types:optional", ], ) diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index ffe22a039fd..3f0c9de66d6 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -14,6 +14,8 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/c/c_api_experimental.h" + +#include "absl/types/optional.h" #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/c/eager/c_api.h" @@ -437,84 +439,105 @@ class ShapeInferenceTest : public ::testing::Test { : status_(TF_NewStatus()), tfe_context_options_(TFE_NewContextOptions()) { tfe_context_ = TFE_NewContext(tfe_context_options_, status_); CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); - matmul_op_ = TFE_NewOp(tfe_context_, "MatMul", status_); CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); } ~ShapeInferenceTest() override { - TFE_DeleteOp(matmul_op_); TFE_DeleteContextOptions(tfe_context_options_); TFE_DeleteContext(tfe_context_); TF_DeleteStatus(status_); } - void infer_matmul_shapes(TF_ShapeAndTypeList* input_shapes, - int64_t expected_rank, int64_t expected_first_dim, - int64_t expected_second_dim) { + // Checks the expected result of shape inference for the given `op`. + void CheckOutputShapes( + TFE_Op* op, + const std::vector>>& input_shapes_vec, + TF_Tensor** input_tensors, + const absl::optional>& expected_shape) { + // Create input_shapes. + TF_ShapeAndTypeList* input_shapes = + TF_NewShapeAndTypeList(input_shapes_vec.size()); + for (size_t i = 0; i < input_shapes_vec.size(); ++i) { + const auto& input_shape = input_shapes_vec[i]; + if (input_shape.has_value()) { + TF_ShapeAndTypeListSetShape(input_shapes, i, input_shape->data(), + input_shape->size()); + } else { + TF_ShapeAndTypeListSetUnknownShape(input_shapes, i); + } + } TF_ShapeAndTypeList* output_shapes; - TFE_InferShapes(matmul_op_, input_shapes, - /*input_tensors*/ nullptr, + TFE_InferShapes(op, input_shapes, input_tensors, /*input_tensors_as_shapes*/ nullptr, /*input_resource_shapes_and_types*/ nullptr, &output_shapes, /*output_resource_shapes_and_types*/ nullptr, status_); CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); CHECK_EQ(output_shapes->num_items, 1); - EXPECT_EQ(output_shapes->items[0].num_dims, expected_rank); - if (expected_rank == 2) { - EXPECT_EQ(output_shapes->items[0].dims[0], expected_first_dim); - EXPECT_EQ(output_shapes->items[0].dims[1], expected_second_dim); + + int num_dims = output_shapes->items[0].num_dims; + int64_t* dims = output_shapes->items[0].dims; + + if (!expected_shape.has_value()) { + EXPECT_EQ(num_dims, -1); + EXPECT_EQ(dims, nullptr); + return; + } + + EXPECT_EQ(num_dims, expected_shape->size()); + for (size_t i = 0; i < num_dims; ++i) { + EXPECT_EQ(dims[i], (*expected_shape)[i]); } TF_DeleteShapeAndTypeList(input_shapes); TF_DeleteShapeAndTypeList(output_shapes); } + absl::optional> make_shape( + std::vector&& dims) const { + return absl::make_optional(dims); + } + + absl::optional> unknown_shape() const { + return absl::nullopt; + } + + static constexpr int64_t kUnknownDim = + shape_inference::InferenceContext::kUnknownDim; TF_Status* status_; TFE_ContextOptions* tfe_context_options_; TFE_Context* tfe_context_; - TFE_Op* matmul_op_; }; -TEST_F(ShapeInferenceTest, InfersShapes) { +TEST_F(ShapeInferenceTest, InfersShapesFromInputShapes) { + TFE_Op* matmul_op; + matmul_op = TFE_NewOp(tfe_context_, "MatMul", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + // Infer shape when everything is known. - int64_t _3by2[] = {3, 2}; - int64_t _2by4[] = {2, 4}; - TF_ShapeAndTypeList* input_shapes = TF_NewShapeAndTypeList(/*num_shapes*/ 2); - TF_ShapeAndTypeListSetShape(input_shapes, 0, _3by2, 2); - TF_ShapeAndTypeListSetShape(input_shapes, 1, _2by4, 2); - infer_matmul_shapes(input_shapes, /*expected_rank*/ 2, - /*expected_first_dim*/ 3, /*expected_second_dim*/ 4); + CheckOutputShapes(matmul_op, + /*input_shapes*/ {make_shape({3, 2}), make_shape({2, 4})}, + /*input_tensors*/ nullptr, + /*expected_shape*/ make_shape({3, 4})); // Infer shape when second operand has unknown shape. - TF_ShapeAndTypeList* input_shapes_unknown_second = - TF_NewShapeAndTypeList(/*num_shapes*/ 2); - TF_ShapeAndTypeListSetShape(input_shapes_unknown_second, 0, _3by2, 2); - TF_ShapeAndTypeListSetUnknownShape(input_shapes_unknown_second, 1); - infer_matmul_shapes( - input_shapes_unknown_second, /*expected_rank*/ 2, - /*expected_first_dim*/ 3, - /*expected_second_dim*/ shape_inference::InferenceContext::kUnknownDim); + CheckOutputShapes(matmul_op, + /*input_shapes*/ {make_shape({3, 2}), unknown_shape()}, + /*input_tensors*/ nullptr, + /*expected_shape*/ make_shape({3, kUnknownDim})); // Infer shape when some dimensions are unknown. - int64_t _unknownby2[] = {-1, 2}; - TF_ShapeAndTypeList* input_shapes_unknown_dims = - TF_NewShapeAndTypeList(/*num_shapes*/ 2); - TF_ShapeAndTypeListSetShape(input_shapes_unknown_dims, 0, _unknownby2, 2); - TF_ShapeAndTypeListSetShape(input_shapes_unknown_dims, 1, _2by4, 2); - infer_matmul_shapes( - input_shapes_unknown_dims, /*expected_rank*/ 2, - /*expected_first_dim*/ shape_inference::InferenceContext::kUnknownDim, - /*expected_second_dim*/ 4); + CheckOutputShapes( + matmul_op, + /*input_shapes*/ {make_shape({kUnknownDim, 2}), make_shape({2, 4})}, + /*input_tensors*/ nullptr, + /*expected_shape*/ make_shape({kUnknownDim, 4})); // Infer shape when everything is unknown. - TF_ShapeAndTypeList* unknown_shapes = - TF_NewShapeAndTypeList(/*num_shapes*/ 2); - TF_ShapeAndTypeListSetUnknownShape(unknown_shapes, 0); - TF_ShapeAndTypeListSetUnknownShape(unknown_shapes, 1); - infer_matmul_shapes( - unknown_shapes, /*expected_rank*/ 2, - /*expected_first_dim*/ shape_inference::InferenceContext::kUnknownDim, - /*expected_second_dim*/ shape_inference::InferenceContext::kUnknownDim); + CheckOutputShapes(matmul_op, + /*input_shapes*/ {unknown_shape(), unknown_shape()}, + /*input_tensors*/ nullptr, + /*expected_shape*/ make_shape({kUnknownDim, kUnknownDim})); + TFE_DeleteOp(matmul_op); // TODO(bgogul): Add some death tests where status is not OK. }