Implements AbstractTensor::Shape(PartialTensorShape*)
.
Adds shape argument to `TracingContext::AddParameter`. PiperOrigin-RevId: 342336504 Change-Id: I3d20c6992f03290866a01bb625d9879eded258ae
This commit is contained in:
parent
242e920d7f
commit
0dd94c4ad3
@ -176,6 +176,7 @@ cc_library(
|
|||||||
"//tensorflow/c:c_api_internal",
|
"//tensorflow/c:c_api_internal",
|
||||||
"//tensorflow/c:conversion_macros",
|
"//tensorflow/c:conversion_macros",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core/platform:casts",
|
"//tensorflow/core/platform:casts",
|
||||||
"//tensorflow/core/platform:types",
|
"//tensorflow/core/platform:types",
|
||||||
],
|
],
|
||||||
@ -244,6 +245,7 @@ cc_library(
|
|||||||
":c_api_unified_internal",
|
":c_api_unified_internal",
|
||||||
"//tensorflow/c:tf_status",
|
"//tensorflow/c:tf_status",
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core/lib/llvm_rtti",
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
"//tensorflow/core/platform:errors",
|
"//tensorflow/core/platform:errors",
|
||||||
"//tensorflow/core/platform:status",
|
"//tensorflow/core/platform:status",
|
||||||
@ -288,6 +290,29 @@ tf_cuda_cc_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cuda_cc_test(
|
||||||
|
name = "unified_api_test",
|
||||||
|
size = "small",
|
||||||
|
srcs = [
|
||||||
|
"unified_api_test.cc",
|
||||||
|
],
|
||||||
|
args = ["--heap_check=local"],
|
||||||
|
linkstatic = tf_kernel_tests_linkstatic(),
|
||||||
|
tags = tf_cuda_tests_tags(),
|
||||||
|
deps = [
|
||||||
|
":c_api_experimental",
|
||||||
|
":c_api_unified_internal",
|
||||||
|
":unified_api_testutil",
|
||||||
|
"//tensorflow/c:tf_status_helper",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
|
"//tensorflow/core/platform:errors",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gradients_util",
|
name = "gradients_util",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -477,8 +502,10 @@ cc_library(
|
|||||||
"//tensorflow:internal",
|
"//tensorflow:internal",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/platform:refcount",
|
"//tensorflow/core/platform:refcount",
|
||||||
|
"//tensorflow/core/platform:status",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,8 +17,10 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/platform/refcount.h"
|
#include "tensorflow/core/platform/refcount.h"
|
||||||
|
#include "tensorflow/core/platform/status.h"
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
// Abstract interface to a Tensor handle in either tracing or immediate
|
// Abstract interface to a Tensor handle in either tracing or immediate
|
||||||
@ -32,6 +34,9 @@ class AbstractTensorHandle : public core::RefCounted {
|
|||||||
public:
|
public:
|
||||||
// Returns tensor dtype.
|
// Returns tensor dtype.
|
||||||
virtual tensorflow::DataType DataType() const = 0;
|
virtual tensorflow::DataType DataType() const = 0;
|
||||||
|
// Returns tensor shape. If tensor has unknown rank, shape remains untouched.
|
||||||
|
virtual tensorflow::Status Shape(
|
||||||
|
tensorflow::PartialTensorShape* shape) const = 0;
|
||||||
|
|
||||||
AbstractTensorHandleKind getKind() const { return kind_; }
|
AbstractTensorHandleKind getKind() const { return kind_; }
|
||||||
|
|
||||||
|
@ -134,7 +134,9 @@ TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
|
|||||||
}
|
}
|
||||||
|
|
||||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||||
TF_DataType dtype, TF_Status* s) {
|
TF_DataType dtype, TF_Shape shape,
|
||||||
|
TF_Status* s) {
|
||||||
|
DCHECK_GE(shape.num_dims, -1);
|
||||||
TracingTensorHandle* t;
|
TracingTensorHandle* t;
|
||||||
TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(func));
|
TracingContext* tracing_ctx = dyn_cast<TracingContext>(unwrap(func));
|
||||||
if (!tracing_ctx) {
|
if (!tracing_ctx) {
|
||||||
@ -143,8 +145,20 @@ TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
|||||||
"TF_AddFunctionParameter must be called on a TracingContext."));
|
"TF_AddFunctionParameter must be called on a TracingContext."));
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
tensorflow::PartialTensorShape partial_shape;
|
||||||
|
if (shape.num_dims != -1) {
|
||||||
|
DCHECK(shape.dim_sizes != nullptr);
|
||||||
|
Status status = tensorflow::PartialTensorShape::MakePartialShape(
|
||||||
|
reinterpret_cast<tensorflow::int64*>(shape.dim_sizes), shape.num_dims,
|
||||||
|
&partial_shape);
|
||||||
|
if (!status.ok()) {
|
||||||
|
Set_TF_Status_from_Status(s, status);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
Set_TF_Status_from_Status(
|
Set_TF_Status_from_Status(
|
||||||
s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), &t));
|
s, tracing_ctx->AddParameter(static_cast<DataType>(dtype), partial_shape,
|
||||||
|
&t));
|
||||||
return wrap(t);
|
return wrap(t);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,10 +64,16 @@ TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
|
|||||||
TF_Status* s);
|
TF_Status* s);
|
||||||
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
||||||
|
|
||||||
|
// Represents a (partially-defined) shape.
|
||||||
|
typedef struct TF_Shape {
|
||||||
|
int num_dims; // Must be >= -1; -1 represents unknown rank.
|
||||||
|
int64_t* dim_sizes;
|
||||||
|
} TF_Shape;
|
||||||
|
|
||||||
// Add a new parameter to a TensorFlow Function.
|
// Add a new parameter to a TensorFlow Function.
|
||||||
// TODO(aminim): what about shape?
|
|
||||||
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||||
TF_DataType dtype, TF_Status* s);
|
TF_DataType dtype, TF_Shape shape,
|
||||||
|
TF_Status* s);
|
||||||
|
|
||||||
// Create an operation suitable to use with the provided context. The operation
|
// Create an operation suitable to use with the provided context. The operation
|
||||||
// requires its type (e.g. "AddV2") to be set independently.
|
// requires its type (e.g. "AddV2") to be set independently.
|
||||||
|
@ -25,6 +25,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
|
#include "tensorflow/core/framework/shape_inference.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
@ -43,22 +45,50 @@ class GraphContext;
|
|||||||
class GraphOperation;
|
class GraphOperation;
|
||||||
class GraphTensor;
|
class GraphTensor;
|
||||||
|
|
||||||
|
auto& kUnknownDim = shape_inference::InferenceContext::kUnknownDim;
|
||||||
|
auto& kUnknownRank = shape_inference::InferenceContext::kUnknownRank;
|
||||||
|
|
||||||
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
|
// GraphTensor wraps a `TF_Output`, i.e. a pointer to TF_Operation and the index
|
||||||
// into the list of outputs for the operation.
|
// into the list of outputs for the operation.
|
||||||
class GraphTensor : public TracingTensorHandle {
|
class GraphTensor : public TracingTensorHandle {
|
||||||
public:
|
public:
|
||||||
explicit GraphTensor(TF_Output output)
|
explicit GraphTensor(TF_Output output, TF_Graph* graph)
|
||||||
: TracingTensorHandle(kGraph), output_(output) {}
|
: TracingTensorHandle(kGraph), output_(output), graph_(graph) {}
|
||||||
|
|
||||||
tensorflow::DataType DataType() const override {
|
tensorflow::DataType DataType() const override {
|
||||||
return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));
|
return static_cast<tensorflow::DataType>(TF_OperationOutputType(output_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tensorflow::Status Shape(
|
||||||
|
tensorflow::PartialTensorShape* shape) const override {
|
||||||
|
DCHECK(shape != nullptr);
|
||||||
|
TF_Status status;
|
||||||
|
int num_dims = TF_GraphGetTensorNumDims(graph_, output_, &status);
|
||||||
|
DCHECK_GE(num_dims, -1);
|
||||||
|
TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
|
||||||
|
if (num_dims == kUnknownRank) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64> dims(num_dims, kUnknownDim);
|
||||||
|
TF_GraphGetTensorShape(graph_, output_,
|
||||||
|
reinterpret_cast<int64_t*>(dims.data()), num_dims,
|
||||||
|
&status);
|
||||||
|
TF_RETURN_IF_ERROR(StatusFromTF_Status(&status));
|
||||||
|
TF_RETURN_IF_ERROR(tensorflow::TensorShapeUtils::MakeShape(dims, shape));
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
TF_Output output_;
|
TF_Output output_;
|
||||||
|
|
||||||
// For LLVM style RTTI.
|
// For LLVM style RTTI.
|
||||||
static bool classof(const AbstractTensorHandle* ptr) {
|
static bool classof(const AbstractTensorHandle* ptr) {
|
||||||
return ptr->getKind() == kGraph;
|
return ptr->getKind() == kGraph;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
TF_Graph* graph_; // For shape inference.
|
||||||
};
|
};
|
||||||
|
|
||||||
// GraphOperation wraps and populates a TF_OperationDescription.
|
// GraphOperation wraps and populates a TF_OperationDescription.
|
||||||
@ -135,7 +165,7 @@ class GraphOperation : public TracingOperation {
|
|||||||
TF_DeleteStatus(s);
|
TF_DeleteStatus(s);
|
||||||
*num_retvals = TF_OperationNumOutputs(operation);
|
*num_retvals = TF_OperationNumOutputs(operation);
|
||||||
for (int i = 0; i < *num_retvals; ++i) {
|
for (int i = 0; i < *num_retvals; ++i) {
|
||||||
retvals[i] = new GraphTensor({operation, i});
|
retvals[i] = new GraphTensor({operation, i}, g_);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -326,12 +356,18 @@ class GraphContext : public TracingContext {
|
|||||||
return new GraphOperation(graph_.get());
|
return new GraphOperation(graph_.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
Status AddParameter(DataType dtype, TracingTensorHandle** output) override {
|
Status AddParameter(DataType dtype, const PartialTensorShape& shape,
|
||||||
|
TracingTensorHandle** output) override {
|
||||||
TracingOperationPtr operation(CreateOperation());
|
TracingOperationPtr operation(CreateOperation());
|
||||||
TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr));
|
TF_RETURN_IF_ERROR(operation->Reset("Placeholder", nullptr));
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str()));
|
operation->SetOpName(absl::StrCat("_input_", inputs_.size()).c_str()));
|
||||||
TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype));
|
TF_RETURN_IF_ERROR(operation->SetAttrType("dtype", dtype));
|
||||||
|
if (!shape.unknown_rank()) {
|
||||||
|
TF_RETURN_IF_ERROR(operation->SetAttrShape(
|
||||||
|
"shape", reinterpret_cast<int64_t*>(shape.dim_sizes().data()),
|
||||||
|
shape.dims()));
|
||||||
|
}
|
||||||
int num_outputs = 1;
|
int num_outputs = 1;
|
||||||
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
std::vector<AbstractTensorHandle*> outputs(num_outputs);
|
||||||
TF_RETURN_IF_ERROR(operation->Execute(
|
TF_RETURN_IF_ERROR(operation->Execute(
|
||||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/platform/casts.h"
|
#include "tensorflow/core/platform/casts.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
@ -107,7 +108,8 @@ class TracingContext : public AbstractContext {
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
// Add a function parameter and return the corresponding tensor.
|
// Add a function parameter and return the corresponding tensor.
|
||||||
virtual Status AddParameter(DataType dtype, TracingTensorHandle**) = 0;
|
virtual Status AddParameter(DataType dtype, const PartialTensorShape& shape,
|
||||||
|
TracingTensorHandle**) = 0;
|
||||||
|
|
||||||
// Finalize this context and make a function out of it. The context is in a
|
// Finalize this context and make a function out of it. The context is in a
|
||||||
// invalid state after this call and must be destroyed.
|
// invalid state after this call and must be destroyed.
|
||||||
|
@ -359,7 +359,7 @@ TEST_P(UnifiedCAPI, TestBasicGraph) {
|
|||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
auto* placeholder_t =
|
auto* placeholder_t =
|
||||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Build an abstract operation.
|
// Build an abstract operation.
|
||||||
@ -450,7 +450,7 @@ TEST_P(UnifiedCAPI, TestBasicGraphMatMul) {
|
|||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
auto* placeholder_t =
|
auto* placeholder_t =
|
||||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Build an abstract operation.
|
// Build an abstract operation.
|
||||||
@ -553,9 +553,9 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraph) {
|
|||||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
|
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
// Create a first "Add" computing `arg0 + arg1`.
|
// Create a first "Add" computing `arg0 + arg1`.
|
||||||
@ -709,9 +709,9 @@ TEST_P(UnifiedCAPI, TestMultiOutputGraphMatMul) {
|
|||||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
|
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name.c_str(), s);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
auto* arg0 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, s);
|
auto* arg1 = TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, s);
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
|
||||||
|
|
||||||
// Create a first "Add" computing `arg0 + arg1`.
|
// Create a first "Add" computing `arg0 + arg1`.
|
||||||
@ -975,7 +975,7 @@ TEST_P(UnifiedCAPI, TF_AbstractTensorGetEagerTensorOnGraphTensorRaises) {
|
|||||||
|
|
||||||
// Add a placeholder to the graph.
|
// Add a placeholder to the graph.
|
||||||
auto placeholder_t =
|
auto placeholder_t =
|
||||||
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, {-1, nullptr}, status.get());
|
||||||
TF_AbstractTensorGetEagerTensor(placeholder_t, status.get());
|
TF_AbstractTensorGetEagerTensor(placeholder_t, status.get());
|
||||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||||
|
|
||||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
#include "tensorflow/c/tf_tensor.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
|
||||||
@ -224,8 +225,10 @@ Status CreateParamsForInputs(AbstractContext* ctx,
|
|||||||
vector<AbstractTensorHandle*>* params) {
|
vector<AbstractTensorHandle*>* params) {
|
||||||
tracing::TracingTensorHandle* handle = nullptr;
|
tracing::TracingTensorHandle* handle = nullptr;
|
||||||
for (auto input : inputs) {
|
for (auto input : inputs) {
|
||||||
|
PartialTensorShape shape;
|
||||||
|
TF_RETURN_IF_ERROR(input->Shape(&shape));
|
||||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||||
input->DataType(), &handle));
|
input->DataType(), shape, &handle));
|
||||||
params->emplace_back(handle);
|
params->emplace_back(handle);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
205
tensorflow/c/eager/unified_api_test.cc
Normal file
205
tensorflow/c/eager/unified_api_test.cc
Normal file
@ -0,0 +1,205 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
|
#include "tensorflow/c/eager/unified_api_testutil.h"
|
||||||
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace {
|
||||||
|
class UnifiedAPI
|
||||||
|
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
TF_StatusPtr status(TF_NewStatus());
|
||||||
|
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||||
|
Status s = StatusFromTF_Status(status.get());
|
||||||
|
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
|
||||||
|
bool UseFunction() const { return std::get<2>(GetParam()); }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Checks that inputs[0] is a scalar.
|
||||||
|
Status TestScalarShape(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs) {
|
||||||
|
PartialTensorShape shape;
|
||||||
|
TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape));
|
||||||
|
if (shape.dims() != 0) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Tensor expected to have scalar shape found rank: ", shape.dims());
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(UnifiedAPI, TestTensorShapeScalar) {
|
||||||
|
if (UseFunction() && UseMlir()) {
|
||||||
|
// TODO(b/173074167): Remove this.
|
||||||
|
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
|
||||||
|
}
|
||||||
|
AbstractContextPtr ctx;
|
||||||
|
{
|
||||||
|
AbstractContext* ctx_raw = nullptr;
|
||||||
|
Status s =
|
||||||
|
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
ctx.reset(ctx_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractTensorHandlePtr x;
|
||||||
|
{
|
||||||
|
AbstractTensorHandle* x_raw = nullptr;
|
||||||
|
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
x.reset(x_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status s = RunModel(TestScalarShape, ctx.get(),
|
||||||
|
/*inputs=*/{x.get()},
|
||||||
|
/*outputs=*/{},
|
||||||
|
/*use_function=*/UseFunction());
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks that inputs[0] is a matrix with shape 2x4.
|
||||||
|
Status TestTensorShape2x4(AbstractContext* ctx,
|
||||||
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
|
absl::Span<AbstractTensorHandle*> outputs) {
|
||||||
|
PartialTensorShape shape;
|
||||||
|
TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape));
|
||||||
|
if (shape.dims() != 2) {
|
||||||
|
return errors::InvalidArgument(
|
||||||
|
"Tensor expected to have rank 2 found rank: ", shape.dims());
|
||||||
|
}
|
||||||
|
int64 dim_sizes[] = {2, 4};
|
||||||
|
for (int i = 0; i < shape.dims(); i++) {
|
||||||
|
if (shape.dim_size(i) != dim_sizes[i]) {
|
||||||
|
return errors::InvalidArgument("Dim ", i, " expected to be of size ",
|
||||||
|
dim_sizes[i],
|
||||||
|
" found: ", shape.dim_size(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(UnifiedAPI, TestTensorShape2x4) {
|
||||||
|
if (UseFunction() && UseMlir()) {
|
||||||
|
// TODO(b/173074167): Remove this.
|
||||||
|
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
|
||||||
|
}
|
||||||
|
AbstractContextPtr ctx;
|
||||||
|
{
|
||||||
|
AbstractContext* ctx_raw = nullptr;
|
||||||
|
Status s =
|
||||||
|
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
ctx.reset(ctx_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
AbstractTensorHandlePtr x;
|
||||||
|
{
|
||||||
|
AbstractTensorHandle* x_raw = nullptr;
|
||||||
|
float data[] = {0., 0., 0., 0., 0., 0., 0., 0};
|
||||||
|
int64 dim_sizes[] = {2, 4};
|
||||||
|
Status s =
|
||||||
|
TestTensorHandleWithDimsFloat(ctx.get(), data, dim_sizes, 2, &x_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
x.reset(x_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
Status s = RunModel(TestTensorShape2x4, ctx.get(),
|
||||||
|
/*inputs=*/{x.get()},
|
||||||
|
/*outputs=*/{},
|
||||||
|
/*use_function=*/UseFunction());
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(UnifiedAPI, TestUnknownShapeTracing) {
|
||||||
|
if (!UseFunction()) {
|
||||||
|
GTEST_SKIP() << "Tracing only test.";
|
||||||
|
}
|
||||||
|
if (UseMlir()) {
|
||||||
|
// TODO(b/173074167): Remove this.
|
||||||
|
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
|
||||||
|
}
|
||||||
|
AbstractContextPtr ctx(BuildFunction("test_fn"));
|
||||||
|
AbstractTensorHandlePtr x;
|
||||||
|
{
|
||||||
|
tracing::TracingTensorHandle* x_raw = nullptr;
|
||||||
|
PartialTensorShape shape;
|
||||||
|
Status s = dyn_cast<tracing::TracingContext>(ctx.get())->AddParameter(
|
||||||
|
DT_FLOAT, shape, &x_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
x.reset(x_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
PartialTensorShape shape;
|
||||||
|
Status s = x->Shape(&shape);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
ASSERT_TRUE(shape.unknown_rank());
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_P(UnifiedAPI, TestPartialShapeTracing) {
|
||||||
|
if (!UseFunction()) {
|
||||||
|
GTEST_SKIP() << "Tracing only test.";
|
||||||
|
}
|
||||||
|
if (UseMlir()) {
|
||||||
|
GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
|
||||||
|
}
|
||||||
|
AbstractContextPtr ctx(BuildFunction("test_fn"));
|
||||||
|
AbstractTensorHandlePtr x;
|
||||||
|
{
|
||||||
|
tracing::TracingTensorHandle* x_raw = nullptr;
|
||||||
|
PartialTensorShape shape;
|
||||||
|
int64 dim_sizes[] = {2, -1};
|
||||||
|
Status s = PartialTensorShape::MakePartialShape(dim_sizes, 2, &shape);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
s = dyn_cast<tracing::TracingContext>(ctx.get())->AddParameter(
|
||||||
|
DT_FLOAT, shape, &x_raw);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
x.reset(x_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
PartialTensorShape shape;
|
||||||
|
Status s = x->Shape(&shape);
|
||||||
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
|
ASSERT_FALSE(shape.unknown_rank());
|
||||||
|
|
||||||
|
ASSERT_EQ(2, shape.dim_size(0));
|
||||||
|
ASSERT_EQ(-1, shape.dim_size(1));
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef PLATFORM_GOOGLE
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
UnifiedCppAPI, UnifiedAPI,
|
||||||
|
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||||
|
/*tfrt*/ ::testing::Values(true, false),
|
||||||
|
/*use_function*/ ::testing::Values(true, false)));
|
||||||
|
#else
|
||||||
|
INSTANTIATE_TEST_SUITE_P(
|
||||||
|
UnifiedCppAPI, UnifiedAPI,
|
||||||
|
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||||
|
/*tfrt*/ ::testing::Values(false),
|
||||||
|
/*use_function*/ ::testing::Values(true, false)));
|
||||||
|
#endif
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/c/tf_status_helper.h"
|
#include "tensorflow/c/tf_status_helper.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
|
|
||||||
@ -38,8 +39,10 @@ Status CreateParamsForInputs(AbstractContext* ctx,
|
|||||||
std::vector<AbstractTensorHandle*>* params) {
|
std::vector<AbstractTensorHandle*>* params) {
|
||||||
tracing::TracingTensorHandle* handle = nullptr;
|
tracing::TracingTensorHandle* handle = nullptr;
|
||||||
for (auto input : inputs) {
|
for (auto input : inputs) {
|
||||||
|
PartialTensorShape shape;
|
||||||
|
TF_RETURN_IF_ERROR(input->Shape(&shape));
|
||||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||||
input->DataType(), &handle));
|
input->DataType(), shape, &handle));
|
||||||
params->emplace_back(handle);
|
params->emplace_back(handle);
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -54,6 +54,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||||
#include "tensorflow/core/framework/node_def_util.h"
|
#include "tensorflow/core/framework/node_def_util.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
@ -102,6 +103,13 @@ class MlirTensor : public TracingTensorHandle {
|
|||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tensorflow::Status Shape(
|
||||||
|
tensorflow::PartialTensorShape* shape) const override {
|
||||||
|
// TODO(b/173074167): Implement this and enable tests in
|
||||||
|
// unified_api_test.cc.
|
||||||
|
return Unimplemented("MlirTensor::Shape is not implemented yet.");
|
||||||
|
}
|
||||||
|
|
||||||
Value getValue() { return value_; }
|
Value getValue() { return value_; }
|
||||||
Type getElementType() {
|
Type getElementType() {
|
||||||
return value_.getType().cast<ShapedType>().getElementType();
|
return value_.getType().cast<ShapedType>().getElementType();
|
||||||
@ -250,6 +258,7 @@ class MlirFunctionContext : public TracingContext {
|
|||||||
return new MlirAbstractOp(context_.get(), this);
|
return new MlirAbstractOp(context_.get(), this);
|
||||||
}
|
}
|
||||||
Status AddParameter(tensorflow::DataType dtype,
|
Status AddParameter(tensorflow::DataType dtype,
|
||||||
|
const tensorflow::PartialTensorShape& shape,
|
||||||
TracingTensorHandle** handle) override;
|
TracingTensorHandle** handle) override;
|
||||||
|
|
||||||
Status Finalize(OutputList* outputs, AbstractFunction** f) override;
|
Status Finalize(OutputList* outputs, AbstractFunction** f) override;
|
||||||
@ -547,8 +556,11 @@ Operation* MlirFunctionContext::CreateOperationFromState(
|
|||||||
return builder_.createOperation(state);
|
return builder_.createOperation(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status MlirFunctionContext::AddParameter(tensorflow::DataType dtype,
|
Status MlirFunctionContext::AddParameter(
|
||||||
TracingTensorHandle** handle) {
|
tensorflow::DataType dtype, const tensorflow::PartialTensorShape& shape,
|
||||||
|
TracingTensorHandle** handle) {
|
||||||
|
// TODO(b/173073199): Use shape. Enable tests in unified_api_test.cc once
|
||||||
|
// resolved.
|
||||||
Type type;
|
Type type;
|
||||||
TF_RETURN_IF_ERROR(ConvertDataTypeToTensor(dtype, builder_, &type));
|
TF_RETURN_IF_ERROR(ConvertDataTypeToTensor(dtype, builder_, &type));
|
||||||
*handle = new MlirTensor(func_.getBody().front().addArgument(type));
|
*handle = new MlirTensor(func_.getBody().front().addArgument(type));
|
||||||
|
@ -633,6 +633,25 @@ Status TensorHandle::CopyInferenceShape(TensorHandle* other) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status TensorHandle::Shape(tensorflow::PartialTensorShape* shape) const {
|
||||||
|
DCHECK(shape != nullptr);
|
||||||
|
if (!IsReady() && !inference_shape_.unknown_rank()) {
|
||||||
|
*shape = inference_shape_;
|
||||||
|
return Status::OK();
|
||||||
|
} else {
|
||||||
|
auto result = absl::visit(
|
||||||
|
[](auto& data) {
|
||||||
|
TensorShape shape;
|
||||||
|
Status s = data.Shape(&shape);
|
||||||
|
return std::make_pair(shape, s);
|
||||||
|
},
|
||||||
|
data_);
|
||||||
|
TF_RETURN_IF_ERROR(result.second);
|
||||||
|
*shape = result.first;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
Status TensorHandle::NumDims(int* num_dims) const {
|
Status TensorHandle::NumDims(int* num_dims) const {
|
||||||
DCHECK(num_dims != nullptr);
|
DCHECK(num_dims != nullptr);
|
||||||
if (!IsReady() && !inference_shape_.unknown_rank()) {
|
if (!IsReady() && !inference_shape_.unknown_rank()) {
|
||||||
|
@ -125,6 +125,7 @@ class TensorHandle : public ImmediateExecutionTensorHandle {
|
|||||||
void Release() override;
|
void Release() override;
|
||||||
|
|
||||||
tensorflow::DataType DataType() const override;
|
tensorflow::DataType DataType() const override;
|
||||||
|
Status Shape(tensorflow::PartialTensorShape* shape) const override;
|
||||||
Status NumDims(int* num_dims) const override;
|
Status NumDims(int* num_dims) const override;
|
||||||
Status NumElements(int64* num_elements) const override;
|
Status NumElements(int64* num_elements) const override;
|
||||||
Status Dim(int dim_index, int64* dim) const override;
|
Status Dim(int dim_index, int64* dim) const override;
|
||||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||||
#include "tensorflow/c/eager/tfe_context_internal.h"
|
#include "tensorflow/c/eager/tfe_context_internal.h"
|
||||||
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.pb.h"
|
#include "tensorflow/core/framework/types.pb.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
@ -132,7 +133,9 @@ PYBIND11_MODULE(_unified_api, m) {
|
|||||||
.def("AddParameter",
|
.def("AddParameter",
|
||||||
[](TracingContext* self, DataType dtype) {
|
[](TracingContext* self, DataType dtype) {
|
||||||
TracingTensorHandle* handle = nullptr;
|
TracingTensorHandle* handle = nullptr;
|
||||||
Status s = self->AddParameter(dtype, &handle);
|
// TODO(srbs): Add shape argument to this function.
|
||||||
|
tensorflow::PartialTensorShape shape;
|
||||||
|
Status s = self->AddParameter(dtype, shape, &handle);
|
||||||
MaybeRaiseRegisteredFromStatus(s);
|
MaybeRaiseRegisteredFromStatus(s);
|
||||||
return static_cast<AbstractTensorHandle*>(handle);
|
return static_cast<AbstractTensorHandle*>(handle);
|
||||||
})
|
})
|
||||||
|
Loading…
Reference in New Issue
Block a user