Implements AbstractTensor::Shape(PartialTensorShape*).

Adds shape argument to `TracingContext::AddParameter`.

PiperOrigin-RevId: 342336504
Change-Id: I3d20c6992f03290866a01bb625d9879eded258ae
This commit is contained in:
Saurabh Saxena 2020-11-13 14:15:50 -08:00 committed by TensorFlower Gardener
parent 242e920d7f
commit 0dd94c4ad3
14 changed files with 358 additions and 22 deletions

View File

@ -176,6 +176,7 @@ cc_library(
"//tensorflow/c:c_api_internal",
"//tensorflow/c:conversion_macros",
"//tensorflow/c:tf_status",
"//tensorflow/core:framework",
"//tensorflow/core/platform:casts",
"//tensorflow/core/platform:types",
],
@ -244,6 +245,7 @@ cc_library(
":c_api_unified_internal",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper",
"//tensorflow/core:framework",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",
@ -288,6 +290,29 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "unified_api_test",
size = "small",
srcs = [
"unified_api_test.cc",
],
args = ["--heap_check=local"],
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags(),
deps = [
":c_api_experimental",
":c_api_unified_internal",
":unified_api_testutil",
"//tensorflow/c:tf_status_helper",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:errors",
],
)
cc_library(
name = "gradients_util",
srcs = [
@ -477,8 +502,10 @@ cc_library(
"//tensorflow:internal",
],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:refcount",
"//tensorflow/core/platform:status",
],
)

View File

@ -17,8 +17,10 @@ limitations under the License.
#include <memory>
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
// Abstract interface to a Tensor handle in either tracing or immediate
@ -32,6 +34,9 @@ class AbstractTensorHandle : public core::RefCounted {
public:
// Returns tensor dtype.
virtual tensorflow::DataType DataType() const = 0;
// Returns tensor shape. If tensor has unknown rank, shape remains untouched.
virtual tensorflow::Status Shape(
tensorflow::PartialTensorShape* shape) const = 0;
AbstractTensorHandleKind getKind() const { return kind_; }

View File

@ -134,7 +134,9 @@ TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
}
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
TF_DataType dtype, TF_Status* s) {
TF_DataType dtype, TF_Shape shape,
TF_Status* s) {
DCHECK_GE(shape.num_dims, -1);
TracingTensorHandle* t;
TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(func));
if (!tracing_ctx) {
@ -143,8 +145,20 @@ TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
"TF_AddFunctionParameter must be called on a TracingContext."));
return nullptr;
}
tensorflow::PartialTensorShape partial_shape;
if (shape.num_dims != -1) {
DCHECK(shape.dim_sizes != nullptr);
Status status = tensorflow::PartialTensorShape::MakePartialShape(
reinterpret_cast<tensorflow::int64*>(shape.dim_sizes), shape.num_dims,
&partial_shape);
if (!status.ok()) {
Set_TF_Status_from_Status(s, status);
return nullptr;
}
}
Set_TF_Status_from_Status(
s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), &t));
s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), partial_shape,
&t));
return wrap(t);
}

View File

@ -64,10 +64,16 @@ TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
TF_Status* s);
void TF_DeleteExecutionContext(TF_ExecutionContext*);
// Represents a (partially-defined) shape.
typedef struct TF_Shape {
int num_dims; // Must be >= -1; -1 represents unknown rank.
int64_t* dim_sizes;
} TF_Shape;
// Add a new parameter to a TensorFlow Function.
// TODO(aminim): what about shape?
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
TF_DataType dtype, TF_Status* s);
TF_DataType dtype, TF_Shape shape,
TF_Status* s);
// Create an operation suitable to use with the provided context. The operation
// requires its type (e.g. "AddV2") to be set independently.

View File

