Implements AbstractTensor::Shape(PartialTensorShape*)
.
Adds shape argument to `TracingContext::AddParameter`. PiperOrigin-RevId: 342336504 Change-Id: I3d20c6992f03290866a01bb625d9879eded258ae
This commit is contained in:
parent
242e920d7f
commit
0dd94c4ad3
@ -176,6 +176,7 @@ cc_library(
|
||||
"//tensorflow/c:c_api_internal",
|
||||
"//tensorflow/c:conversion_macros",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/platform:casts",
|
||||
"//tensorflow/core/platform:types",
|
||||
],
|
||||
@ -244,6 +245,7 @@ cc_library(
|
||||
":c_api_unified_internal",
|
||||
"//tensorflow/c:tf_status",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/platform:status",
|
||||
@ -288,6 +290,29 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "unified_api_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"unified_api_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":c_api_experimental",
|
||||
":c_api_unified_internal",
|
||||
":unified_api_testutil",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients_util",
|
||||
srcs = [
|
||||
@ -477,8 +502,10 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:refcount",
|
||||
"//tensorflow/core/platform:status",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -17,8 +17,10 @@ limitations under the License.
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
namespace tensorflow {
|
||||
|
||||
// Abstract interface to a Tensor handle in either tracing or immediate
|
||||
@ -32,6 +34,9 @@ class AbstractTensorHandle : public core::RefCounted {
|
||||
public:
|
||||
// Returns tensor dtype.
|
||||
virtual tensorflow::DataType DataType() const = 0;
|
||||
// Returns tensor shape. If tensor has unknown rank, shape remains untouched.
|
||||
virtual tensorflow::Status Shape(
|
||||
tensorflow::PartialTensorShape* shape) const = 0;
|
||||
|
||||
AbstractTensorHandleKind getKind() const { return kind_; }
|
||||
|
||||
|
@ -134,7 +134,9 @@ TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
|
||||
}
|
||||
|
||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
TF_DataType dtype, TF_Status* s) {
|
||||
TF_DataType dtype, TF_Shape shape,
|
||||
TF_Status* s) {
|
||||
DCHECK_GE(shape.num_dims, -1);
|
||||
TracingTensorHandle* t;
|
||||
TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(func));
|
||||
if (!tracing_ctx) {
|
||||
@ -143,8 +145,20 @@ TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
"TF_AddFunctionParameter must be called on a TracingContext."));
|
||||
return nullptr;
|
||||
}
|
||||
tensorflow::PartialTensorShape partial_shape;
|
||||
if (shape.num_dims != -1) {
|
||||
DCHECK(shape.dim_sizes != nullptr);
|
||||
Status status = tensorflow::PartialTensorShape::MakePartialShape(
|
||||
reinterpret_cast<tensorflow::int64*>(shape.dim_sizes), shape.num_dims,
|
||||
&partial_shape);
|
||||
if (!status.ok()) {
|
||||
Set_TF_Status_from_Status(s, status);
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
Set_TF_Status_from_Status(
|
||||
s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), &t));
|
||||
s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), partial_shape,
|
||||
&t));
|
||||
return wrap(t);
|
||||
}
|
||||
|
||||
|
@ -64,10 +64,16 @@ TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
|
||||
TF_Status* s);
|
||||
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
||||
|
||||
// Represents a (partially-defined) shape.
|
||||
typedef struct TF_Shape {
|
||||
int num_dims; // Must be >= -1; -1 represents unknown rank.
|
||||
int64_t* dim_sizes;
|
||||
} TF_Shape;
|
||||
|
||||
// Add a new parameter to a TensorFlow Function.
|
||||
// TODO(aminim): what about shape?
|
||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||
TF_DataType dtype, TF_Status* s);
|
||||
TF_DataType dtype, TF_Shape shape,
|
||||
TF_Status* s);
|
||||
|
||||
// Create an operation suitable to use with the provided context. The operation
|
||||
// requires its type (e.g. "AddV2") to be set independently.
|
||||
|
@ -25,6 +25,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
@ -43,22 +45,50 @@ class GraphContext;
|
||||
class GraphOperation;
|
||||
class GraphTensor;
|
||||
|
||||
auto& kUnknownDim = shape_inference::InferenceContext::kUnknownDim;
|
||||
auto& kUnknownRank = shape_inference::InferenceContext::kUnknownRank;
|
||||
|
||||
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
|
||||
// into the list of outputs for the operation.
|
||||
class GraphTensor : public TracingTensorHandle {
|
||||
public:
|
||||
explicit GraphTensor(TF_Output output)
|
||||
: TracingTensorHandle(kGraph), output_(output) {}
|
||||
explicit GraphTensor(TF_Output output, TF_Graph* graph)
|
||||
: TracingTensorHandle(kGraph), output_(output), graph_(graph) {}
|
||||
|
||||
tensorflow::DataType DataType() const override {
|
||||
return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));
|
||||
}
|
||||
|
||||
tensorflow::Status Shape(
|
||||
tensorflow::PartialTensorShape* shape) const override {
|
||||
DCHECK(shape != nullptr);
|
||||
TF_Status status;
|
||||
int num_dims = TF_GraphGetTensorNumDims(graph_, output_, &status);
|
||||
DCHECK_GE(num_dims, -1);
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
|
||||
if (num_dims == kUnknownRank) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
std::vector<int64> dims(num_dims, kUnknownDim);
|
||||
TF_GraphGetTensorShape(graph_, output_,
|
||||
reinterpret_cast<int64_t*>(dims.data()), num_dims,
|
||||
&status);
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
|
||||
TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_Output output_;
|
||||
|
||||
// For LLVM style RTTI.
|
||||
static bool classof(const AbstractTensorHandle* ptr) {
|
||||
return ptr->getKind() == kGraph;
|
||||
}
|
||||
|
||||
private:
|
||||
TF_Graph* graph_; // For shape inference.
|
||||
};
|
||||
|
||||
// GraphOperation wraps and populates a TF_OperationDescription.
|
||||
@ -135,7 +165,7 @@ class GraphOperation : public TracingOperation {
|
||||
TF_DeleteStatus(s);
|
||||
*num_retvals = TF_OperationNumOutputs(operation);
|
||||
for (int i = 0; i < *num_retvals; ++i) {
|
||||
retvals[i] = new GraphTensor({operation, i});
|
||||
retvals[i] = new GraphTensor({operation, i}, g_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -326,12 +356,18 @@ class GraphContext : public TracingContext {
|
||||
return new GraphOperation(graph_.get());
|
||||
}
|
||||
|
||||
Status AddParameter(DataType dtype, TracingTensorHandle** output) override {
|
||||
Status AddParameter(DataType dtype, const PartialTensorShape& shape,
|
||||
TracingTensorHandle** output) override {
|
||||
TracingOperationPtr operation(CreateOperation());
|
||||
TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr));
|
||||
TF_RETURN_IF_ERROR(
|
||||
operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str()));
|
||||
TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype));
|
||||
if (!shape.unknown_rank()) {
|
||||
TF_RETURN_IF_ERROR(operation->SetAttrShape(
|
||||
"shape", reinterpret_cast<int64_t*>(shape.dim_sizes().data()),
|
||||
shape.dims()));
|
||||
}
|
||||
int num_outputs = 1;
|
||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||
TF_RETURN_IF_ERROR(operation->Execute(
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/platform/casts.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
@ -107,7 +108,8 @@ class TracingContext : public AbstractContext {
|
||||
|
||||
public:
|
||||
// Add a function parameter and return the corresponding tensor.
|
||||
virtual Status AddParameter(DataType dtype, TracingTensorHandle**) = 0;
|
||||
virtual Status AddParameter(DataType dtype, const PartialTensorShape& shape,
|
||||
TracingTensorHandle**) = 0;
|
||||
|
||||
// Finalize this context and make a function out of it. The context is in a
|
||||
// invalid state after this call and must be destroyed.
|
||||
|
@ -359,7 +359,7 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
auto* placeholder_t =
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
@ -450,7 +450,7 @@ TEST_P(UnifiedCAPI, TestBasicGraphMatMul) {
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
auto* placeholder_t =
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||
|
||||
// Build an abstract operation.
|
||||
@ -553,9 +553,9 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Create a first "Add" computing `arg0 + arg1`.
|
||||
@ -709,9 +709,9 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraphMatMul) {
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||
|
||||
// Create a first "Add" computing `arg0 + arg1`.
|
||||
@ -975,7 +975,7 @@ TEST_P(UnifiedCAPI, TF_AbstractTensorGetEagerTensorOnGraphTensorRaises) {
|
||||
|
||||
// Add a placeholder to the graph.
|
||||
auto placeholder_t =
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
|
||||
TF_AbstractTensorGetEagerTensor(placeholder_t, status.get());
|
||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
@ -224,8 +225,10 @@ Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
vector<AbstractTensorHandle*>* params) {
|
||||
tracing::TracingTensorHandle* handle = nullptr;
|
||||
for (auto input : inputs) {
|
||||
PartialTensorShape shape;
|
||||
TF_RETURN_IF_ERROR(input->Shape(&shape));
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||
input->DataType(), &handle));
|
||||
input->DataType(), shape, &handle));
|
||||
params->emplace_back(handle);
|
||||
}
|
||||
return Status::OK();
|
||||
@ -314,4 +317,4 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
} // namespace tensorflow
|
||||
|
205
tensorflow/c/eager/unified_api_test.cc
Normal file
205
tensorflow/c/eager/unified_api_test.cc
Normal file
@ -0,0 +1,205 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/unified_api_testutil.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
class UnifiedAPI
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_StatusPtr status(TF_NewStatus());
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||
Status s = StatusFromTF_Status(status.get());
|
||||
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
|
||||
public:
|
||||
bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
|
||||
bool UseFunction() const { return std::get<2>(GetParam()); }
|
||||
};
|
||||
|
||||
// Checks that inputs[0] is a scalar.
|
||||
Status TestScalarShape(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
PartialTensorShape shape;
|
||||
TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape));
|
||||
if (shape.dims() != 0) {
|
||||
return errors::InvalidArgument(
|
||||
"Tensor expected to have scalar shape found rank: ", shape.dims());
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST_P(UnifiedAPI, TestTensorShapeScalar) {
|
||||
if (UseFunction() && UseMlir()) {
|
||||
// TODO(b/173074167): Remove this.
|
||||
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
|
||||
}
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
Status s = RunModel(TestScalarShape, ctx.get(),
|
||||
/*inputs=*/{x.get()},
|
||||
/*outputs=*/{},
|
||||
/*use_function=*/UseFunction());
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
|
||||
// Checks that inputs[0] is a matrix with shape 2x4.
|
||||
Status TestTensorShape2x4(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
PartialTensorShape shape;
|
||||
TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape));
|
||||
if (shape.dims() != 2) {
|
||||
return errors::InvalidArgument(
|
||||
"Tensor expected to have rank 2 found rank: ", shape.dims());
|
||||
}
|
||||
int64 dim_sizes[] = {2, 4};
|
||||
for (int i = 0; i < shape.dims(); i++) {
|
||||
if (shape.dim_size(i) != dim_sizes[i]) {
|
||||
return errors::InvalidArgument("Dim ", i, " expected to be of size ",
|
||||
dim_sizes[i],
|
||||
" found: ", shape.dim_size(i));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST_P(UnifiedAPI, TestTensorShape2x4) {
|
||||
if (UseFunction() && UseMlir()) {
|
||||
// TODO(b/173074167): Remove this.
|
||||
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
|
||||
}
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
float data[] = {0., 0., 0., 0., 0., 0., 0., 0};
|
||||
int64 dim_sizes[] = {2, 4};
|
||||
Status s =
|
||||
TestTensorHandleWithDimsFloat(ctx.get(), data, dim_sizes, 2, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
Status s = RunModel(TestTensorShape2x4, ctx.get(),
|
||||
/*inputs=*/{x.get()},
|
||||
/*outputs=*/{},
|
||||
/*use_function=*/UseFunction());
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
|
||||
TEST_P(UnifiedAPI, TestUnknownShapeTracing) {
|
||||
if (!UseFunction()) {
|
||||
GTEST_SKIP() << "Tracing only test.";
|
||||
}
|
||||
if (UseMlir()) {
|
||||
// TODO(b/173074167): Remove this.
|
||||
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
|
||||
}
|
||||
AbstractContextPtr ctx(BuildFunction("test_fn"));
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
tracing::TracingTensorHandle* x_raw = nullptr;
|
||||
PartialTensorShape shape;
|
||||
Status s = dyn_cast<tracing::TracingContext>(ctx.get())->AddParameter(
|
||||
DT_FLOAT, shape, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
PartialTensorShape shape;
|
||||
Status s = x->Shape(&shape);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ASSERT_TRUE(shape.unknown_rank());
|
||||
}
|
||||
|
||||
TEST_P(UnifiedAPI, TestPartialShapeTracing) {
|
||||
if (!UseFunction()) {
|
||||
GTEST_SKIP() << "Tracing only test.";
|
||||
}
|
||||
if (UseMlir()) {
|
||||
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
|
||||
}
|
||||
AbstractContextPtr ctx(BuildFunction("test_fn"));
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
tracing::TracingTensorHandle* x_raw = nullptr;
|
||||
PartialTensorShape shape;
|
||||
int64 dim_sizes[] = {2, -1};
|
||||
Status s = PartialTensorShape::MakePartialShape(dim_sizes, 2, &shape);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
s = dyn_cast<tracing::TracingContext>(ctx.get())->AddParameter(
|
||||
DT_FLOAT, shape, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
PartialTensorShape shape;
|
||||
Status s = x->Shape(&shape);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ASSERT_FALSE(shape.unknown_rank());
|
||||
|
||||
ASSERT_EQ(2, shape.dim_size(0));
|
||||
ASSERT_EQ(-1, shape.dim_size(1));
|
||||
}
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCppAPI, UnifiedAPI,
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(true, false),
|
||||
/*use_function*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCppAPI, UnifiedAPI,
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*use_function*/ ::testing::Values(true, false)));
|
||||
#endif
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
@ -38,8 +39,10 @@ Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
std::vector<AbstractTensorHandle*>* params) {
|
||||
tracing::TracingTensorHandle* handle = nullptr;
|
||||
for (auto input : inputs) {
|
||||
PartialTensorShape shape;
|
||||
TF_RETURN_IF_ERROR(input->Shape(&shape));
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||
input->DataType(), &handle));
|
||||
input->DataType(), shape, &handle));
|
||||
params->emplace_back(handle);
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -54,6 +54,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
@ -102,6 +103,13 @@ class MlirTensor : public TracingTensorHandle {
|
||||
return type;
|
||||
}
|
||||
|
||||
tensorflow::Status Shape(
|
||||
tensorflow::PartialTensorShape* shape) const override {
|
||||
// TODO(b/173074167): Implement this and enable tests in
|
||||
// unified_api_test.cc.
|
||||
return Unimplemented("MlirTensor::Shape is not implemented yet.");
|
||||
}
|
||||
|
||||
Value getValue() { return value_; }
|
||||
Type getElementType() {
|
||||
return value_.getType().cast<ShapedType>().getElementType();
|
||||
@ -250,6 +258,7 @@ class MlirFunctionContext : public TracingContext {
|
||||
return new MlirAbstractOp(context_.get(), this);
|
||||
}
|
||||
Status AddParameter(tensorflow::DataType dtype,
|
||||
const tensorflow::PartialTensorShape& shape,
|
||||
TracingTensorHandle** handle) override;
|
||||
|
||||
Status Finalize(OutputList* outputs, AbstractFunction** f) override;
|
||||
@ -547,8 +556,11 @@ Operation* MlirFunctionContext::CreateOperationFromState(
|
||||
return builder_.createOperation(state);
|
||||
}
|
||||
|
||||
Status MlirFunctionContext::AddParameter(tensorflow::DataType dtype,
|
||||
TracingTensorHandle** handle) {
|
||||
Status MlirFunctionContext::AddParameter(
|
||||
tensorflow::DataType dtype, const tensorflow::PartialTensorShape& shape,
|
||||
TracingTensorHandle** handle) {
|
||||
// TODO(b/173073199): Use shape. Enable tests in unified_api_test.cc once
|
||||
// resolved.
|
||||
Type type;
|
||||
TF_RETURN_IF_ERROR(ConvertDataTypeToTensor(dtype, builder_, &type));
|
||||
*handle = new MlirTensor(func_.getBody().front().addArgument(type));
|
||||
|
@ -633,6 +633,25 @@ Status TensorHandle::CopyInferenceShape(TensorHandle* other) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandle::Shape(tensorflow::PartialTensorShape* shape) const {
|
||||
DCHECK(shape != nullptr);
|
||||
if (!IsReady() && !inference_shape_.unknown_rank()) {
|
||||
*shape = inference_shape_;
|
||||
return Status::OK();
|
||||
} else {
|
||||
auto result = absl::visit(
|
||||
[](auto& data) {
|
||||
TensorShape shape;
|
||||
Status s = data.Shape(&shape);
|
||||
return std::make_pair(shape, s);
|
||||
},
|
||||
data_);
|
||||
TF_RETURN_IF_ERROR(result.second);
|
||||
*shape = result.first;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TensorHandle::NumDims(int* num_dims) const {
|
||||
DCHECK(num_dims != nullptr);
|
||||
if (!IsReady() && !inference_shape_.unknown_rank()) {
|
||||
|
@ -125,6 +125,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
|
||||
void Release() override;
|
||||
|
||||
tensorflow::DataType DataType() const override;
|
||||
Status Shape(tensorflow::PartialTensorShape* shape) const override;
|
||||
Status NumDims(int* num_dims) const override;
|
||||
Status NumElements(int64* num_elements) const override;
|
||||
Status Dim(int dim_index, int64* dim) const override;
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
@ -132,7 +133,9 @@ PYBIND11_MODULE(_unified_api, m) {
|
||||
.def("AddParameter",
|
||||
[](TracingContext* self, DataType dtype) {
|
||||
TracingTensorHandle* handle = nullptr;
|
||||
Status s = self->AddParameter(dtype, &handle);
|
||||
// TODO(srbs): Add shape argument to this function.
|
||||
tensorflow::PartialTensorShape shape;
|
||||
Status s = self->AddParameter(dtype, shape, &handle);
|
||||
MaybeRaiseRegisteredFromStatus(s);
|
||||
return static_cast<AbstractTensorHandle*>(handle);
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user