diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index a39365b6c5e..5c085a4b9b8 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -176,6 +176,7 @@ cc_library( "//tensorflow/c:c_api_internal", "//tensorflow/c:conversion_macros", "//tensorflow/c:tf_status", + "//tensorflow/core:framework", "//tensorflow/core/platform:casts", "//tensorflow/core/platform:types", ], @@ -244,6 +245,7 @@ cc_library( ":c_api_unified_internal", "//tensorflow/c:tf_status", "//tensorflow/c:tf_status_helper", + "//tensorflow/core:framework", "//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/platform:errors", "//tensorflow/core/platform:status", @@ -288,6 +290,29 @@ tf_cuda_cc_test( ], ) +tf_cuda_cc_test( + name = "unified_api_test", + size = "small", + srcs = [ + "unified_api_test.cc", + ], + args = ["--heap_check=local"], + linkstatic = tf_kernel_tests_linkstatic(), + tags = tf_cuda_tests_tags(), + deps = [ + ":c_api_experimental", + ":c_api_unified_internal", + ":unified_api_testutil", + "//tensorflow/c:tf_status_helper", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/lib/llvm_rtti", + "//tensorflow/core/platform:errors", + ], +) + cc_library( name = "gradients_util", srcs = [ @@ -477,8 +502,10 @@ cc_library( "//tensorflow:internal", ], deps = [ + "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:refcount", + "//tensorflow/core/platform:status", ], ) diff --git a/tensorflow/c/eager/abstract_tensor_handle.h b/tensorflow/c/eager/abstract_tensor_handle.h index 37e6d1bf29c..1ca4a9a8ecb 100644 --- a/tensorflow/c/eager/abstract_tensor_handle.h +++ b/tensorflow/c/eager/abstract_tensor_handle.h @@ -17,8 +17,10 @@ limitations under the License. #include +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/refcount.h" +#include "tensorflow/core/platform/status.h" namespace tensorflow { // Abstract interface to a Tensor handle in either tracing or immediate @@ -32,6 +34,9 @@ class AbstractTensorHandle : public core::RefCounted { public: // Returns tensor dtype. virtual tensorflow::DataType DataType() const = 0; + // Returns tensor shape. If tensor has unknown rank, shape remains untouched. + virtual tensorflow::Status Shape( + tensorflow::PartialTensorShape* shape) const = 0; AbstractTensorHandleKind getKind() const { return kind_; } diff --git a/tensorflow/c/eager/c_api_unified_experimental.cc b/tensorflow/c/eager/c_api_unified_experimental.cc index 2d290df19ce..f89d3e84cf4 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.cc +++ b/tensorflow/c/eager/c_api_unified_experimental.cc @@ -134,7 +134,9 @@ TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx, } TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, - TF_DataType dtype, TF_Status* s) { + TF_DataType dtype, TF_Shape shape, + TF_Status* s) { + DCHECK_GE(shape.num_dims, -1); TracingTensorHandle* t; TracingContext* tracing_ctx = dyn_cast(unwrap(func)); if (!tracing_ctx) { @@ -143,8 +145,20 @@ TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, "TF_AddFunctionParameter must be called on a TracingContext.")); return nullptr; } + tensorflow::PartialTensorShape partial_shape; + if (shape.num_dims != -1) { + DCHECK(shape.dim_sizes != nullptr); + Status status = tensorflow::PartialTensorShape::MakePartialShape( + reinterpret_cast(shape.dim_sizes), shape.num_dims, + &partial_shape); + if (!status.ok()) { + Set_TF_Status_from_Status(s, status); + return nullptr; + } + } Set_TF_Status_from_Status( - s, tracing_ctx->AddParameter(static_cast(dtype), &t)); + s, tracing_ctx->AddParameter(static_cast(dtype), partial_shape, + &t)); return wrap(t); } diff --git a/tensorflow/c/eager/c_api_unified_experimental.h b/tensorflow/c/eager/c_api_unified_experimental.h index d216b4e694b..ee22695632f 100644 --- a/tensorflow/c/eager/c_api_unified_experimental.h +++ b/tensorflow/c/eager/c_api_unified_experimental.h @@ -64,10 +64,16 @@ TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*, TF_Status* s); void TF_DeleteExecutionContext(TF_ExecutionContext*); +// Represents a (partially-defined) shape. +typedef struct TF_Shape { + int num_dims; // Must be >= -1; -1 represents unknown rank. + int64_t* dim_sizes; +} TF_Shape; + // Add a new parameter to a TensorFlow Function. -// TODO(aminim): what about shape? TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func, - TF_DataType dtype, TF_Status* s); + TF_DataType dtype, TF_Shape shape, + TF_Status* s); // Create an operation suitable to use with the provided context. The operation // requires its type (e.g. "AddV2") to be set independently. diff --git a/tensorflow/c/eager/c_api_unified_experimental_graph.cc b/tensorflow/c/eager/c_api_unified_experimental_graph.cc index 0e9d6c18157..b229abb0cb6 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_graph.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_graph.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/errors.h" @@ -43,22 +45,50 @@ class GraphContext; class GraphOperation; class GraphTensor; +auto& kUnknownDim = shape_inference::InferenceContext::kUnknownDim; +auto& kUnknownRank = shape_inference::InferenceContext::kUnknownRank; + // GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index // into the list of outputs for the operation. class GraphTensor : public TracingTensorHandle { public: - explicit GraphTensor(TF_Output output) - : TracingTensorHandle(kGraph), output_(output) {} + explicit GraphTensor(TF_Output output, TF_Graph* graph) + : TracingTensorHandle(kGraph), output_(output), graph_(graph) {} tensorflow::DataType DataType() const override { return static_cast(TF_OperationOutputType(output_)); } + + tensorflow::Status Shape( + tensorflow::PartialTensorShape* shape) const override { + DCHECK(shape != nullptr); + TF_Status status; + int num_dims = TF_GraphGetTensorNumDims(graph_, output_, &status); + DCHECK_GE(num_dims, -1); + TF_RETURN_IF_ERROR(StatusFromTF_Status(&status)); + if (num_dims == kUnknownRank) { + return Status::OK(); + } + + std::vector dims(num_dims, kUnknownDim); + TF_GraphGetTensorShape(graph_, output_, + reinterpret_cast(dims.data()), num_dims, + &status); + TF_RETURN_IF_ERROR(StatusFromTF_Status(&status)); + TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape)); + + return Status::OK(); + } + TF_Output output_; // For LLVM style RTTI. static bool classof(const AbstractTensorHandle* ptr) { return ptr->getKind() == kGraph; } + + private: + TF_Graph* graph_; // For shape inference. }; // GraphOperation wraps and populates a TF_OperationDescription. @@ -135,7 +165,7 @@ class GraphOperation : public TracingOperation { TF_DeleteStatus(s); *num_retvals = TF_OperationNumOutputs(operation); for (int i = 0; i < *num_retvals; ++i) { - retvals[i] = new GraphTensor({operation, i}); + retvals[i] = new GraphTensor({operation, i}, g_); } return Status::OK(); } @@ -326,12 +356,18 @@ class GraphContext : public TracingContext { return new GraphOperation(graph_.get()); } - Status AddParameter(DataType dtype, TracingTensorHandle** output) override { + Status AddParameter(DataType dtype, const PartialTensorShape& shape, + TracingTensorHandle** output) override { TracingOperationPtr operation(CreateOperation()); TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr)); TF_RETURN_IF_ERROR( operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str())); TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype)); + if (!shape.unknown_rank()) { + TF_RETURN_IF_ERROR(operation->SetAttrShape( + "shape", reinterpret_cast(shape.dim_sizes().data()), + shape.dims())); + } int num_outputs = 1; std::vector outputs(num_outputs); TF_RETURN_IF_ERROR(operation->Execute( diff --git a/tensorflow/c/eager/c_api_unified_experimental_internal.h b/tensorflow/c/eager/c_api_unified_experimental_internal.h index 9433fe8f120..cd0d7610c7f 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_internal.h +++ b/tensorflow/c/eager/c_api_unified_experimental_internal.h @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/tf_datatype.h" #include "tensorflow/c/tf_status.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/platform/casts.h" #include "tensorflow/core/platform/types.h" @@ -107,7 +108,8 @@ class TracingContext : public AbstractContext { public: // Add a function parameter and return the corresponding tensor. - virtual Status AddParameter(DataType dtype, TracingTensorHandle**) = 0; + virtual Status AddParameter(DataType dtype, const PartialTensorShape& shape, + TracingTensorHandle**) = 0; // Finalize this context and make a function out of it. The context is in a // invalid state after this call and must be destroyed. diff --git a/tensorflow/c/eager/c_api_unified_experimental_test.cc b/tensorflow/c/eager/c_api_unified_experimental_test.cc index 432ddb4b2d4..71dcfc4dcd2 100644 --- a/tensorflow/c/eager/c_api_unified_experimental_test.cc +++ b/tensorflow/c/eager/c_api_unified_experimental_test.cc @@ -359,7 +359,7 @@ TEST_P(UnifiedCAPI, TestBasicGraph) { ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); auto* placeholder_t = - TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get()); + TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Build an abstract operation. @@ -450,7 +450,7 @@ TEST_P(UnifiedCAPI, TestBasicGraphMatMul) { ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); auto* placeholder_t = - TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get()); + TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get()); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); // Build an abstract operation. @@ -553,9 +553,9 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) { TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); + auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); + auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); // Create a first "Add" computing `arg0 + arg1`. @@ -709,9 +709,9 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraphMatMul) { TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); + auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s); + auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); // Create a first "Add" computing `arg0 + arg1`. @@ -975,7 +975,7 @@ TEST_P(UnifiedCAPI, TF_AbstractTensorGetEagerTensorOnGraphTensorRaises) { // Add a placeholder to the graph. auto placeholder_t = - TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get()); + TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get()); TF_AbstractTensorGetEagerTensor(placeholder_t, status.get()); ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get())); diff --git a/tensorflow/c/eager/gradients_util.cc b/tensorflow/c/eager/gradients_util.cc index e53faf4a3f3..72e523d7ba7 100644 --- a/tensorflow/c/eager/gradients_util.cc +++ b/tensorflow/c/eager/gradients_util.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/c/experimental/ops/nn_ops.h" #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/errors.h" @@ -224,8 +225,10 @@ Status CreateParamsForInputs(AbstractContext* ctx, vector* params) { tracing::TracingTensorHandle* handle = nullptr; for (auto input : inputs) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(input->Shape(&shape)); TF_RETURN_IF_ERROR(dyn_cast(ctx)->AddParameter( - input->DataType(), &handle)); + input->DataType(), shape, &handle)); params->emplace_back(handle); } return Status::OK(); @@ -314,4 +317,4 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) { } } // namespace gradients -} // namespace tensorflow \ No newline at end of file +} // namespace tensorflow diff --git a/tensorflow/c/eager/unified_api_test.cc b/tensorflow/c/eager/unified_api_test.cc new file mode 100644 index 00000000000..52de726b291 --- /dev/null +++ b/tensorflow/c/eager/unified_api_test.cc @@ -0,0 +1,205 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/unified_api_testutil.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +class UnifiedAPI + : public ::testing::TestWithParam> { + protected: + void SetUp() override { + TF_StatusPtr status(TF_NewStatus()); + TF_SetTracingImplementation(std::get<0>(GetParam()), status.get()); + Status s = StatusFromTF_Status(status.get()); + CHECK_EQ(errors::OK, s.code()) << s.error_message(); + } + + public: + bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; } + bool UseFunction() const { return std::get<2>(GetParam()); } +}; + +// Checks that inputs[0] is a scalar. +Status TestScalarShape(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape)); + if (shape.dims() != 0) { + return errors::InvalidArgument( + "Tensor expected to have scalar shape found rank: ", shape.dims()); + } + return Status::OK(); +} + +TEST_P(UnifiedAPI, TestTensorShapeScalar) { + if (UseFunction() && UseMlir()) { + // TODO(b/173074167): Remove this. + GTEST_SKIP() << "MlirTensor::Shape is not implemented yet."; + } + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + Status s = RunModel(TestScalarShape, ctx.get(), + /*inputs=*/{x.get()}, + /*outputs=*/{}, + /*use_function=*/UseFunction()); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); +} + +// Checks that inputs[0] is a matrix with shape 2x4. +Status TestTensorShape2x4(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape)); + if (shape.dims() != 2) { + return errors::InvalidArgument( + "Tensor expected to have rank 2 found rank: ", shape.dims()); + } + int64 dim_sizes[] = {2, 4}; + for (int i = 0; i < shape.dims(); i++) { + if (shape.dim_size(i) != dim_sizes[i]) { + return errors::InvalidArgument("Dim ", i, " expected to be of size ", + dim_sizes[i], + " found: ", shape.dim_size(i)); + } + } + return Status::OK(); +} + +TEST_P(UnifiedAPI, TestTensorShape2x4) { + if (UseFunction() && UseMlir()) { + // TODO(b/173074167): Remove this. + GTEST_SKIP() << "MlirTensor::Shape is not implemented yet."; + } + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + float data[] = {0., 0., 0., 0., 0., 0., 0., 0}; + int64 dim_sizes[] = {2, 4}; + Status s = + TestTensorHandleWithDimsFloat(ctx.get(), data, dim_sizes, 2, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + Status s = RunModel(TestTensorShape2x4, ctx.get(), + /*inputs=*/{x.get()}, + /*outputs=*/{}, + /*use_function=*/UseFunction()); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); +} + +TEST_P(UnifiedAPI, TestUnknownShapeTracing) { + if (!UseFunction()) { + GTEST_SKIP() << "Tracing only test."; + } + if (UseMlir()) { + // TODO(b/173074167): Remove this. + GTEST_SKIP() << "MlirTensor::Shape is not implemented yet."; + } + AbstractContextPtr ctx(BuildFunction("test_fn")); + AbstractTensorHandlePtr x; + { + tracing::TracingTensorHandle* x_raw = nullptr; + PartialTensorShape shape; + Status s = dyn_cast(ctx.get())->AddParameter( + DT_FLOAT, shape, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + PartialTensorShape shape; + Status s = x->Shape(&shape); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_TRUE(shape.unknown_rank()); +} + +TEST_P(UnifiedAPI, TestPartialShapeTracing) { + if (!UseFunction()) { + GTEST_SKIP() << "Tracing only test."; + } + if (UseMlir()) { + GTEST_SKIP() << "MlirTensor::Shape is not implemented yet."; + } + AbstractContextPtr ctx(BuildFunction("test_fn")); + AbstractTensorHandlePtr x; + { + tracing::TracingTensorHandle* x_raw = nullptr; + PartialTensorShape shape; + int64 dim_sizes[] = {2, -1}; + Status s = PartialTensorShape::MakePartialShape(dim_sizes, 2, &shape); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + s = dyn_cast(ctx.get())->AddParameter( + DT_FLOAT, shape, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + PartialTensorShape shape; + Status s = x->Shape(&shape); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_FALSE(shape.unknown_rank()); + + ASSERT_EQ(2, shape.dim_size(0)); + ASSERT_EQ(-1, shape.dim_size(1)); +} + +#ifdef PLATFORM_GOOGLE +INSTANTIATE_TEST_SUITE_P( + UnifiedCppAPI, UnifiedAPI, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(true, false), + /*use_function*/ ::testing::Values(true, false))); +#else +INSTANTIATE_TEST_SUITE_P( + UnifiedCppAPI, UnifiedAPI, + ::testing::Combine(::testing::Values("graphdef", "mlir"), + /*tfrt*/ ::testing::Values(false), + /*use_function*/ ::testing::Values(true, false))); +#endif +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/eager/unified_api_testutil.cc b/tensorflow/c/eager/unified_api_testutil.cc index 5b20b01e42d..9e8683df0e7 100644 --- a/tensorflow/c/eager/unified_api_testutil.cc +++ b/tensorflow/c/eager/unified_api_testutil.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/errors.h" @@ -38,8 +39,10 @@ Status CreateParamsForInputs(AbstractContext* ctx, std::vector* params) { tracing::TracingTensorHandle* handle = nullptr; for (auto input : inputs) { + PartialTensorShape shape; + TF_RETURN_IF_ERROR(input->Shape(&shape)); TF_RETURN_IF_ERROR(dyn_cast(ctx)->AddParameter( - input->DataType(), &handle)); + input->DataType(), shape, &handle)); params->emplace_back(handle); } return Status::OK(); diff --git a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc index 32c51f2e2bd..d2cc3069c18 100644 --- a/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc +++ b/tensorflow/compiler/mlir/tensorflow/c/c_api_unified_experimental_mlir.cc @@ -54,6 +54,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/errors.h" @@ -102,6 +103,13 @@ class MlirTensor : public TracingTensorHandle { return type; } + tensorflow::Status Shape( + tensorflow::PartialTensorShape* shape) const override { + // TODO(b/173074167): Implement this and enable tests in + // unified_api_test.cc. + return Unimplemented("MlirTensor::Shape is not implemented yet."); + } + Value getValue() { return value_; } Type getElementType() { return value_.getType().cast().getElementType(); @@ -250,6 +258,7 @@ class MlirFunctionContext : public TracingContext { return new MlirAbstractOp(context_.get(), this); } Status AddParameter(tensorflow::DataType dtype, + const tensorflow::PartialTensorShape& shape, TracingTensorHandle** handle) override; Status Finalize(OutputList* outputs, AbstractFunction** f) override; @@ -547,8 +556,11 @@ Operation* MlirFunctionContext::CreateOperationFromState( return builder_.createOperation(state); } -Status MlirFunctionContext::AddParameter(tensorflow::DataType dtype, - TracingTensorHandle** handle) { +Status MlirFunctionContext::AddParameter( + tensorflow::DataType dtype, const tensorflow::PartialTensorShape& shape, + TracingTensorHandle** handle) { + // TODO(b/173073199): Use shape. Enable tests in unified_api_test.cc once + // resolved. Type type; TF_RETURN_IF_ERROR(ConvertDataTypeToTensor(dtype, builder_, &type)); *handle = new MlirTensor(func_.getBody().front().addArgument(type)); diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index da37ad1b480..ca13190748c 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -633,6 +633,25 @@ Status TensorHandle::CopyInferenceShape(TensorHandle* other) { return Status::OK(); } +Status TensorHandle::Shape(tensorflow::PartialTensorShape* shape) const { + DCHECK(shape != nullptr); + if (!IsReady() && !inference_shape_.unknown_rank()) { + *shape = inference_shape_; + return Status::OK(); + } else { + auto result = absl::visit( + [](auto& data) { + TensorShape shape; + Status s = data.Shape(&shape); + return std::make_pair(shape, s); + }, + data_); + TF_RETURN_IF_ERROR(result.second); + *shape = result.first; + } + return Status::OK(); +} + Status TensorHandle::NumDims(int* num_dims) const { DCHECK(num_dims != nullptr); if (!IsReady() && !inference_shape_.unknown_rank()) { diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index b2bb24f5bc0..396af4166c7 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -125,6 +125,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle { void Release() override; tensorflow::DataType DataType() const override; + Status Shape(tensorflow::PartialTensorShape* shape) const override; Status NumDims(int* num_dims) const override; Status NumElements(int64* num_elements) const override; Status Dim(int dim_index, int64* dim) const override; diff --git a/tensorflow/python/framework/experimental/unified_api.cc b/tensorflow/python/framework/experimental/unified_api.cc index 96bf2232a1e..f12353e3700 100644 --- a/tensorflow/python/framework/experimental/unified_api.cc +++ b/tensorflow/python/framework/experimental/unified_api.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/eager/tfe_context_internal.h" #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" +#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" @@ -132,7 +133,9 @@ PYBIND11_MODULE(_unified_api, m) { .def("AddParameter", [](TracingContext* self, DataType dtype) { TracingTensorHandle* handle = nullptr; - Status s = self->AddParameter(dtype, &handle); + // TODO(srbs): Add shape argument to this function. + tensorflow::PartialTensorShape shape; + Status s = self->AddParameter(dtype, shape, &handle); MaybeRaiseRegisteredFromStatus(s); return static_cast(handle); })