@ -25,6 +25,8 @@ limitations under the License.
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
@ -43,22 +45,50 @@ class GraphContext;
class GraphOperation;
class GraphTensor;
auto& kUnknownDim = shape_inference::InferenceContext::kUnknownDim;
auto& kUnknownRank = shape_inference::InferenceContext::kUnknownRank;
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
// into the list of outputs for the operation.
class GraphTensor : public TracingTensorHandle {
public:
explicit GraphTensor(TF_Output output)
: TracingTensorHandle(kGraph), output_(output) {}
explicit GraphTensor(TF_Output output, TF_Graph* graph)
: TracingTensorHandle(kGraph), output_(output), graph_(graph) {}
tensorflow::DataType DataType() const override {
return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));
}
tensorflow::Status Shape(
tensorflow::PartialTensorShape* shape) const override {
DCHECK(shape != nullptr);
TF_Status status;
int num_dims = TF_GraphGetTensorNumDims(graph_, output_, &status);
DCHECK_GE(num_dims, -1);
TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
if (num_dims == kUnknownRank) {
return Status::OK();
}
std::vector<int64> dims(num_dims, kUnknownDim);
TF_GraphGetTensorShape(graph_, output_,
reinterpret_cast<int64_t*>(dims.data()), num_dims,
&status);
TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape));
return Status::OK();
}
TF_Output output_;
// For LLVM style RTTI.
static bool classof(const AbstractTensorHandle* ptr) {
return ptr->getKind() == kGraph;
}
private:
TF_Graph* graph_; // For shape inference.
};
// GraphOperation wraps and populates a TF_OperationDescription.
@ -135,7 +165,7 @@ class GraphOperation : public TracingOperation {
TF_DeleteStatus(s);
*num_retvals = TF_OperationNumOutputs(operation);
for (int i = 0; i < *num_retvals; ++i) {
retvals[i] = new GraphTensor({operation, i});
retvals[i] = new GraphTensor({operation, i}, g_);
}
return Status::OK();
}
@ -326,12 +356,18 @@ class GraphContext : public TracingContext {
return new GraphOperation(graph_.get());
}
Status AddParameter(DataType dtype, TracingTensorHandle** output) override {
Status AddParameter(DataType dtype, const PartialTensorShape& shape,
TracingTensorHandle** output) override {
TracingOperationPtr operation(CreateOperation());
TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr));
TF_RETURN_IF_ERROR(
operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str()));
TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype));
if (!shape.unknown_rank()) {
TF_RETURN_IF_ERROR(operation->SetAttrShape(
"shape", reinterpret_cast<int64_t*>(shape.dim_sizes().data()),
shape.dims()));
}
int num_outputs = 1;
std::vector<AbstractTensorHandle*> outputs(num_outputs);
TF_RETURN_IF_ERROR(operation->Execute(

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/types.h"
@ -107,7 +108,8 @@ class TracingContext : public AbstractContext {
public:
// Add a function parameter and return the corresponding tensor.
virtual Status AddParameter(DataType dtype, TracingTensorHandle**) = 0;
virtual Status AddParameter(DataType dtype, const PartialTensorShape& shape,
TracingTensorHandle**) = 0;
// Finalize this context and make a function out of it. The context is in a
// invalid state after this call and must be destroyed.

View File

@ -359,7 +359,7 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
auto* placeholder_t =
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
@ -450,7 +450,7 @@ TEST_P(UnifiedCAPI, TestBasicGraphMatMul) {
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
auto* placeholder_t =
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
// Build an abstract operation.
@ -553,9 +553,9 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Create a first "Add" computing `arg0 + arg1`.
@ -709,9 +709,9 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraphMatMul) {
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
// Create a first "Add" computing `arg0 + arg1`.
@ -975,7 +975,7 @@ TEST_P(UnifiedCAPI, TF_AbstractTensorGetEagerTensorOnGraphTensorRaises) {
// Add a placeholder to the graph.
auto placeholder_t =
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
TF_AbstractTensorGetEagerTensor(placeholder_t, status.get());
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));

View File

@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
@ -224,8 +225,10 @@ Status CreateParamsForInputs(AbstractContext* ctx,
vector<AbstractTensorHandle*>* params) {
tracing::TracingTensorHandle* handle = nullptr;
for (auto input : inputs) {
PartialTensorShape shape;
TF_RETURN_IF_ERROR(input->Shape(&shape));
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
input->DataType(), &handle));
input->DataType(), shape, &handle));
params->emplace_back(handle);
}
return Status::OK();
@ -314,4 +317,4 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
}
} // namespace gradients
} // namespace tensorflow
} // namespace tensorflow

