From 038d0cdabcad49398d51b960486f4fa38f532449 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 5 Aug 2019 13:51:51 -0700 Subject: [PATCH] Handle input_tensors in C shape inference API. PiperOrigin-RevId: 261761496 --- tensorflow/c/c_api_experimental.cc | 25 +++++++++++-- tensorflow/c/c_api_experimental.h | 13 +++++-- tensorflow/c/c_api_experimental_test.cc | 50 +++++++++++++++++++++---- 3 files changed, 74 insertions(+), 14 deletions(-) diff --git a/tensorflow/c/c_api_experimental.cc b/tensorflow/c/c_api_experimental.cc index ea1d538de16..af1c0ea6833 100644 --- a/tensorflow/c/c_api_experimental.cc +++ b/tensorflow/c/c_api_experimental.cc @@ -1053,6 +1053,10 @@ void TF_DeleteShapeAndTypeListArray(TF_ShapeAndTypeList** shape_list_array, delete[] shape_list_array; } +namespace tensorflow { +Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); +} // namespace tensorflow + void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, TF_Tensor** input_tensors, TF_ShapeAndTypeList* input_tensors_as_shapes, @@ -1082,10 +1086,26 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, tensorflow::OpRegistry::Global()->LookUp(node_def.op(), &op_reg_data); if (!status->status.ok()) return; + // Initialize a input_tensor vector with `nullptr` values. + std::vector input_tensors_vector(num_inputs, nullptr); + // A vector to keep track of newly created `tf::Tensor` objects. + std::vector all_input_tensors; + // Update the vector with information from `input_tensors` if provided. + if (input_tensors != nullptr) { + for (int i = 0; i < num_inputs; ++i) { + if (input_tensors[i] == nullptr) continue; + all_input_tensors.emplace_back(); + Tensor& input_tensor = all_input_tensors.back(); + status->status = TF_TensorToTensor(input_tensors[i], &input_tensor); + if (!status->status.ok()) return; + input_tensors_vector[i] = &input_tensor; + } + } + // Create an inference context with dummy values, which will be updated later. InferenceContext c(TF_GRAPH_DEF_VERSION, &node_def, op_reg_data->op_def, - std::vector(num_inputs), - std::vector(num_inputs, nullptr), {}, + std::vector(num_inputs), input_tensors_vector, + {}, std::vector>>()); // Set input_shapes. @@ -1102,7 +1122,6 @@ void TFE_InferShapes(TFE_Op* tfe_op, TF_ShapeAndTypeList* input_shapes, c.SetInput(i, c.MakeShape(dims)); } - // TODO(bgogul): Handle input_tensors. // TODO(bgogul): Handle input_tensors_as_shapes. // TODO(bgogul): Handle input_resource_shapes_and_types. diff --git a/tensorflow/c/c_api_experimental.h b/tensorflow/c/c_api_experimental.h index fb2b039f268..126db2640f6 100644 --- a/tensorflow/c/c_api_experimental.h +++ b/tensorflow/c/c_api_experimental.h @@ -378,10 +378,15 @@ TF_CAPI_EXPORT extern void TF_DeleteShapeAndTypeList( TF_CAPI_EXPORT extern void TF_DeleteShapeAndTypeListArray( TF_ShapeAndTypeList** shape_list_array, int num_items); -// Infer shapes for the given `node_def`. The arguments mimic the arguments of -// the `shape_inference::InferenceContext` constructor. The types need not be -// set in `input_shapes` as it is not used for shape inference. The number of -// `input_tensors` should be the same as the number of items in `input_shapes`. +// Infer shapes for the given `op`. The arguments mimic the arguments of the +// `shape_inference::InferenceContext` constructor. Note the following: +// - The inputs of the `op` are not used for shape inference. So, it is +// OK to not have the inputs properly set in `op`. See `input_tensors` +// if you want shape inference to consider the input tensors of the +// op for shape inference. +// - The types need not be set in `input_shapes` as it is not used. +// - The number of `input_tensors` should be the same as the number of items +// in `input_shapes`. // // The results are returned in `output_shapes` and // `output_resource_shapes_and_types`. The caller is responsible for freeing the diff --git a/tensorflow/c/c_api_experimental_test.cc b/tensorflow/c/c_api_experimental_test.cc index 3f0c9de66d6..4b49b90e293 100644 --- a/tensorflow/c/c_api_experimental_test.cc +++ b/tensorflow/c/c_api_experimental_test.cc @@ -439,7 +439,6 @@ class ShapeInferenceTest : public ::testing::Test { : status_(TF_NewStatus()), tfe_context_options_(TFE_NewContextOptions()) { tfe_context_ = TFE_NewContext(tfe_context_options_, status_); CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); - CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); } ~ShapeInferenceTest() override { @@ -452,7 +451,7 @@ class ShapeInferenceTest : public ::testing::Test { void CheckOutputShapes( TFE_Op* op, const std::vector>>& input_shapes_vec, - TF_Tensor** input_tensors, + const std::vector& input_tensors, const absl::optional>& expected_shape) { // Create input_shapes. TF_ShapeAndTypeList* input_shapes = @@ -467,7 +466,10 @@ class ShapeInferenceTest : public ::testing::Test { } } TF_ShapeAndTypeList* output_shapes; - TFE_InferShapes(op, input_shapes, input_tensors, + TFE_InferShapes(op, input_shapes, + input_tensors.empty() + ? nullptr + : const_cast(input_tensors.data()), /*input_tensors_as_shapes*/ nullptr, /*input_resource_shapes_and_types*/ nullptr, &output_shapes, /*output_resource_shapes_and_types*/ nullptr, status_); @@ -515,31 +517,65 @@ TEST_F(ShapeInferenceTest, InfersShapesFromInputShapes) { // Infer shape when everything is known. CheckOutputShapes(matmul_op, /*input_shapes*/ {make_shape({3, 2}), make_shape({2, 4})}, - /*input_tensors*/ nullptr, + /*input_tensors*/ {}, /*expected_shape*/ make_shape({3, 4})); // Infer shape when second operand has unknown shape. CheckOutputShapes(matmul_op, /*input_shapes*/ {make_shape({3, 2}), unknown_shape()}, - /*input_tensors*/ nullptr, + /*input_tensors*/ {}, /*expected_shape*/ make_shape({3, kUnknownDim})); // Infer shape when some dimensions are unknown. CheckOutputShapes( matmul_op, /*input_shapes*/ {make_shape({kUnknownDim, 2}), make_shape({2, 4})}, - /*input_tensors*/ nullptr, + /*input_tensors*/ {}, /*expected_shape*/ make_shape({kUnknownDim, 4})); // Infer shape when everything is unknown. CheckOutputShapes(matmul_op, /*input_shapes*/ {unknown_shape(), unknown_shape()}, - /*input_tensors*/ nullptr, + /*input_tensors*/ {}, /*expected_shape*/ make_shape({kUnknownDim, kUnknownDim})); TFE_DeleteOp(matmul_op); // TODO(bgogul): Add some death tests where status is not OK. } +TEST_F(ShapeInferenceTest, InfersShapesFromInputTensors) { + // Prepare some tensors for shape. + TF_Tensor* tensor_1X6 = Int32Tensor({1, 6}); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TF_Tensor* tensor_1X1X6 = Int32Tensor({1, 1, 6}); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + + TFE_Op* reshape_op = TFE_NewOp(tfe_context_, "Reshape", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_OpSetAttrType(reshape_op, "T", TF_FLOAT); + TFE_OpSetAttrType(reshape_op, "Tshape", TF_INT32); + CheckOutputShapes(reshape_op, + /* input_shapes*/ {unknown_shape(), unknown_shape()}, + /* input_tensors*/ {nullptr, tensor_1X6}, + /*expected_shape*/ make_shape({1, 6})); + TFE_DeleteOp(reshape_op); + reshape_op = nullptr; + + TFE_Op* fill_op = TFE_NewOp(tfe_context_, "Fill", status_); + CHECK_EQ(TF_OK, TF_GetCode(status_)) << TF_Message(status_); + TFE_OpSetAttrType(fill_op, "T", TF_FLOAT); + TFE_OpSetAttrType(fill_op, "Tshape", TF_INT32); + + CheckOutputShapes(fill_op, + /* input_shapes*/ {unknown_shape(), unknown_shape()}, + /* input_tensors*/ {tensor_1X1X6, nullptr}, + /*expected_shape*/ make_shape({1, 1, 6})); + TFE_DeleteOp(fill_op); + fill_op = nullptr; + + TF_DeleteTensor(tensor_1X1X6); + TF_DeleteTensor(tensor_1X6); +} + } // namespace } // namespace tensorflow