Introduce a higher-level function handling in the tracing oriented unified API
This patch intends to make function tracing more of a first class concept in the API. It tries to move away from the "flat graph" model with "placeholder" operation introduced with the expectation to turn them into function parameters later. Instead the user starts by creating an empty function which is an ExecutionContext (and as such can trace operations). Function parameters can get added to this context using a dedicated API returning an AbstractTensor. The diff in UnifiedCAPI/TestBasicGraph is probably a good illustration of the change from a client point of view. Another important point of this patch is to make it so that no C public API is defined in the `c_api_unified_experimental_graph.cc` file, instead the implementation is dispatched based on a registered factory function to create the tracing context. This will allow to swap the tracing implementation through injection later. PiperOrigin-RevId: 311529850 Change-Id: I822047f4306835abc0e044dc87c14179596f64bd
This commit is contained in:
parent
2a5910906a
commit
ec2cc2903f
|
@ -448,6 +448,8 @@ tf_cuda_library(
|
||||||
"//conditions:default": [],
|
"//conditions:default": [],
|
||||||
}) + [
|
}) + [
|
||||||
"@com_google_absl//absl/memory",
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
"@com_google_absl//absl/container:flat_hash_map",
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
"//tensorflow/core/distributed_runtime/eager:eager_client",
|
||||||
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
"//tensorflow/core/distributed_runtime/rpc/eager:grpc_eager_client",
|
||||||
|
|
|
@ -17,6 +17,8 @@ limitations under the License.
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/container/flat_hash_map.h"
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
|
@ -26,6 +28,51 @@ using tensorflow::string;
|
||||||
using tensorflow::internal::OutputList;
|
using tensorflow::internal::OutputList;
|
||||||
using tensorflow::internal::unwrap;
|
using tensorflow::internal::unwrap;
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace internal {
|
||||||
|
typedef absl::flat_hash_map<std::string, FactoryFunction> FactoriesMap;
|
||||||
|
|
||||||
|
static FactoriesMap& GetFactories() {
|
||||||
|
static FactoriesMap* factories = new FactoriesMap;
|
||||||
|
return *factories;
|
||||||
|
}
|
||||||
|
|
||||||
|
static const char* default_factory = "<unset>";
|
||||||
|
|
||||||
|
void RegisterTracingEngineFactory(const string& name, FactoryFunction factory) {
|
||||||
|
assert((!GetFactories().count(name)) ||
|
||||||
|
(GetFactories()[name] == factory) &&
|
||||||
|
"Duplicate tracing factory registration");
|
||||||
|
GetFactories()[name] = factory;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SetDefaultTracingEngine(const char* name) { default_factory = name; }
|
||||||
|
|
||||||
|
static ExecutionContext* CreateTracingExecutionContext(const char* fn_name,
|
||||||
|
TF_Status* s) {
|
||||||
|
auto entry = GetFactories().find(default_factory);
|
||||||
|
if (entry != GetFactories().end()) return entry->second(fn_name, s);
|
||||||
|
string msg = absl::StrCat(
|
||||||
|
"No tracing engine factory has been registered with the key '",
|
||||||
|
default_factory, "' (available: ");
|
||||||
|
// Ensure deterministic (sorted) order in the error message
|
||||||
|
std::set<string> factories_sorted;
|
||||||
|
for (const auto& factory : GetFactories())
|
||||||
|
factories_sorted.insert(factory.first);
|
||||||
|
const char* comma = "";
|
||||||
|
for (const string& factory : factories_sorted) {
|
||||||
|
msg += comma + factory;
|
||||||
|
comma = ", ";
|
||||||
|
}
|
||||||
|
msg += ")";
|
||||||
|
|
||||||
|
TF_SetStatus(s, TF_INVALID_ARGUMENT, msg.c_str());
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace internal
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
// Public C API entry points
|
// Public C API entry points
|
||||||
//
|
//
|
||||||
|
@ -36,6 +83,28 @@ using tensorflow::internal::unwrap;
|
||||||
//
|
//
|
||||||
// =============================================================================
|
// =============================================================================
|
||||||
|
|
||||||
|
void TF_SetTracingImplementation(const char* name) {
|
||||||
|
tensorflow::internal::SetDefaultTracingEngine(name);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Creates a new TensorFlow function, it is an execution context attached to a
|
||||||
|
// given tracing context.
|
||||||
|
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* s) {
|
||||||
|
return wrap(tensorflow::internal::CreateTracingExecutionContext(fn_name, s));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
|
||||||
|
TF_OutputList* outputs, TF_Status* s) {
|
||||||
|
auto* func = wrap(unwrap(ctx)->Finalize(unwrap(outputs), s));
|
||||||
|
TF_DeleteExecutionContext(ctx);
|
||||||
|
return func;
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||||
|
TF_DataType dtype, TF_Status* s) {
|
||||||
|
return wrap(unwrap(func)->AddParameter(dtype, s));
|
||||||
|
}
|
||||||
|
|
||||||
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
|
void TF_DeleteExecutionContext(TF_ExecutionContext* c) { delete unwrap(c); }
|
||||||
|
|
||||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
|
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* c) {
|
||||||
|
|
|
@ -49,15 +49,26 @@ typedef struct TF_AbstractOp TF_AbstractOp;
|
||||||
// setting functional attributes of other composite ops e.g. control flow.
|
// setting functional attributes of other composite ops e.g. control flow.
|
||||||
typedef struct TF_AbstractFunction TF_AbstractFunction;
|
typedef struct TF_AbstractFunction TF_AbstractFunction;
|
||||||
|
|
||||||
// Creates a context for tracing the execution of operations into a function.
|
// This allows the client to swap the implementation of the tracing engine.
|
||||||
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s);
|
// Any future call to TF_CreateFunction will use the implementation defined
|
||||||
|
// here.
|
||||||
|
void TF_SetTracingImplementation(const char* name);
|
||||||
|
|
||||||
|
// Creates a new TensorFlow function. A Function is an execution context, and as
|
||||||
|
// such it can trace operations through TF_ExecuteOperation. After completing
|
||||||
|
// tracing, a function can be obtained by TF_FinalizeFunction.
|
||||||
|
TF_ExecutionContext* TF_CreateFunction(const char* fn_name, TF_Status* status);
|
||||||
|
|
||||||
// Creates a context for eager execution of operations.
|
// Creates a context for eager execution of operations.
|
||||||
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
|
TF_ExecutionContext* TF_NewEagerExecutionContext(TFE_ContextOptions*,
|
||||||
TF_Status* s);
|
TF_Status* s);
|
||||||
|
|
||||||
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
void TF_DeleteExecutionContext(TF_ExecutionContext*);
|
||||||
|
|
||||||
|
// Add a new parameter to a TensorFlow Function.
|
||||||
|
// TODO(aminim): what about shape?
|
||||||
|
TF_AbstractTensor* TF_AddFunctionParameter(TF_ExecutionContext* func,
|
||||||
|
TF_DataType dtype, TF_Status* s);
|
||||||
|
|
||||||
// Create an operation suitable to use with the provided context. The operation
|
// Create an operation suitable to use with the provided context. The operation
|
||||||
// requires its type (e.g. "AddV2") to be set independently.
|
// requires its type (e.g. "AddV2") to be set independently.
|
||||||
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
|
TF_AbstractOp* TF_NewAbstractOp(TF_ExecutionContext* ctx);
|
||||||
|
@ -100,13 +111,12 @@ void TF_ExecuteOperation(TF_AbstractOp* op, int num_inputs,
|
||||||
TF_ExecutionContext* ctx, TF_Status* s);
|
TF_ExecutionContext* ctx, TF_Status* s);
|
||||||
|
|
||||||
// Creates a new TF_AbstractFunction from the current tracing states in the
|
// Creates a new TF_AbstractFunction from the current tracing states in the
|
||||||
// context. The returned TF_GraphToFunction must be deleted by the client.
|
// context. The provided `ctx` is consumed by this API call and deleted.
|
||||||
|
// The returned TF_AbstractFunction must be deleted by the client,
|
||||||
// TODO(aminim): clarify the contract on the state of the context after this
|
// TODO(aminim): clarify the contract on the state of the context after this
|
||||||
// call.
|
// call.
|
||||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
TF_AbstractFunction* TF_FinalizeFunction(TF_ExecutionContext* ctx,
|
||||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
TF_OutputList*, TF_Status*);
|
||||||
const TF_AbstractTensor* inputs, int num_outputs,
|
|
||||||
const TF_AbstractTensor* outputs, TF_Status* status);
|
|
||||||
|
|
||||||
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
|
void TF_DeleteAbstractFunction(TF_AbstractFunction*);
|
||||||
|
|
||||||
|
|
|
@ -123,6 +123,17 @@ class EagerContext : public ExecutionContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
|
||||||
|
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||||
|
"Can't add function parameter on an eager context.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
|
||||||
|
TF_SetStatus(s, TF_INVALID_ARGUMENT,
|
||||||
|
"Can't use finalize function on an eager context.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
|
void RegisterFunction(AbstractFunction* afunc, TF_Status* s) override {
|
||||||
auto* func = afunc->GetTfFunction(s);
|
auto* func = afunc->GetTfFunction(s);
|
||||||
if (!func) {
|
if (!func) {
|
||||||
|
|
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
#include "tensorflow/c/c_api.h"
|
#include "tensorflow/c/c_api.h"
|
||||||
#include "tensorflow/c/eager/c_api_internal.h"
|
#include "tensorflow/c/eager/c_api_internal.h"
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
|
@ -114,12 +115,14 @@ struct GraphFunction : public AbstractFunction {
|
||||||
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
static constexpr AbstractFunctionKind kKind = kGraphFunc;
|
||||||
};
|
};
|
||||||
|
|
||||||
// GraphContext wraps a TF_Graph and manages the "execution" of operation, i.e.
|
// GraphContext wraps a TF_Graph modeling a single function and manages the
|
||||||
// adding them to the graph.
|
// "execution" of operation, i.e. adding them to the function.
|
||||||
class GraphContext : public ExecutionContext {
|
class GraphContext : public ExecutionContext {
|
||||||
public:
|
public:
|
||||||
GraphContext()
|
explicit GraphContext(const char* name)
|
||||||
: ExecutionContext(kKind), graph_(new TF_Graph(), TF_DeleteGraph) {}
|
: ExecutionContext(kKind),
|
||||||
|
graph_(new TF_Graph(), TF_DeleteGraph),
|
||||||
|
name_(name) {}
|
||||||
|
|
||||||
AbstractOp* CreateOperation() override {
|
AbstractOp* CreateOperation() override {
|
||||||
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
// TODO(srbs): Should the lifetime of this op be tied to the context.
|
||||||
|
@ -164,24 +167,38 @@ class GraphContext : public ExecutionContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_Function* ToFunction(const char* fn_name, int num_inputs,
|
AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) override {
|
||||||
const GraphTensor* inputs, int num_outputs,
|
TF_OperationDescription* opdesc =
|
||||||
const GraphTensor* outputs, TF_Status* status) const {
|
TF_NewOperation(graph_.get(), "Placeholder",
|
||||||
std::vector<TF_Output> graph_inputs;
|
absl::StrCat("_input_", inputs_.size()).c_str());
|
||||||
graph_inputs.resize(num_inputs);
|
TF_SetAttrType(opdesc, "dtype", dtype);
|
||||||
std::vector<TF_Output> graph_outputs;
|
auto* operation = TF_FinishOperation(opdesc, s);
|
||||||
graph_outputs.resize(num_outputs);
|
if (!s->status.ok()) return nullptr;
|
||||||
for (int i = 0; i < num_inputs; i++) {
|
|
||||||
graph_inputs[i] = inputs[i].output;
|
inputs_.push_back(TF_Output{operation, 0});
|
||||||
}
|
return new GraphTensor(inputs_.back(), this);
|
||||||
for (int i = 0; i < num_outputs; i++) {
|
|
||||||
graph_outputs[i] = outputs[i].output;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return TF_GraphToFunction(graph_.get(), fn_name, 0, -1, nullptr,
|
AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) override {
|
||||||
graph_inputs.size(), graph_inputs.data(),
|
std::unique_ptr<GraphFunction> func(new GraphFunction);
|
||||||
graph_outputs.size(), graph_outputs.data(),
|
std::vector<TF_Output> graph_outputs;
|
||||||
nullptr, nullptr, fn_name, status);
|
graph_outputs.reserve(outputs->outputs.size());
|
||||||
|
for (AbstractTensor* abstract_output : outputs->outputs) {
|
||||||
|
GraphTensor* output = dyncast<GraphTensor>(abstract_output);
|
||||||
|
if (!output) {
|
||||||
|
TF_SetStatus(s, TF_UNIMPLEMENTED,
|
||||||
|
"Returning a non-graph tensor from a function has not "
|
||||||
|
"been implemented yet.");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
graph_outputs.push_back(output->output);
|
||||||
|
}
|
||||||
|
|
||||||
|
func->func = TF_GraphToFunction(
|
||||||
|
graph_.get(), name_, 0, -1, nullptr, inputs_.size(), inputs_.data(),
|
||||||
|
graph_outputs.size(), graph_outputs.data(), nullptr, nullptr, name_, s);
|
||||||
|
if (TF_GetCode(s) != TF_OK) return nullptr;
|
||||||
|
return func.release();
|
||||||
}
|
}
|
||||||
|
|
||||||
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
|
void RegisterFunction(AbstractFunction* func, TF_Status* s) override {
|
||||||
|
@ -195,54 +212,20 @@ class GraphContext : public ExecutionContext {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> graph_;
|
||||||
|
std::vector<TF_Output> inputs_;
|
||||||
|
const char* name_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Helper that converts the graph currently held in the context into a function.
|
static ExecutionContext* GraphTracingFactory(const char* name, TF_Status* s) {
|
||||||
static AbstractFunction* ExecutionContextToFunction(
|
return new GraphContext(name);
|
||||||
const ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
|
||||||
const AbstractTensor* inputs, int num_outputs,
|
|
||||||
const AbstractTensor* outputs, TF_Status* status) {
|
|
||||||
auto* graph_ctx = dyncast<const GraphContext>(fn_body);
|
|
||||||
if (graph_ctx == nullptr) {
|
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
|
||||||
"fn_body is not a TF_GraphContext.");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto* graph_inputs = dyncast<const GraphTensor>(inputs);
|
|
||||||
if (!graph_inputs) {
|
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "inputs aren't GraphTensors.");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
auto* graph_outputs = dyncast<const GraphTensor>(outputs);
|
|
||||||
if (!graph_outputs) {
|
|
||||||
TF_SetStatus(status, TF_INVALID_ARGUMENT, "outputs aren't GraphTensors.");
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
GraphFunction* func = new GraphFunction;
|
|
||||||
func->func = graph_ctx->ToFunction(fn_name, num_inputs, graph_inputs,
|
|
||||||
num_outputs, graph_outputs, status);
|
|
||||||
return func;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Register the tracing implemented in this file as the default tracing engine.
|
||||||
|
static bool register_tracing = [] {
|
||||||
|
RegisterTracingEngineFactory("graphdef", GraphTracingFactory);
|
||||||
|
SetDefaultTracingEngine("graphdef");
|
||||||
|
return true;
|
||||||
|
}();
|
||||||
|
|
||||||
} // namespace internal
|
} // namespace internal
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
// =============================================================================
|
|
||||||
// Public C API entry points
|
|
||||||
// These are only the entry points specific to the Graph API.
|
|
||||||
// =============================================================================
|
|
||||||
|
|
||||||
using tensorflow::internal::unwrap;
|
|
||||||
|
|
||||||
TF_ExecutionContext* TF_NewGraphExecutionContext(TF_Status* s) {
|
|
||||||
return wrap(new tensorflow::internal::GraphContext());
|
|
||||||
}
|
|
||||||
|
|
||||||
TF_AbstractFunction* TF_ExecutionContextToFunction(
|
|
||||||
const TF_ExecutionContext* fn_body, const char* fn_name, int num_inputs,
|
|
||||||
const TF_AbstractTensor* inputs, int num_outputs,
|
|
||||||
const TF_AbstractTensor* outputs, TF_Status* status) {
|
|
||||||
return wrap(ExecutionContextToFunction(unwrap(fn_body), fn_name, num_inputs,
|
|
||||||
unwrap(inputs), num_outputs,
|
|
||||||
unwrap(outputs), status));
|
|
||||||
}
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||||
#include "tensorflow/c/tf_datatype.h"
|
#include "tensorflow/c/tf_datatype.h"
|
||||||
#include "tensorflow/c/tf_status.h"
|
#include "tensorflow/c/tf_status.h"
|
||||||
#include "tensorflow/core/platform/casts.h"
|
#include "tensorflow/core/platform/casts.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
|
@ -148,6 +149,17 @@ struct ExecutionContext {
|
||||||
// Creates an empty AbstractOperation suitable to use with this context.
|
// Creates an empty AbstractOperation suitable to use with this context.
|
||||||
virtual AbstractOp* CreateOperation() = 0;
|
virtual AbstractOp* CreateOperation() = 0;
|
||||||
|
|
||||||
|
// Add a function parameter and return the corresponding tensor.
|
||||||
|
// This is only valid with an ExecutionContext obtained from a TracingContext,
|
||||||
|
// it'll always error out with an eager context.
|
||||||
|
virtual AbstractTensor* AddParameter(TF_DataType dtype, TF_Status* s) = 0;
|
||||||
|
|
||||||
|
// Finalize this context and make a function out of it. The context is in a
|
||||||
|
// invalid state after this call and must be destroyed.
|
||||||
|
// This is only valid with an ExecutionContext obtained from a TracingContext,
|
||||||
|
// it'll always error out with an eager context.
|
||||||
|
virtual AbstractFunction* Finalize(OutputList* outputs, TF_Status* s) = 0;
|
||||||
|
|
||||||
// Registers a functions with this context, after this the function is
|
// Registers a functions with this context, after this the function is
|
||||||
// available to be called/referenced by its name in this context.
|
// available to be called/referenced by its name in this context.
|
||||||
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
|
virtual void RegisterFunction(AbstractFunction* func, TF_Status* s) = 0;
|
||||||
|
@ -156,6 +168,11 @@ struct ExecutionContext {
|
||||||
const ExecutionContextKind k;
|
const ExecutionContextKind k;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
typedef ExecutionContext* (*FactoryFunction)(const char* fn_name, TF_Status*);
|
||||||
|
void SetDefaultTracingEngine(const char* name);
|
||||||
|
void RegisterTracingEngineFactory(const ::tensorflow::string& name,
|
||||||
|
FactoryFunction factory);
|
||||||
|
|
||||||
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
|
// Create utilities to wrap/unwrap: this convert from the C opaque types to the
|
||||||
// C++ implementation, and back.
|
// C++ implementation, and back.
|
||||||
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \
|
#define MAKE_WRAP_UNWRAP(C_TYPEDEF, CPP_CLASS) \
|
||||||
|
|
|
@ -29,7 +29,12 @@ using tensorflow::string;
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
TEST(UnifiedCAPI, TestBasicEager) {
|
class UnifiedCAPI : public ::testing::TestWithParam<const char*> {
|
||||||
|
protected:
|
||||||
|
void SetUp() override { TF_SetTracingImplementation(GetParam()); }
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(UnifiedCAPI, TestBasicEager) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
@ -81,33 +86,18 @@ TEST(UnifiedCAPI, TestBasicEager) {
|
||||||
TF_DeleteExecutionContext(ctx);
|
TF_DeleteExecutionContext(ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnifiedCAPI, TestBasicGraph) {
|
TEST_P(UnifiedCAPI, TestBasicGraph) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
// Start a new function / execution context.
|
||||||
|
string fn_name = "double";
|
||||||
|
TF_ExecutionContext* graph_ctx =
|
||||||
|
TF_CreateFunction(fn_name.c_str(), status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Add a placeholder to the graph.
|
auto* placeholder_t =
|
||||||
auto* placeholder_op = TF_NewAbstractOp(graph_ctx);
|
TF_AddFunctionParameter(graph_ctx, TF_FLOAT, status.get());
|
||||||
TF_AbstractOpSetOpType(placeholder_op, "Placeholder", status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
TF_AbstractOpSetOpName(placeholder_op, "my_ph", status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
TF_AbstractOpSetAttrType(placeholder_op, "dtype", TF_FLOAT, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
|
|
||||||
// Build inputs and outputs.
|
|
||||||
TF_OutputList* placeholder_outputs = TF_NewOutputList();
|
|
||||||
|
|
||||||
// Execute.
|
|
||||||
TF_ExecuteOperation(placeholder_op, 0, nullptr, placeholder_outputs,
|
|
||||||
graph_ctx, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
|
||||||
ASSERT_EQ(1, TF_OutputListNumOutputs(placeholder_outputs));
|
|
||||||
TF_AbstractTensor* placeholder_t = TF_OutputListGet(placeholder_outputs, 0);
|
|
||||||
|
|
||||||
// Delete placeholder op.
|
|
||||||
TF_DeleteAbstractOp(placeholder_op);
|
|
||||||
|
|
||||||
// Build an abstract operation.
|
// Build an abstract operation.
|
||||||
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
auto* add_op = TF_NewAbstractOp(graph_ctx);
|
||||||
|
@ -123,17 +113,13 @@ TEST(UnifiedCAPI, TestBasicGraph) {
|
||||||
// Execute.
|
// Execute.
|
||||||
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
|
TF_ExecuteOperation(add_op, 2, inputs, add_outputs, graph_ctx, status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
TF_AbstractTensor* output_t = TF_OutputListGet(add_outputs, 0);
|
|
||||||
|
|
||||||
// Clean up operation and inputs.
|
// Clean up operation and inputs.
|
||||||
TF_DeleteAbstractOp(add_op);
|
TF_DeleteAbstractOp(add_op);
|
||||||
|
|
||||||
string fn_name = "double";
|
TF_AbstractFunction* func =
|
||||||
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
|
TF_FinalizeFunction(graph_ctx, add_outputs, status.get());
|
||||||
graph_ctx, fn_name.c_str(), 1, placeholder_t, 1, output_t, status.get());
|
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
TF_DeleteAbstractTensor(placeholder_t);
|
|
||||||
TF_DeleteAbstractTensor(output_t);
|
|
||||||
|
|
||||||
// Build eager context.
|
// Build eager context.
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
@ -174,18 +160,16 @@ TEST(UnifiedCAPI, TestBasicGraph) {
|
||||||
ASSERT_EQ(*f_value, 4.0);
|
ASSERT_EQ(*f_value, 4.0);
|
||||||
|
|
||||||
TF_DeleteOutputList(add_outputs);
|
TF_DeleteOutputList(add_outputs);
|
||||||
TF_DeleteOutputList(placeholder_outputs);
|
|
||||||
TF_DeleteAbstractOp(fn_op);
|
TF_DeleteAbstractOp(fn_op);
|
||||||
TF_DeleteAbstractTensor(input_t);
|
TF_DeleteAbstractTensor(input_t);
|
||||||
TF_DeleteAbstractTensor(final_result);
|
TF_DeleteAbstractTensor(final_result);
|
||||||
TF_DeleteTensor(f_t);
|
TF_DeleteTensor(f_t);
|
||||||
TF_DeleteAbstractFunction(func);
|
TF_DeleteAbstractFunction(func);
|
||||||
|
|
||||||
TF_DeleteExecutionContext(graph_ctx);
|
|
||||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
TEST_P(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||||
|
@ -193,18 +177,15 @@ TEST(UnifiedCAPI, TF_ExecutionContextToFunctionWithEagerContextRaises) {
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
TFE_DeleteContextOptions(opts);
|
TFE_DeleteContextOptions(opts);
|
||||||
|
|
||||||
TF_AbstractFunction* func = TF_ExecutionContextToFunction(
|
TF_AbstractFunction* func = TF_FinalizeFunction(ctx, nullptr, status.get());
|
||||||
ctx, nullptr, 0, nullptr, 0, nullptr, status.get());
|
|
||||||
ASSERT_EQ(nullptr, func);
|
ASSERT_EQ(nullptr, func);
|
||||||
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
ASSERT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(status.get()));
|
||||||
|
|
||||||
TF_DeleteExecutionContext(ctx);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
TEST_P(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Add a placeholder to the graph.
|
// Add a placeholder to the graph.
|
||||||
|
@ -222,10 +203,10 @@ TEST(UnifiedCAPI, TF_CallingSetOpTypeAfterFinishingOpBuildingRaises) {
|
||||||
TF_DeleteExecutionContext(graph_ctx);
|
TF_DeleteExecutionContext(graph_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
TEST_P(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Add a placeholder to the graph.
|
// Add a placeholder to the graph.
|
||||||
|
@ -243,7 +224,7 @@ TEST(UnifiedCAPI, TF_CallingSetOpNameAfterFinishingOpBuildingRaises) {
|
||||||
TF_DeleteExecutionContext(graph_ctx);
|
TF_DeleteExecutionContext(graph_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
TEST_P(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||||
// Build an Eager context.
|
// Build an Eager context.
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
|
@ -273,7 +254,8 @@ TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Build a Graph context.
|
// Build a Graph context.
|
||||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Execute eager op using graph context.
|
// Execute eager op using graph context.
|
||||||
|
@ -289,10 +271,11 @@ TEST(UnifiedCAPI, TestExecutingEagerOpInGraphModeRaises) {
|
||||||
TF_DeleteExecutionContext(graph_ctx);
|
TF_DeleteExecutionContext(graph_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
TEST_P(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TF_ExecutionContext* graph_ctx = TF_NewGraphExecutionContext(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
TF_ExecutionContext* graph_ctx = TF_CreateFunction("some_func", status.get());
|
||||||
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get());
|
||||||
|
|
||||||
// Add a placeholder to the graph.
|
// Add a placeholder to the graph.
|
||||||
|
@ -349,5 +332,7 @@ TEST(UnifiedCAPI, TestExecutingGraphOpInEagerModeRaises) {
|
||||||
TF_DeleteExecutionContext(eager_execution_ctx);
|
TF_DeleteExecutionContext(eager_execution_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(Tracing, UnifiedCAPI, ::testing::Values("graphdef"));
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
Loading…
Reference in New Issue