View File

@ -0,0 +1,205 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/unified_api_testutil.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace {
class UnifiedAPI
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
}
public:
bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
bool UseFunction() const { return std::get<2>(GetParam()); }
};
// Checks that inputs[0] is a scalar.
Status TestScalarShape(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
PartialTensorShape shape;
TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape));
if (shape.dims() != 0) {
return errors::InvalidArgument(
"Tensor expected to have scalar shape found rank: ", shape.dims());
}
return Status::OK();
}
TEST_P(UnifiedAPI, TestTensorShapeScalar) {
if (UseFunction() && UseMlir()) {
// TODO(b/173074167): Remove this.
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
}
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
Status s = RunModel(TestScalarShape, ctx.get(),
/*inputs=*/{x.get()},
/*outputs=*/{},
/*use_function=*/UseFunction());
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
}
// Checks that inputs[0] is a matrix with shape 2x4.
Status TestTensorShape2x4(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
PartialTensorShape shape;
TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape));
if (shape.dims() != 2) {
return errors::InvalidArgument(
"Tensor expected to have rank 2 found rank: ", shape.dims());
}
int64 dim_sizes[] = {2, 4};
for (int i = 0; i < shape.dims(); i++) {
if (shape.dim_size(i) != dim_sizes[i]) {
return errors::InvalidArgument("Dim ", i, " expected to be of size ",
dim_sizes[i],
" found: ", shape.dim_size(i));
}
}
return Status::OK();
}
TEST_P(UnifiedAPI, TestTensorShape2x4) {
if (UseFunction() && UseMlir()) {
// TODO(b/173074167): Remove this.
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
}
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
float data[] = {0., 0., 0., 0., 0., 0., 0., 0};
int64 dim_sizes[] = {2, 4};
Status s =
TestTensorHandleWithDimsFloat(ctx.get(), data, dim_sizes, 2, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
Status s = RunModel(TestTensorShape2x4, ctx.get(),
/*inputs=*/{x.get()},
/*outputs=*/{},
/*use_function=*/UseFunction());
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
}
TEST_P(UnifiedAPI, TestUnknownShapeTracing) {
if (!UseFunction()) {
GTEST_SKIP() << "Tracing only test.";
}
if (UseMlir()) {
// TODO(b/173074167): Remove this.
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
}
AbstractContextPtr ctx(BuildFunction("test_fn"));
AbstractTensorHandlePtr x;
{
tracing::TracingTensorHandle* x_raw = nullptr;
PartialTensorShape shape;
Status s = dyn_cast<tracing::TracingContext>(ctx.get())->AddParameter(
DT_FLOAT, shape, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
PartialTensorShape shape;
Status s = x->Shape(&shape);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ASSERT_TRUE(shape.unknown_rank());
}
TEST_P(UnifiedAPI, TestPartialShapeTracing) {
if (!UseFunction()) {
GTEST_SKIP() << "Tracing only test.";
}
if (UseMlir()) {
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
}
AbstractContextPtr ctx(BuildFunction("test_fn"));
AbstractTensorHandlePtr x;
{
tracing::TracingTensorHandle* x_raw = nullptr;
PartialTensorShape shape;
int64 dim_sizes[] = {2, -1};
Status s = PartialTensorShape::MakePartialShape(dim_sizes, 2, &shape);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
s = dyn_cast<tracing::TracingContext>(ctx.get())->AddParameter(
DT_FLOAT, shape, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
PartialTensorShape shape;
Status s = x->Shape(&shape);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ASSERT_FALSE(shape.unknown_rank());
ASSERT_EQ(2, shape.dim_size(0));
ASSERT_EQ(-1, shape.dim_size(1));
}
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCppAPI, UnifiedAPI,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(true, false),
/*use_function*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCppAPI, UnifiedAPI,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*use_function*/ ::testing::Values(true, false)));
#endif
} // namespace
} // namespace tensorflow

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_status.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
@ -38,8 +39,10 @@ Status CreateParamsForInputs(AbstractContext* ctx,
std::vector<AbstractTensorHandle*>* params) {
tracing::TracingTensorHandle* handle = nullptr;
for (auto input : inputs) {
PartialTensorShape shape;
TF_RETURN_IF_ERROR(input->Shape(&shape));
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
input->DataType(), &handle));
input->DataType(), shape, &handle));
params->emplace_back(handle);
}
return Status::OK();

