Clean up test for the shape inference so that expectations are obvious.
PiperOrigin-RevId: 261732924
This commit is contained in:
parent
9020ea72b4
commit
6edf8c57c7
@ -504,6 +504,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -14,6 +14,8 @@ limitations under the License.
|
|||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/c/c_api_experimental.h"
|
#include "tensorflow/c/c_api_experimental.h"
|
||||||
|
|
||||||
|
#include "absl/types/optional.h"
|
||||||
#include "tensorflow/c/c_api_internal.h"
|
#include "tensorflow/c/c_api_internal.h"
|
||||||
#include "tensorflow/c/c_test_util.h"
|
#include "tensorflow/c/c_test_util.h"
|
||||||
#include "tensorflow/c/eager/c_api.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
@ -437,84 +439,105 @@ class ShapeInferenceTest : public ::testing::Test {
|
|||||||
: status_(TF_NewStatus()), tfe_context_options_(TFE_NewContextOptions()) {
|
: status_(TF_NewStatus()), tfe_context_options_(TFE_NewContextOptions()) {
|
||||||
tfe_context_ = TFE_NewContext(tfe_context_options_, status_);
|
tfe_context_ = TFE_NewContext(tfe_context_options_, status_);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
matmul_op_ = TFE_NewOp(tfe_context_, "MatMul", status_);
|
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
}
|
}
|
||||||
|
|
||||||
~ShapeInferenceTest() override {
|
~ShapeInferenceTest() override {
|
||||||
TFE_DeleteOp(matmul_op_);
|
|
||||||
TFE_DeleteContextOptions(tfe_context_options_);
|
TFE_DeleteContextOptions(tfe_context_options_);
|
||||||
TFE_DeleteContext(tfe_context_);
|
TFE_DeleteContext(tfe_context_);
|
||||||
TF_DeleteStatus(status_);
|
TF_DeleteStatus(status_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void infer_matmul_shapes(TF_ShapeAndTypeList* input_shapes,
|
// Checks the expected result of shape inference for the given `op`.
|
||||||
int64_t expected_rank, int64_t expected_first_dim,
|
void CheckOutputShapes(
|
||||||
int64_t expected_second_dim) {
|
TFE_Op* op,
|
||||||
|
const std::vector<absl::optional<std::vector<int64_t>>>& input_shapes_vec,
|
||||||
|
TF_Tensor** input_tensors,
|
||||||
|
const absl::optional<std::vector<int64_t>>& expected_shape) {
|
||||||
|
// Create input_shapes.
|
||||||
|
TF_ShapeAndTypeList* input_shapes =
|
||||||
|
TF_NewShapeAndTypeList(input_shapes_vec.size());
|
||||||
|
for (size_t i = 0; i < input_shapes_vec.size(); ++i) {
|
||||||
|
const auto& input_shape = input_shapes_vec[i];
|
||||||
|
if (input_shape.has_value()) {
|
||||||
|
TF_ShapeAndTypeListSetShape(input_shapes, i, input_shape->data(),
|
||||||
|
input_shape->size());
|
||||||
|
} else {
|
||||||
|
TF_ShapeAndTypeListSetUnknownShape(input_shapes, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
TF_ShapeAndTypeList* output_shapes;
|
TF_ShapeAndTypeList* output_shapes;
|
||||||
TFE_InferShapes(matmul_op_, input_shapes,
|
TFE_InferShapes(op, input_shapes, input_tensors,
|
||||||
/*input_tensors*/ nullptr,
|
|
||||||
/*input_tensors_as_shapes*/ nullptr,
|
/*input_tensors_as_shapes*/ nullptr,
|
||||||
/*input_resource_shapes_and_types*/ nullptr, &output_shapes,
|
/*input_resource_shapes_and_types*/ nullptr, &output_shapes,
|
||||||
/*output_resource_shapes_and_types*/ nullptr, status_);
|
/*output_resource_shapes_and_types*/ nullptr, status_);
|
||||||
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
CHECK_EQ(output_shapes->num_items, 1);
|
CHECK_EQ(output_shapes->num_items, 1);
|
||||||
EXPECT_EQ(output_shapes->items[0].num_dims, expected_rank);
|
|
||||||
if (expected_rank == 2) {
|
int num_dims = output_shapes->items[0].num_dims;
|
||||||
EXPECT_EQ(output_shapes->items[0].dims[0], expected_first_dim);
|
int64_t* dims = output_shapes->items[0].dims;
|
||||||
EXPECT_EQ(output_shapes->items[0].dims[1], expected_second_dim);
|
|
||||||
|
if (!expected_shape.has_value()) {
|
||||||
|
EXPECT_EQ(num_dims, -1);
|
||||||
|
EXPECT_EQ(dims, nullptr);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
EXPECT_EQ(num_dims, expected_shape->size());
|
||||||
|
for (size_t i = 0; i < num_dims; ++i) {
|
||||||
|
EXPECT_EQ(dims[i], (*expected_shape)[i]);
|
||||||
}
|
}
|
||||||
TF_DeleteShapeAndTypeList(input_shapes);
|
TF_DeleteShapeAndTypeList(input_shapes);
|
||||||
TF_DeleteShapeAndTypeList(output_shapes);
|
TF_DeleteShapeAndTypeList(output_shapes);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
absl::optional<std::vector<int64_t>> make_shape(
|
||||||
|
std::vector<int64_t>&& dims) const {
|
||||||
|
return absl::make_optional(dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
absl::optional<std::vector<int64_t>> unknown_shape() const {
|
||||||
|
return absl::nullopt;
|
||||||
|
}
|
||||||
|
|
||||||
|
static constexpr int64_t kUnknownDim =
|
||||||
|
shape_inference::InferenceContext::kUnknownDim;
|
||||||
TF_Status* status_;
|
TF_Status* status_;
|
||||||
TFE_ContextOptions* tfe_context_options_;
|
TFE_ContextOptions* tfe_context_options_;
|
||||||
TFE_Context* tfe_context_;
|
TFE_Context* tfe_context_;
|
||||||
TFE_Op* matmul_op_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(ShapeInferenceTest, InfersShapes) {
|
TEST_F(ShapeInferenceTest, InfersShapesFromInputShapes) {
|
||||||
|
TFE_Op* matmul_op;
|
||||||
|
matmul_op = TFE_NewOp(tfe_context_, "MatMul", status_);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_);
|
||||||
|
|
||||||
// Infer shape when everything is known.
|
// Infer shape when everything is known.
|
||||||
int64_t _3by2[] = {3, 2};
|
CheckOutputShapes(matmul_op,
|
||||||
int64_t _2by4[] = {2, 4};
|
/*input_shapes*/ {make_shape({3, 2}), make_shape({2, 4})},
|
||||||
TF_ShapeAndTypeList* input_shapes = TF_NewShapeAndTypeList(/*num_shapes*/ 2);
|
/*input_tensors*/ nullptr,
|
||||||
TF_ShapeAndTypeListSetShape(input_shapes, 0, _3by2, 2);
|
/*expected_shape*/ make_shape({3, 4}));
|
||||||
TF_ShapeAndTypeListSetShape(input_shapes, 1, _2by4, 2);
|
|
||||||
infer_matmul_shapes(input_shapes, /*expected_rank*/ 2,
|
|
||||||
/*expected_first_dim*/ 3, /*expected_second_dim*/ 4);
|
|
||||||
|
|
||||||
// Infer shape when second operand has unknown shape.
|
// Infer shape when second operand has unknown shape.
|
||||||
TF_ShapeAndTypeList* input_shapes_unknown_second =
|
CheckOutputShapes(matmul_op,
|
||||||
TF_NewShapeAndTypeList(/*num_shapes*/ 2);
|
/*input_shapes*/ {make_shape({3, 2}), unknown_shape()},
|
||||||
TF_ShapeAndTypeListSetShape(input_shapes_unknown_second, 0, _3by2, 2);
|
/*input_tensors*/ nullptr,
|
||||||
TF_ShapeAndTypeListSetUnknownShape(input_shapes_unknown_second, 1);
|
/*expected_shape*/ make_shape({3, kUnknownDim}));
|
||||||
infer_matmul_shapes(
|
|
||||||
input_shapes_unknown_second, /*expected_rank*/ 2,
|
|
||||||
/*expected_first_dim*/ 3,
|
|
||||||
/*expected_second_dim*/ shape_inference::InferenceContext::kUnknownDim);
|
|
||||||
|
|
||||||
// Infer shape when some dimensions are unknown.
|
// Infer shape when some dimensions are unknown.
|
||||||
int64_t _unknownby2[] = {-1, 2};
|
CheckOutputShapes(
|
||||||
TF_ShapeAndTypeList* input_shapes_unknown_dims =
|
matmul_op,
|
||||||
TF_NewShapeAndTypeList(/*num_shapes*/ 2);
|
/*input_shapes*/ {make_shape({kUnknownDim, 2}), make_shape({2, 4})},
|
||||||
TF_ShapeAndTypeListSetShape(input_shapes_unknown_dims, 0, _unknownby2, 2);
|
/*input_tensors*/ nullptr,
|
||||||
TF_ShapeAndTypeListSetShape(input_shapes_unknown_dims, 1, _2by4, 2);
|
/*expected_shape*/ make_shape({kUnknownDim, 4}));
|
||||||
infer_matmul_shapes(
|
|
||||||
input_shapes_unknown_dims, /*expected_rank*/ 2,
|
|
||||||
/*expected_first_dim*/ shape_inference::InferenceContext::kUnknownDim,
|
|
||||||
/*expected_second_dim*/ 4);
|
|
||||||
|
|
||||||
// Infer shape when everything is unknown.
|
// Infer shape when everything is unknown.
|
||||||
TF_ShapeAndTypeList* unknown_shapes =
|
CheckOutputShapes(matmul_op,
|
||||||
TF_NewShapeAndTypeList(/*num_shapes*/ 2);
|
/*input_shapes*/ {unknown_shape(), unknown_shape()},
|
||||||
TF_ShapeAndTypeListSetUnknownShape(unknown_shapes, 0);
|
/*input_tensors*/ nullptr,
|
||||||
TF_ShapeAndTypeListSetUnknownShape(unknown_shapes, 1);
|
/*expected_shape*/ make_shape({kUnknownDim, kUnknownDim}));
|
||||||
infer_matmul_shapes(
|
|
||||||
unknown_shapes, /*expected_rank*/ 2,
|
|
||||||
/*expected_first_dim*/ shape_inference::InferenceContext::kUnknownDim,
|
|
||||||
/*expected_second_dim*/ shape_inference::InferenceContext::kUnknownDim);
|
|
||||||
|
|
||||||
|
TFE_DeleteOp(matmul_op);
|
||||||
// TODO(bgogul): Add some death tests where status is not OK.
|
// TODO(bgogul): Add some death tests where status is not OK.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user