View File

@ -54,6 +54,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
@ -102,6 +103,13 @@ class MlirTensor : public TracingTensorHandle {
return type;
}
tensorflow::Status Shape(
tensorflow::PartialTensorShape* shape) const override {
// TODO(b/173074167): Implement this and enable tests in
// unified_api_test.cc.
return Unimplemented("MlirTensor::Shape is not implemented yet.");
}
Value getValue() { return value_; }
Type getElementType() {
return value_.getType().cast<ShapedType>().getElementType();
@ -250,6 +258,7 @@ class MlirFunctionContext : public TracingContext {
return new MlirAbstractOp(context_.get(), this);
}
Status AddParameter(tensorflow::DataType dtype,
const tensorflow::PartialTensorShape& shape,
TracingTensorHandle** handle) override;
Status Finalize(OutputList* outputs, AbstractFunction** f) override;
@ -547,8 +556,11 @@ Operation* MlirFunctionContext::CreateOperationFromState(
return builder_.createOperation(state);
}
Status MlirFunctionContext::AddParameter(tensorflow::DataType dtype,
TracingTensorHandle** handle) {
Status MlirFunctionContext::AddParameter(
tensorflow::DataType dtype, const tensorflow::PartialTensorShape& shape,
TracingTensorHandle** handle) {
// TODO(b/173073199): Use shape. Enable tests in unified_api_test.cc once
// resolved.
Type type;
TF_RETURN_IF_ERROR(ConvertDataTypeToTensor(dtype, builder_, &type));
*handle = new MlirTensor(func_.getBody().front().addArgument(type));

View File

@ -633,6 +633,25 @@ Status TensorHandle::CopyInferenceShape(TensorHandle* other) {
return Status::OK();
}
Status TensorHandle::Shape(tensorflow::PartialTensorShape* shape) const {
DCHECK(shape != nullptr);
if (!IsReady() && !inference_shape_.unknown_rank()) {
*shape = inference_shape_;
return Status::OK();
} else {
auto result = absl::visit(
[](auto& data) {
TensorShape shape;
Status s = data.Shape(&shape);
return std::make_pair(shape, s);
},
data_);
TF_RETURN_IF_ERROR(result.second);
*shape = result.first;
}
return Status::OK();
}
Status TensorHandle::NumDims(int* num_dims) const {
DCHECK(num_dims != nullptr);
if (!IsReady() && !inference_shape_.unknown_rank()) {

View File

@ -125,6 +125,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
void Release() override;
tensorflow::DataType DataType() const override;
Status Shape(tensorflow::PartialTensorShape* shape) const override;
Status NumDims(int* num_dims) const override;
Status NumElements(int64* num_elements) const override;
Status Dim(int dim_index, int64* dim) const override;

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
#include "tensorflow/c/eager/tfe_context_internal.h"
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
@ -132,7 +133,9 @@ PYBIND11_MODULE(_unified_api, m) {
.def("AddParameter",
[](TracingContext* self, DataType dtype) {
TracingTensorHandle* handle = nullptr;
Status s = self->AddParameter(dtype, &handle);
// TODO(srbs): Add shape argument to this function.
tensorflow::PartialTensorShape shape;
Status s = self->AddParameter(dtype, shape, &handle);
MaybeRaiseRegisteredFromStatus(s);
return static_cast<AbstractTensorHandle*>(handle);
})